source: trunk/LMDZ.MARS/util/display_netcdf.py @ 4007

Last change on this file since 4007 was 3867, checked in by jbclement, 6 months ago

Mars PCM:
Small usage improvements for the script "display_netcdf.py".
JBC

  • Property svn:executable set to *
File size: 30.2 KB
Line 
1#!/usr/bin/env python3
2##############################################################
3### Python script to visualize a variable in a NetCDF file ###
4##############################################################
5
6"""
7This script can display any numeric variable from a NetCDF file.
8It supports the following cases:
9  - Scalar output
10  - 1D time series
11  - 1D vertical profiles
12  - 2D latitude/longitude map
13  - 2D cross-sections
14  - Optionally average over latitude and plot longitude vs. time heatmap
15  - Optionally display polar stereographic view of 2D maps
16  - Optionally display 3D globe view of 2D maps
17
18Automatic setup from the environment file found in the "LMDZ.MARS/util" folder:
19  1. Make sure Conda is installed.
20  2. In terminal, navigate to the folder containing this script.
21  4. Create the environment:
22       conda env create -f display_netcdf.yml
23  5. Activate the environment:
24       conda activate my_env
25  6. Run the script:
26     python display_netcdf.py
27
28Usage:
29  1) Command-line mode:
30       python display_netcdf.py /path/to/your_file.nc --variable VAR_NAME [options]
31     Options:
32    --variable         : Name of the variable to visualize.
33    --time-index       : Index along the Time dimension (0-based, ignored for purely 1D time series).
34    --alt-index        : Index along the altitude dimension (0-based), if present.
35    --cmap             : Matplotlib colormap (default: "jet").
36    --avg-lat          : Average over latitude and plot longitude vs. time heatmap.
37    --slice-lon-index  : Fixed longitude index for altitude×longitude cross-section.
38    --slice-lat-index  : Fixed latitude index for altitude×latitude cross-section.
39    --show-topo        : Overlay MOLA topography on lat/lon maps.
40    --show-polar       :
41    --show-3d          :
42    --output           : If provided, save the figure to this filename instead of displaying.
43    --extra-indices    : JSON string to fix indices for any other dimensions.
44                         For dimensions with "slope", use 1-based numbering here.
45                         Example: '{"nslope": 1, "physical_points": 3}'
46
47  2) Interactive mode:
48       python display_netcdf.py
49     The script will prompt for everything.
50"""
51
52import os
53import sys
54import glob
55import readline
56import argparse
57import json
58import numpy as np
59import matplotlib.pyplot as plt
60import matplotlib.path as mpath
61import matplotlib.colors as mcolors
62import cartopy.crs as ccrs
63import pandas as pd
64from netCDF4 import Dataset
65
66# Attempt vedo import early for global use
67try:
68    import vedo
69    from vedo import *
70    from scipy.interpolate import RegularGridInterpolator
71    vedo_available = True
72except ImportError:
73    vedo_available = False
74
75# Constants for recognized dimension names
76TIME_DIMS = ("Time", "time", "time_counter")
77ALT_DIMS  = ("altitude",)
78LAT_DIMS  = ("latitude", "lat")
79LON_DIMS  = ("longitude", "lon")
80
81# Paths for MOLA data
82MOLA_NPY = 'MOLA_1px_per_deg.npy'
83MOLA_CSV = 'molaTeam_contour_31rgb_steps.csv'
84
85# Attempt to load MOLA topography
86if os.path.isfile(MOLA_NPY): # shape (nlat, nlon) at 1° per pixel: lat from -90 to 90, lon from 0 to 360
87    MOLA = np.load(MOLA_NPY) 
88    nlat, nlon = MOLA.shape
89    topo_lats = np.linspace(90 - 0.5, -90 + 0.5, nlat)
90    topo_lons = np.linspace(-180 + 0.5, 180 - 0.5, nlon)
91    topo_lon2d, topo_lat2d = np.meshgrid(topo_lons, topo_lats)
92    topo_loaded = True
93else:
94    print(f"Warning: '{MOLA_NPY}' not found! Topography contours disabled.")
95    topo_loaded = False
96
97
98# Attempt to load contour color table
99if os.path.isfile(MOLA_CSV):
100    color_table = pd.read_csv(MOLA_CSV)
101    csv_loaded = True
102else:
103    print(f"Warning: '{MOLA_CSV}' not found! 3D view colors disabled.")
104    csv_loaded = False
105
106
107def complete_filename(text, state):
108    """
109    Tab-completion for filesystem paths.
110    """
111    if "*" not in text:
112        pattern = text + "*"
113    else:
114        pattern = text
115    matches = glob.glob(os.path.expanduser(pattern))
116    matches = [m + "/" if os.path.isdir(m) else m for m in matches]
117    try:
118        return matches[state]
119    except IndexError:
120        return None
121
122
123def make_varname_completer(varnames):
124    """
125    Returns a readline completer for variable names.
126    """
127    def completer(text, state):
128        options = [name for name in varnames if name.startswith(text)]
129        try:
130            return options[state]
131        except IndexError:
132            return None
133    return completer
134
135
136def find_dim_index(dims, candidates):
137    """
138    Search through dims tuple for any name in candidates.
139    Returns the index if found, else returns None.
140    """
141    for idx, dim in enumerate(dims):
142        for cand in candidates:
143            if cand.lower() == dim.lower():
144                return idx
145    return None
146
147
148def find_coord_var(dataset, candidates):
149    """
150    Among dataset variables, return the first variable whose name matches any candidate.
151    Returns None if none found.
152    """
153    for name in dataset.variables:
154        for cand in candidates:
155            if cand.lower() == name.lower():
156                return name
157    return None
158
159
160def overlay_topography(ax, transform, levels=10):
161    """
162    Overlay MOLA topography contours onto a given GeoAxes.
163    """
164    if not topo_loaded:
165        return
166    ax.contour(
167        topo_lon2d, topo_lat2d, MOLA,
168        levels=levels,
169        linewidths=0.5,
170        colors='black',
171        transform=transform
172    )
173
174
175def attach_format_coord(ax, mat, x, y, x_dim, y_dim, varname, is_pcolormesh=True, data_crs=ccrs.PlateCarree()):
176    """
177    Attach a format_coord function to the axes to display x, y, and value at cursor.
178    Works for both pcolormesh and imshow style grids.
179    """
180    # Determine dimensions
181    if mat.ndim == 2:
182        ny, nx = mat.shape
183    elif mat.ndim == 3 and mat.shape[2] in (3, 4):
184        ny, nx, nc = mat.shape
185    else:
186        raise ValueError(f"Unsupported mat shape {mat.shape}")
187
188    # Edges or extents
189    if is_pcolormesh:
190        xedges, yedges = x, y
191    else:
192        x0, x1 = x.min(), x.max()
193        y0, y1 = y.min(), y.max()
194
195    # Detect if ax is a GeoAxes with a projection we can invert
196    proj = getattr(ax, 'projection', None)
197    geo_axes = (
198        isinstance(proj, ccrs.Projection)
199        and x_dim.lower() in LON_DIMS
200        and y_dim.lower() in LAT_DIMS
201    )
202
203    def format_coord(xp, yp):
204        # Geographic transform if appropriate
205        if geo_axes:
206            try:
207                lonp, latp = data_crs.transform_point(xp, yp, src_crs=proj)
208            except Exception:
209                lonp, latp = xp, yp
210            xi, yi = lonp, latp
211        else:
212            xi, yi = xp, yp
213
214        # Map to matrix indices
215        if is_pcolormesh:
216            col = np.searchsorted(xedges, xi) - 1
217            row = np.searchsorted(yedges, yi) - 1
218        else:
219            col = int((xi - x0) / (x1 - x0) * nx)
220            row = int((yi - y0) / (y1 - y0) * ny)
221
222        # Build the label
223        label_xy = f"{x_dim}={xi:.3g}, {y_dim}={yi:.3g}"
224        if 0 <= row < ny and 0 <= col < nx:
225            if mat.ndim == 2:
226                v = mat[row, col]
227                return f"{label_xy}, {varname}={v:.3g}"
228            else:
229                vals = mat[row, col]
230                txt = ", ".join(f"{vv:.3g}" for vv in vals[:3])
231                return f"{label_xy}, {varname}=({txt})"
232        else:
233            return label_xy
234
235    ax.format_coord = format_coord
236
237
238def plot_polar_views(lon2d, lat2d, data2d, colormap, varname, units=None, topo_overlay=True):
239    """
240    Plot two polar‐stereographic views (north & south) of the same data.
241    """
242    figs = []  # collect figure handles
243
244    for pole in ("north", "south"):
245        # Choose projection and extent for each pole
246        if pole == "north":
247            proj = ccrs.NorthPolarStereo(central_longitude=180)
248            extent = [-180, 180, 60, 90]
249        else:
250            proj = ccrs.SouthPolarStereo(central_longitude=180)
251            extent = [-180, 180, -90, -60]
252
253        # Create figure and GeoAxes
254        fig = plt.figure(figsize=(8, 6))
255        ax = fig.add_subplot(1, 1, 1, projection=proj, aspect=True)
256        ax.set_global()
257        ax.set_extent(extent, ccrs.PlateCarree())
258
259        # Draw circular boundary
260        theta = np.linspace(0, 2 * np.pi, 100)
261        center, radius = [0.5, 0.5], 0.5
262        verts = np.vstack([np.sin(theta), np.cos(theta)]).T
263        circle = mpath.Path(verts * radius + center)
264        ax.set_boundary(circle, transform=ax.transAxes)
265
266        # Add meridians/parallels
267        gl = ax.gridlines(
268            draw_labels=True,
269            color='k',
270            xlocs=range(-180, 181, 30),
271            ylocs=range(-90, 91, 10),
272            linestyle='--',
273            linewidth=0.5
274        )
275
276        # Plot data in PlateCarree projection
277        cf = ax.pcolormesh(
278            lon2d, lat2d, data2d,
279            shading='auto',
280            cmap=colormap,
281            transform=ccrs.PlateCarree()
282        )
283        uniq_lons = np.unique(lon2d.ravel())
284        uniq_lats = np.unique(lat2d.ravel())
285        attach_format_coord(ax, data2d, uniq_lons, uniq_lats, 'lon', 'lat', varname, is_pcolormesh=True)
286
287        # Optionally overlay MOLA topography
288        if topo_overlay:
289            overlay_topography(ax, transform=ccrs.PlateCarree(), levels=20)
290
291        # Colorbar and title
292        cbar = fig.colorbar(cf, ax=ax, pad=0.1)
293        label = varname + (f" ({units})" if units else "")
294        cbar.set_label(label)
295        ax.set_title(f"{varname} — {pole.capitalize()} polar region", pad=20, y=1.05, fontsize=12, fontweight='bold')
296
297        figs.append(fig)
298
299    # Show both figures
300    plt.show()
301
302
303def plot_3D_globe(lon2d, lat2d, data2d, colormap, varname, units=None):
304    """
305    Plot a 3D globe view of the data using vedo, with surface coloring based on data2d
306    and overlaid contour lines from MOLA topography.
307    """
308    if not vedo_available:
309        print("3D view skipped: vedo missing.")
310        return
311    if not csv_loaded:
312        print("3D view skipped: color table missing.")
313        return
314
315    # Prepare MOLA grid
316    nlat, nlon = MOLA.shape
317    lats = np.linspace(90, -90, nlat)
318    lons = np.linspace(-180, 180, nlon)
319    lon_grid, lat_grid = np.meshgrid(lons, lats)
320
321    # Interpolate data2d onto MOLA grid
322    lat_data = np.linspace(-90, 90, data2d.shape[0])
323    lon_data = np.linspace(-180, 180, data2d.shape[1])
324    interp2d = RegularGridInterpolator((lat_data, lon_data), data2d,
325                                       bounds_error=False, fill_value=None)
326    newdata2d = interp2d((lat_grid, lon_grid))
327
328    # Generate contour lines from MOLA
329    cs = plt.contour(lon_grid, lat_grid, MOLA, levels=10, linewidths=0)
330    plt.clf()
331    contour_lines = []
332    radius = 3389500 # Mars average radius [m]
333    for segs, level in zip(cs.allsegs, cs.levels):
334        for verts in segs:
335            lon_c = verts[:, 0]
336            lat_c = verts[:, 1]
337            phi_c = np.radians(90 - lat_c)
338            theta_c = np.radians(lon_c)
339            elev = RegularGridInterpolator((lats, lons), MOLA,
340                                           bounds_error=False,
341                                           fill_value=0.0)((lat_c, lon_c))
342            r_cont = radius + elev * 10
343            x_c = r_cont * np.sin(phi_c) * np.cos(theta_c) * 1.002
344            y_c = r_cont * np.sin(phi_c) * np.sin(theta_c) * 1.002
345            z_c = r_cont * np.cos(phi_c) * 1.002
346            pts = np.column_stack([x_c, y_c, z_c])
347            if pts.shape[0] > 1:
348                contour_lines.append(Line(pts, c='k', lw=0.5))
349
350    # Create sphere surface mesh
351    phi = np.deg2rad(90 - lat_grid)
352    theta = np.deg2rad(lon_grid)
353    r = radius + MOLA * 10
354    x = r * np.sin(phi) * np.cos(theta)
355    y = r * np.sin(phi) * np.sin(theta)
356    z = r * np.cos(phi)
357    pts = np.stack([x.ravel(), y.ravel(), z.ravel()], axis=1)
358
359    # Build mesh faces
360    faces = []
361    for i in range(nlat - 1):
362        for j in range(nlon - 1):
363            p0 = i * nlon + j
364            p1 = p0 + 1
365            p2 = p0 + nlon
366            p3 = p2 + 1
367            faces.extend([(p0, p2, p1), (p1, p2, p3)])
368
369    mesh = Mesh([pts, faces])
370    mesh.cmap(colormap, newdata2d.ravel())
371    mesh.add_scalarbar(title=varname + (f' [{units}]' if units else ''), c='white')
372    mesh.lighting('default')
373
374    # Geographic grid lines
375    meridians, parallels, labels = [], [], []
376    zero_lon_offset = radius * 0.03
377    for lon in range(-150, 181, 30):
378        lat_line = np.linspace(-90, 90, nlat)
379        lon_line = np.full_like(lat_line, lon)
380        phi = np.deg2rad(90 - lat_line)
381        theta = np.deg2rad(lon_line)
382        elev = RegularGridInterpolator((lats, lons), MOLA)((lat_line, lon_line))
383        rr = radius + elev * 10
384        pts_line = np.column_stack([
385            rr * np.sin(phi) * np.cos(theta),
386            rr * np.sin(phi) * np.sin(theta),
387            rr * np.cos(phi)
388        ]) * 1.005
389        label_pos = pts_line[len(pts_line)//2]
390        norm = np.linalg.norm(label_pos)
391        label_pos_out = label_pos / norm * (norm + radius * 0.02)
392        if lon == 0:
393            label_pos_out[1] += zero_lon_offset
394        meridians.append(Line(pts_line, c='k', lw=1)#.flagpole(
395            #f"{lon}°",
396            #point=label_pos_out,
397            #offset=[0, 0, radius * 0.05],
398            #s=radius*0.01,
399            #c='yellow'
400        #).follow_camera()
401        )
402
403    for lat in range(-60, 91, 30):
404        lon_line = np.linspace(-180, 180, nlon)
405        lat_line = np.full_like(lon_line, lat)
406        phi = np.deg2rad(90 - lat_line)
407        theta = np.deg2rad(lon_line)
408        elev = RegularGridInterpolator((lats, lons), MOLA)((lat_line, lon_line))
409        rr = radius + elev * 10
410        pts_line = np.column_stack([
411            rr * np.sin(phi) * np.cos(theta),
412            rr * np.sin(phi) * np.sin(theta),
413            rr * np.cos(phi)
414        ]) * 1.005
415        label_pos = pts_line[len(pts_line)//2]
416        norm = np.linalg.norm(label_pos)
417        label_pos_out = label_pos / norm * (norm + radius * 0.02)
418        parallels.append(Line(pts_line, c='k', lw=1)#.flagpole(
419            #f"{lat}°",
420            #point=label_pos_out,
421            #offset=[0, 0, radius * 0.05],
422            #s=radius*0.01,
423            #c='yellow'
424        #).follow_camera()
425        )
426
427    # Create plotter
428    plotter = Plotter(title="3D globe view", bg="bb", axes=0)
429
430    # Configure camera
431    cam_dist = radius * 3
432    plotter.camera.SetPosition([cam_dist, 0, 0])
433    plotter.camera.SetFocalPoint([0, 0, 0])
434    plotter.camera.SetViewUp([0, 0, 1])
435
436    # Show the globe
437    plotter.show(mesh, *contour_lines, *meridians, *parallels)
438
439
440def transform_physical_points(dataset, var, data):
441    """
442    Transform a physical_points 1D array into a 2D grid of shape (nlat, nlon).
443    """
444    # Fetch lat/lon coordinate variables
445    lat_var = find_coord_var(dataset, LAT_DIMS)
446    lon_var = find_coord_var(dataset, LON_DIMS)
447    if lat_var is None or lon_var is None:
448        raise ValueError("Cannot find latitude or longitude variables for physical_points")
449    raw_lats = dataset.variables[lat_var][:]
450    raw_lons = dataset.variables[lon_var][:]
451    # Unmask
452    if hasattr(raw_lats, 'mask'):
453        raw_lats = np.where(raw_lats.mask, np.nan, raw_lats.data)
454    if hasattr(raw_lons, 'mask'):
455        raw_lons = np.where(raw_lons.mask, np.nan, raw_lons.data)
456    # Convert radians to degrees if in radians
457    if np.max(np.abs(raw_lats)) <= np.pi:
458        raw_lats = np.degrees(raw_lats)
459        raw_lons = np.degrees(raw_lons)
460    # Get unique coords
461    uniq_lats = np.unique(raw_lats)
462    uniq_lons = np.unique(raw_lons)
463    # Initialize grid
464    grid = np.full((uniq_lats.size, uniq_lons.size), np.nan)
465    # Build the grid
466    for value, lat, lon in zip(data.ravel(), raw_lats.ravel(), raw_lons.ravel()):
467        i = np.where(np.isclose(uniq_lats, lat))[0][0]
468        j = np.where(np.isclose(uniq_lons, lon))[0][0]
469        grid[i, j] = value
470    # Duplicate the pole value across all longitudes
471    for i in (0, -1):
472        row = grid[i, :]
473        count_good = np.count_nonzero(~np.isnan(row))
474        if count_good == 1:
475            pole_value = row[~np.isnan(row)][0]
476            grid[i, :] = pole_value
477    # Wrap longitude if needed
478    if -180.0 in uniq_lons:
479        idx = np.where(np.isclose(uniq_lons, -180.0))[0][0]
480        grid = np.hstack([grid, grid[:, [idx]]])
481        uniq_lons = np.append(uniq_lons, 180.0)
482    return grid, uniq_lats, uniq_lons, lat_var, lon_var
483
484
485def get_dimension_indices(ds, varname):
486    """
487    For each dimension of the variable:
488     - if size == 1 → automatically select index 0
489     - otherwise prompt the user:
490         <number>     : take that specific index (1-based)
491         'a'          : average over this dimension
492         'e' or Enter : take all values
493    Returns {dim_name: int index, 'avg', or None}.
494    """
495    var = ds.variables[varname]
496    dims = var.dimensions
497    shape = var.shape
498    selection = {}
499    for dim, size in zip(dims, shape):
500        if size == 1:
501            selection[dim] = 0
502            continue
503        prompt = (
504                f"Available options for '{dim}' (size {size}):\n"
505                f"  > '1–{size}' to pick that index\n"
506                "  > 'a' to average over this dimension\n"
507                "  > 'e' or Enter to take all values\n"
508                "Choose: "
509        )
510        while True:
511            resp = input(prompt).strip().lower()
512            if resp in ("", "e"):
513                selection[dim] = None
514                break
515            if resp == 'a':
516                selection[dim] = 'avg'
517                break
518            if resp.isdigit():
519                n = int(resp)
520                if 1 <= n <= size:
521                    selection[dim] = n - 1
522                    break
523            print(f"  Invalid entry '{resp}'. Please enter a number, 'a', 'e', or just Enter.")
524    return selection
525
526
527def plot_variable(dataset, varname, colormap="jet", output_path=None, extra_indices=None):
528    """
529    Automatically select singleton dims, prompt for others,
530    allow user to choose x/y for 2D, handle special cases (physical_points, averaging).
531    """
532    var = dataset.variables[varname]
533    dims = list(var.dimensions)
534    # Read data
535    try:
536        data_full = var[:]
537    except Exception as e:
538        print(f"Error: Cannot read data for '{varname}': {e}")
539        return
540    # Unmask
541    if hasattr(data_full, 'mask'):
542        data_full = np.where(data_full.mask, np.nan, data_full.data)
543    # Initialize extra_indices
544    extra_indices = extra_indices or {}
545    # Handle averaging selections
546    for dim, mode in dict(extra_indices).items():
547        if mode == 'avg':
548            ax = dims.index(dim)
549            data_full = np.nanmean(data_full, axis=ax, keepdims=True)
550            extra_indices[dim] = 0
551    # Build slicer
552    slicer = []
553    for dim in dims:
554        idx = extra_indices.get(dim)
555        slicer.append(idx if isinstance(idx, int) else slice(None))
556    data_slice = data_full[tuple(slicer)]
557    nd = data_slice.ndim
558    # Special case: physical_points dimension
559    if nd == 1 and 'physical_points' in dims:
560        # Transform into 2D grid
561        grid, uniq_lats, uniq_lons, latv, lonv = transform_physical_points(dataset, var, data_slice)
562        # Plot map
563        proj = ccrs.PlateCarree()
564        fig, ax = plt.subplots(figsize=(8, 6), subplot_kw=dict(projection=proj))
565        lon2d, lat2d = np.meshgrid(uniq_lons, uniq_lats)
566        lon_ticks = np.arange(-180, 181, 30)
567        lat_ticks = np.arange(-90, 91, 30)
568        ax.set_xticks(lon_ticks, crs=ccrs.PlateCarree())
569        ax.set_yticks(lat_ticks, crs=ccrs.PlateCarree())
570        ax.tick_params(
571            axis='x', which='major',
572            length=4,
573            direction='out',
574            pad=2,
575            labelsize=8
576        )
577        ax.tick_params(
578            axis='y', which='major',
579            length=4,
580            direction='out',
581            pad=2,
582            labelsize=8
583        )
584        cf = ax.pcolormesh(lon2d, lat2d, grid, shading='auto', cmap=colormap, transform=ccrs.PlateCarree())
585        attach_format_coord(ax, grid, uniq_lons, uniq_lats, 'lon', 'lat', varname, is_pcolormesh=True)
586        overlay_topography(ax, transform=proj, levels=10) # Overlay MOLA topography
587        cbar = fig.colorbar(cf, ax=ax, pad=0.02)
588        cbar.set_label(varname)
589        ax.set_title(f"{varname} (physical_points)", fontweight='bold')
590        ax.set_xlabel(f"Longitude ({getattr(dataset.variables[lonv], 'units', 'deg')})")
591        ax.set_ylabel(f"Latitude ({getattr(dataset.variables[latv], 'units', 'deg')})")
592        # Prompt for polar-stereo views if interactive
593        if input("Display polar-stereo views? [y = yes, anything else = no]: ").strip().lower() == "y":
594            units = getattr(var, 'units', None)
595            plot_polar_views(lon2d, lat2d, grid, colormap, varname, units)
596        # Prompt for 3D globe view if interactive
597        if input("Display 3D globe view? [y = yes, anything else = no]: ").strip().lower() == "y":
598            units = getattr(var, 'units', None)
599            plot_3D_globe(lon2d, lat2d, grid, colormap, varname, units)
600        if output_path:
601            fig.savefig(output_path, bbox_inches='tight')
602            print(f"Saved to {output_path}")
603        else:
604            plt.show()
605        return
606    # 0D
607    if nd == 0:
608        print(f"\033[36mScalar '{varname}': {float(data_slice)}\033[0m")
609        return
610    # 1D
611    if nd == 1:
612        rem = [(i, d) for i, (d, s) in enumerate(zip(dims, slicer)) if isinstance(s, slice)]
613        axis_idx, dim_name = rem[0]
614        coord_var = find_coord_var(dataset, [dim_name])
615        if coord_var:
616            x = dataset.variables[coord_var][:]
617            if hasattr(x, 'mask'):
618                x = np.where(x.mask, np.nan, x.data)
619            xlabel = coord_var
620        else:
621            x = np.arange(data_slice.shape[0])
622            xlabel = dim_name
623        y = data_slice
624        plt.figure(figsize=(8, 4))
625        plt.plot(x, y)
626        plt.grid(True)
627        plt.xlabel(xlabel)
628        plt.ylabel(varname)
629        plt.title(f"{varname} vs {xlabel}")
630        if output_path:
631            plt.savefig(output_path, bbox_inches='tight')
632            print(f"Saved plot to {output_path}")
633        else:
634            plt.show()
635        return
636    # 2D
637    if nd == 2:
638        remaining = [d for d, idx in zip(dims, slicer) if isinstance(idx, slice)]
639        # Choose X/Y interactively
640        resp = input(f"Which dimension on X? {remaining}: ").strip()
641        if resp == remaining[1]:
642            x_dim, y_dim = remaining[1], remaining[0]
643        else:
644            x_dim, y_dim = remaining[0], remaining[1]
645        def get_coords(dim):
646            coord_var = find_coord_var(dataset, [dim])
647            if coord_var:
648                arr = dataset.variables[coord_var][:]
649                if hasattr(arr, 'mask'):
650                    arr = np.where(arr.mask, np.nan, arr.data)
651                return arr
652            return np.arange(data_slice.shape[remaining.index(dim)])
653        x_coords = get_coords(x_dim)
654        y_coords = get_coords(y_dim)
655        order = [remaining.index(y_dim), remaining.index(x_dim)]
656        plot_data = np.moveaxis(data_slice, order, [0, 1])
657        fig, ax = plt.subplots(figsize=(8, 6))
658        im = ax.pcolormesh(x_coords, y_coords, plot_data, shading='auto', cmap=colormap)
659        attach_format_coord(ax, plot_data, x_coords, y_coords, x_dim, y_dim, varname, is_pcolormesh=True)
660        cbar = fig.colorbar(im, ax=ax, pad=0.02)
661        cbar.set_label(varname)
662        ax.set_xlabel(x_dim)
663        ax.set_ylabel(y_dim)
664        ax.set_title(f"{varname} ({y_dim} vs {x_dim})")
665        ax.grid(True)
666        if {x_dim, y_dim} & set(LAT_DIMS) and {x_dim, y_dim} & set(LON_DIMS):
667            # Prompt for polar-stereo views if interactive
668            if sys.stdin.isatty() and input("Display polar-stereo views? [y = yes, anything else = no]: ").strip().lower() == "y":
669                units = getattr(dataset.variables[varname], "units", None)
670                plot_polar_views(x_coords, y_coords, plot_data, colormap, varname, units)
671            # Prompt for 3D globe view if interactive
672            if sys.stdin.isatty() and input("Display 3D globe view? [y = yes, anything else = no]: ").strip().lower() == "y":
673                units = getattr(dataset.variables[varname], "units", None)
674                plot_3D_globe(x_coords, y_coords, plot_data, colormap, varname, units)
675        if output_path:
676            fig.savefig(output_path, bbox_inches='tight')
677            print(f"Saved to {output_path}")
678        else:
679            plt.show()
680        return
681    print(f"Plotting for ndim={nd} not yet supported.")
682
683
684def visualize_variable_interactive(nc_path=None):
685    """
686    Interactive loop: keep prompting for variables to plot until user quits.
687    """
688    # Open dataset
689    if nc_path:
690        path = nc_path
691    else:
692        readline.set_completer(complete_filename)
693        readline.parse_and_bind("tab: complete")
694        path = input("Enter path to NetCDF file: ").strip()
695
696    if not os.path.isfile(path):
697        print(f"Error: '{path}' not found.")
698        return
699
700    ds = Dataset(path, "r")
701    var_list = list(ds.variables.keys())
702    if not var_list:
703        print("No variables found in file.")
704        ds.close()
705        return
706
707    # Enable interactive mode
708    plt.ion()
709
710    while True:
711        # Enable tab-completion for variable names
712        readline.set_completer(make_varname_completer(var_list))
713        readline.parse_and_bind("tab: complete")
714
715        print("\nAvailable variables:")
716        for name in var_list:
717            print(f"  > {name}")
718        varname = input("\nEnter variable name to plot (or Enter to quit): ").strip()
719        if varname == "":
720            print("Exiting.")
721            break
722        if varname not in ds.variables:
723            print(f"Variable '{varname}' not found. Try again.")
724            continue
725
726        # Display dimensions and size
727        var = ds.variables[varname]
728        dims, shape = var.dimensions, var.shape
729        print(f"\nVariable '{varname}' has dimensions:")
730        for dim, size in zip(dims, shape):
731            print(f"  - {dim} (size {size})")
732        print()
733
734        # Prepare slicing parameters
735        selection = get_dimension_indices(ds, varname)
736
737        # Plot the variable
738        plot_variable(
739            ds,
740            varname,
741            colormap    = 'jet',
742            output_path = None,
743            extra_indices = selection
744        )
745
746    ds.close()
747
748
749def visualize_variable_cli(nc_file, varname, time_index, alt_index,
750                           colormap, output_path, extra_json, avg_lat):
751    """
752    Command-line mode: visualize directly, parsing the --extra-indices argument (JSON string).
753    """
754    if not os.path.isfile(nc_file):
755        print(f"Error: '{nc_file}' not found.")
756        return
757    ds = Dataset(nc_file, "r")
758    if varname not in ds.variables:
759        print(f"Variable '{varname}' not in file.")
760        ds.close()
761        return
762
763    # Display dimensions and size
764    dims  = ds.variables[varname].dimensions
765    shape = ds.variables[varname].shape
766    print(f"\nVariable '{varname}' has {len(dims)} dimensions:")
767    for name, size in zip(dims, shape):
768        print(f"  - {name}: size {size}")
769    print()
770
771    # Special case: time-only → plot directly
772    t_idx = find_dim_index(dims, TIME_DIMS)
773    if (
774        t_idx is not None and shape[t_idx] > 1 and
775        all(shape[i] == 1 for i in range(len(dims)) if i != t_idx)
776    ):
777        print("Detected single-point spatial dims; plotting time series…")
778        var_obj = ds.variables[varname]
779        data = var_obj[:].squeeze()
780        time_var = find_coord_var(ds, TIME_DIMS)
781        if time_var:
782            tvals = ds.variables[time_var][:]
783        else:
784            tvals = np.arange(data.shape[0])
785        if hasattr(data, "mask"):
786            data = np.where(data.mask, np.nan, data.data)
787        if hasattr(tvals, "mask"):
788            tvals = np.where(tvals.mask, np.nan, tvals.data)
789        plt.figure()
790        plt.plot(tvals, data, marker="o")
791        plt.xlabel(time_var or "Time Index")
792        plt.ylabel(varname + (f" ({var_obj.units})" if hasattr(var_obj, "units") else ""))
793        plt.title(f"{varname} vs {time_var or 'Index'}", fontweight='bold')
794        if output_path:
795            plt.savefig(output_path, bbox_inches="tight")
796            print(f"Saved to {output_path}")
797        else:
798            plt.show()
799        ds.close()
800        return
801
802    # if --avg-lat but lat/lon/Time not compatible → disable
803    lat_idx = find_dim_index(dims, LAT_DIMS)
804    lon_idx = find_dim_index(dims, LON_DIMS)
805    if avg_lat and not (
806        t_idx   is not None and shape[t_idx]  > 1 and
807        lat_idx is not None and shape[lat_idx] > 1 and
808        lon_idx is not None and shape[lon_idx] > 1
809    ):
810        print("Note: disabling --avg-lat (requires Time, lat & lon each >1).")
811        avg_lat = False
812
813    # Parse extra indices JSON
814    extra = {}
815    if extra_json:
816        try:
817            parsed = json.loads(extra_json)
818            for k, v in parsed.items():
819                if isinstance(v, int):
820                    if "slope" in k.lower():
821                        extra[k] = v - 1
822                    else:
823                        extra[k] = v
824        except:
825            print("Warning: bad extra-indices.")
826
827    plot_variable(ds, varname, time_index, alt_index,
828                  colormap, output_path, extra, avg_lat)
829    ds.close()
830
831
832def main():
833    parser = argparse.ArgumentParser()
834    parser.add_argument('nc_file', nargs='?', help='NetCDF file (omit for interactive)')
835    parser.add_argument('-v','--variable', help='Variable name')
836    parser.add_argument('-t','--time-index', type=int, help='Time index (0-based)')
837    parser.add_argument('-a','--alt-index', type=int, help='Altitude index (0-based)')
838    parser.add_argument('-c','--cmap', default='jet', help='Colormap')
839    parser.add_argument('--avg-lat', action='store_true', help='Average over latitude')
840    parser.add_argument('--show-polar', action='store_true', help='Enable polar-stereo views')
841    parser.add_argument('--show-3d', action='store_true', help='Enable 3D globe view')
842    parser.add_argument('-o','--output', help='Save figure path')
843    parser.add_argument('-e','--extra-indices', help='JSON string for other dims')
844    args = parser.parse_args()
845
846    if args.nc_file and args.variable:
847        visualize_variable_cli(
848            args.nc_file, args.variable,
849            args.time_index, args.alt_index,
850            args.cmap, args.output,
851            args.extra_indices, args.avg_lat
852        )
853    else:
854        visualize_variable_interactive(args.nc_file)
855
856
857if __name__ == "__main__":
858    main()
859
Note: See TracBrowser for help on using the repository browser.