source: trunk/LMDZ.MARS/util/display_netcdf.py @ 3849

Last change on this file since 3849 was 3849, checked in by jbclement, 3 days ago

Mars PCM:
Simplification and generalization of the "display_netcdf.py" script which can now handle more cases of shapes/dimensions in the variables + using 'pcolormesh' instead of 'imshow' to better stick to the original data.
JBC

  • Property svn:executable set to *
File size: 29.8 KB
Line 
1#!/usr/bin/env python3
2##############################################################
3### Python script to visualize a variable in a NetCDF file ###
4##############################################################
5
6"""
7This script can display any numeric variable from a NetCDF file.
8It supports the following cases:
9  - Scalar output
10  - 1D time series
11  - 1D vertical profiles
12  - 2D latitude/longitude map
13  - 2D cross-sections
14  - Optionally average over latitude and plot longitude vs. time heatmap
15  - Optionally display polar stereographic view of 2D maps
16  - Optionally display 3D globe view of 2D maps
17
18Automatic setup from the environment file found in the "LMDZ.MARS/util" folder:
19  1. Make sure Conda is installed.
20  2. In terminal, navigate to the folder containing this script.
21  4. Create the environment:
22       conda env create -f display_netcdf.yml
23  5. Activate the environment:
24       conda activate my_env
25  6. Run the script:
26     python display_netcdf.py
27
28Usage:
29  1) Command-line mode:
30       python display_netcdf.py /path/to/your_file.nc --variable VAR_NAME [options]
31     Options:
32    --variable         : Name of the variable to visualize.
33    --time-index       : Index along the Time dimension (0-based, ignored for purely 1D time series).
34    --alt-index        : Index along the altitude dimension (0-based), if present.
35    --cmap             : Matplotlib colormap (default: "jet").
36    --avg-lat          : Average over latitude and plot longitude vs. time heatmap.
37    --slice-lon-index  : Fixed longitude index for altitude×longitude cross-section.
38    --slice-lat-index  : Fixed latitude index for altitude×latitude cross-section.
39    --show-topo        : Overlay MOLA topography on lat/lon maps.
40    --show-polar       :
41    --show-3d          :
42    --output           : If provided, save the figure to this filename instead of displaying.
43    --extra-indices    : JSON string to fix indices for any other dimensions.
44                         For dimensions with "slope", use 1-based numbering here.
45                         Example: '{"nslope": 1, "physical_points": 3}'
46
47  2) Interactive mode:
48       python display_netcdf.py
49     The script will prompt for everything.
50"""
51
52import os
53import sys
54import glob
55import readline
56import argparse
57import json
58import numpy as np
59import matplotlib.pyplot as plt
60import matplotlib.path as mpath
61import matplotlib.colors as mcolors
62import cartopy.crs as ccrs
63import pandas as pd
64from netCDF4 import Dataset
65
66# Attempt vedo import early for global use
67try:
68    import vedo
69    from vedo import *
70    from scipy.interpolate import RegularGridInterpolator
71    vedo_available = True
72except ImportError:
73    vedo_available = False
74
75# Constants for recognized dimension names
76TIME_DIMS = ("Time", "time", "time_counter")
77ALT_DIMS  = ("altitude",)
78LAT_DIMS  = ("latitude", "lat")
79LON_DIMS  = ("longitude", "lon")
80
81# Paths for MOLA data
82MOLA_NPY = 'MOLA_1px_per_deg.npy'
83MOLA_CSV = 'molaTeam_contour_31rgb_steps.csv'
84
85# Attempt to load MOLA topography
86try:
87    MOLA = np.load('MOLA_1px_per_deg.npy') # shape (nlat, nlon) at 1° per pixel: lat from -90 to 90, lon from 0 to 360
88    nlat, nlon = MOLA.shape
89    topo_lats = np.linspace(90 - 0.5, -90 + 0.5, nlat)
90    topo_lons = np.linspace(-180 + 0.5, 180 - 0.5, nlon)
91    topo_lon2d, topo_lat2d = np.meshgrid(topo_lons, topo_lats)
92    topo_loaded = True
93except Exception as e:
94    print(f"Warning: '{MOLA_NPY}' not found: {e}")
95    topo_loaded = False
96
97
98# Attempt to load contour color table
99if os.path.isfile(MOLA_CSV):
100    color_table = pd.read_csv(MOLA_CSV)
101    csv_loaded = True
102else:
103    print(f"Warning: '{MOLA_CSV}' not found. 3D view colors disabled.")
104    csv_loaded = False
105
106
107def complete_filename(text, state):
108    """
109    Tab-completion for filesystem paths.
110    """
111    if "*" not in text:
112        pattern = text + "*"
113    else:
114        pattern = text
115    matches = glob.glob(os.path.expanduser(pattern))
116    matches = [m + "/" if os.path.isdir(m) else m for m in matches]
117    try:
118        return matches[state]
119    except IndexError:
120        return None
121
122
123def make_varname_completer(varnames):
124    """
125    Returns a readline completer for variable names.
126    """
127    def completer(text, state):
128        options = [name for name in varnames if name.startswith(text)]
129        try:
130            return options[state]
131        except IndexError:
132            return None
133    return completer
134
135
136def find_dim_index(dims, candidates):
137    """
138    Search through dims tuple for any name in candidates.
139    Returns the index if found, else returns None.
140    """
141    for idx, dim in enumerate(dims):
142        for cand in candidates:
143            if cand.lower() == dim.lower():
144                return idx
145    return None
146
147
148def find_coord_var(dataset, candidates):
149    """
150    Among dataset variables, return the first variable whose name matches any candidate.
151    Returns None if none found.
152    """
153    for name in dataset.variables:
154        for cand in candidates:
155            if cand.lower() == name.lower():
156                return name
157    return None
158
159
160def overlay_topography(ax, transform, levels=10):
161    """
162    Overlay MOLA topography contours onto a given GeoAxes.
163    """
164    if not topo_loaded:
165        return
166    ax.contour(
167        topo_lon2d, topo_lat2d, MOLA,
168        levels=levels,
169        linewidths=0.5,
170        colors='black',
171        transform=transform
172    )
173
174
175def attach_format_coord(ax, mat, x, y, is_pcolormesh=True, data_crs=ccrs.PlateCarree()):
176    """
177    Attach a format_coord function to the axes to display x, y, and value at cursor.
178    Works for both pcolormesh and imshow style grids.
179    """
180    # Determine dimensions
181    if mat.ndim == 2:
182        ny, nx = mat.shape
183    elif mat.ndim == 3 and mat.shape[2] in (3, 4):
184        ny, nx, nc = mat.shape
185    else:
186        raise ValueError(f"Unsupported mat shape {mat.shape}")
187    # Edges or extents
188    if is_pcolormesh:
189        xedges, yedges = x, y
190    else:
191        x0, x1 = x.min(), x.max()
192        y0, y1 = y.min(), y.max()
193
194    # Detect if ax is a GeoAxes with a projection we can invert
195    proj = getattr(ax, 'projection', None)
196    use_geo = isinstance(proj, ccrs.Projection)
197
198    def format_coord(xp, yp):
199        # If GeoAxes, invert back to geographic lon/lat
200        if use_geo:
201            try:
202                lonp, latp = data_crs.transform_point(xp, yp, src_crs=proj)
203            except Exception:
204                lonp, latp = xp, yp
205            xi, yi = lonp, latp
206        else:
207            xi, yi = xp, yp
208
209        # Map to matrix indices
210        if is_pcolormesh:
211            col = np.searchsorted(xedges, xi) - 1
212            row = np.searchsorted(yedges, yi) - 1
213        else:
214            col = int((xi - x0) / (x1 - x0) * nx)
215            row = int((yi - y0) / (y1 - y0) * ny)
216
217        # Within bounds?
218        if 0 <= row < ny and 0 <= col < nx:
219            if mat.ndim == 2:
220                v = mat[row, col]
221                return f"lon={xi:.3g}, lat={yi:.3g}, val={v:.3g}"
222            else:
223                vals = mat[row, col]
224                txt = ", ".join(f"{vv:.3g}" for vv in vals[:3])
225                return f"lon={xi:.3g}, lat={yi:.3g}, val=({txt})"
226        # Out of bounds: still show coords
227        return f"lon={xi:.3g}, lat={yi:.3g}"
228
229    ax.format_coord = format_coord
230
231
232def plot_polar_views(lon2d, lat2d, data2d, colormap, varname, units=None, topo_overlay=True):
233    """
234    Plot two polar‐stereographic views (north & south) of the same data.
235    """
236    figs = []  # collect figure handles
237
238    for pole in ("north", "south"):
239        # Choose projection and extent for each pole
240        if pole == "north":
241            proj = ccrs.NorthPolarStereo(central_longitude=180)
242            extent = [-180, 180, 60, 90]
243        else:
244            proj = ccrs.SouthPolarStereo(central_longitude=180)
245            extent = [-180, 180, -90, -60]
246
247        # Create figure and GeoAxes
248        fig = plt.figure(figsize=(8, 6))
249        ax = fig.add_subplot(1, 1, 1, projection=proj, aspect=True)
250        ax.set_global()
251        ax.set_extent(extent, ccrs.PlateCarree())
252
253        # Draw circular boundary
254        theta = np.linspace(0, 2 * np.pi, 100)
255        center, radius = [0.5, 0.5], 0.5
256        verts = np.vstack([np.sin(theta), np.cos(theta)]).T
257        circle = mpath.Path(verts * radius + center)
258        ax.set_boundary(circle, transform=ax.transAxes)
259
260        # Add meridians/parallels
261        gl = ax.gridlines(
262            draw_labels=True,
263            color='k',
264            xlocs=range(-180, 181, 30),
265            ylocs=range(-90, 91, 10),
266            linestyle='--',
267            linewidth=0.5
268        )
269
270        # Plot data in PlateCarree projection
271        cf = ax.pcolormesh(
272            lon2d, lat2d, data2d,
273            shading='auto',
274            cmap=colormap,
275            transform=ccrs.PlateCarree()
276        )
277        uniq_lons = np.unique(lon2d.ravel())
278        uniq_lats = np.unique(lat2d.ravel())
279        attach_format_coord(ax, data2d, uniq_lons, uniq_lats, is_pcolormesh=True)
280
281        # Optionally overlay MOLA topography
282        if topo_overlay:
283            overlay_topography(ax, transform=ccrs.PlateCarree(), levels=20)
284
285        # Colorbar and title
286        cbar = fig.colorbar(cf, ax=ax, pad=0.1)
287        label = varname + (f" ({units})" if units else "")
288        cbar.set_label(label)
289        ax.set_title(f"{varname} — {pole.capitalize()} polar region", pad=20, y=1.05, fontsize=12, fontweight='bold')
290
291        figs.append(fig)
292
293    # Show both figures
294    plt.show()
295
296
297def plot_3D_globe(lon2d, lat2d, data2d, colormap, varname, units=None):
298    """
299    Plot a 3D globe view of the data using vedo, with surface coloring based on data2d
300    and overlaid contour lines from MOLA topography.
301    """
302    if not vedo_available:
303        print("3D view skipped: vedo missing.")
304        return
305    if not csv_loaded:
306        print("3D view skipped: color table missing.")
307        return
308
309    # Prepare MOLA grid
310    nlat, nlon = MOLA.shape
311    lats = np.linspace(90, -90, nlat)
312    lons = np.linspace(-180, 180, nlon)
313    lon_grid, lat_grid = np.meshgrid(lons, lats)
314
315    # Interpolate data2d onto MOLA grid
316    lat_data = np.linspace(-90, 90, data2d.shape[0])
317    lon_data = np.linspace(-180, 180, data2d.shape[1])
318    interp2d = RegularGridInterpolator((lat_data, lon_data), data2d,
319                                       bounds_error=False, fill_value=None)
320    newdata2d = interp2d((lat_grid, lon_grid))
321
322    # Generate contour lines from MOLA
323    cs = plt.contour(lon_grid, lat_grid, MOLA, levels=10, linewidths=0)
324    plt.clf()
325    contour_lines = []
326    radius = 3389500 # Mars average radius [m]
327    for segs, level in zip(cs.allsegs, cs.levels):
328        for verts in segs:
329            lon_c = verts[:, 0]
330            lat_c = verts[:, 1]
331            phi_c = np.radians(90 - lat_c)
332            theta_c = np.radians(lon_c)
333            elev = RegularGridInterpolator((lats, lons), MOLA,
334                                           bounds_error=False,
335                                           fill_value=0.0)((lat_c, lon_c))
336            r_cont = radius + elev * 10
337            x_c = r_cont * np.sin(phi_c) * np.cos(theta_c) * 1.002
338            y_c = r_cont * np.sin(phi_c) * np.sin(theta_c) * 1.002
339            z_c = r_cont * np.cos(phi_c) * 1.002
340            pts = np.column_stack([x_c, y_c, z_c])
341            if pts.shape[0] > 1:
342                contour_lines.append(Line(pts, c='k', lw=0.5))
343
344    # Create sphere surface mesh
345    phi = np.deg2rad(90 - lat_grid)
346    theta = np.deg2rad(lon_grid)
347    r = radius + MOLA * 10
348    x = r * np.sin(phi) * np.cos(theta)
349    y = r * np.sin(phi) * np.sin(theta)
350    z = r * np.cos(phi)
351    pts = np.stack([x.ravel(), y.ravel(), z.ravel()], axis=1)
352
353    # Build mesh faces
354    faces = []
355    for i in range(nlat - 1):
356        for j in range(nlon - 1):
357            p0 = i * nlon + j
358            p1 = p0 + 1
359            p2 = p0 + nlon
360            p3 = p2 + 1
361            faces.extend([(p0, p2, p1), (p1, p2, p3)])
362
363    mesh = Mesh([pts, faces])
364    mesh.cmap(colormap, newdata2d.ravel())
365    mesh.add_scalarbar(title=varname + (f' [{units}]' if units else ''), c='white')
366    mesh.lighting('default')
367
368    # Geographic grid lines
369    meridians, parallels, labels = [], [], []
370    zero_lon_offset = radius * 0.03
371    for lon in range(-150, 181, 30):
372        lat_line = np.linspace(-90, 90, nlat)
373        lon_line = np.full_like(lat_line, lon)
374        phi = np.deg2rad(90 - lat_line)
375        theta = np.deg2rad(lon_line)
376        elev = RegularGridInterpolator((lats, lons), MOLA)((lat_line, lon_line))
377        rr = radius + elev * 10
378        pts_line = np.column_stack([
379            rr * np.sin(phi) * np.cos(theta),
380            rr * np.sin(phi) * np.sin(theta),
381            rr * np.cos(phi)
382        ]) * 1.005
383        label_pos = pts_line[len(pts_line)//2]
384        norm = np.linalg.norm(label_pos)
385        label_pos_out = label_pos / norm * (norm + radius * 0.02)
386        if lon == 0:
387            label_pos_out[1] += zero_lon_offset
388        meridians.append(Line(pts_line, c='k', lw=1)#.flagpole(
389            #f"{lon}°",
390            #point=label_pos_out,
391            #offset=[0, 0, radius * 0.05],
392            #s=radius*0.01,
393            #c='yellow'
394        #).follow_camera()
395        )
396
397    for lat in range(-60, 91, 30):
398        lon_line = np.linspace(-180, 180, nlon)
399        lat_line = np.full_like(lon_line, lat)
400        phi = np.deg2rad(90 - lat_line)
401        theta = np.deg2rad(lon_line)
402        elev = RegularGridInterpolator((lats, lons), MOLA)((lat_line, lon_line))
403        rr = radius + elev * 10
404        pts_line = np.column_stack([
405            rr * np.sin(phi) * np.cos(theta),
406            rr * np.sin(phi) * np.sin(theta),
407            rr * np.cos(phi)
408        ]) * 1.005
409        label_pos = pts_line[len(pts_line)//2]
410        norm = np.linalg.norm(label_pos)
411        label_pos_out = label_pos / norm * (norm + radius * 0.02)
412        parallels.append(Line(pts_line, c='k', lw=1)#.flagpole(
413            #f"{lat}°",
414            #point=label_pos_out,
415            #offset=[0, 0, radius * 0.05],
416            #s=radius*0.01,
417            #c='yellow'
418        #).follow_camera()
419        )
420
421    # Create plotter
422    plotter = Plotter(title="3D globe view", bg="bb", axes=0)
423
424    # Configure camera
425    cam_dist = radius * 3
426    plotter.camera.SetPosition([cam_dist, 0, 0])
427    plotter.camera.SetFocalPoint([0, 0, 0])
428    plotter.camera.SetViewUp([0, 0, 1])
429
430    # Show the globe
431    plotter.show(mesh, *contour_lines, *meridians, *parallels)
432
433
434def transform_physical_points(dataset, var, data):
435    """
436    Transform a physical_points 1D array into a 2D grid of shape (nlat, nlon).
437    """
438    # Fetch lat/lon coordinate variables
439    lat_var = find_coord_var(dataset, LAT_DIMS)
440    lon_var = find_coord_var(dataset, LON_DIMS)
441    if lat_var is None or lon_var is None:
442        raise ValueError("Cannot find latitude or longitude variables for physical_points")
443    raw_lats = dataset.variables[lat_var][:]
444    raw_lons = dataset.variables[lon_var][:]
445    # Unmask
446    if hasattr(raw_lats, 'mask'):
447        raw_lats = np.where(raw_lats.mask, np.nan, raw_lats.data)
448    if hasattr(raw_lons, 'mask'):
449        raw_lons = np.where(raw_lons.mask, np.nan, raw_lons.data)
450    # Convert radians to degrees if in radians
451    if np.max(np.abs(raw_lats)) <= np.pi:
452        raw_lats = np.degrees(raw_lats)
453        raw_lons = np.degrees(raw_lons)
454    # Get unique coords
455    uniq_lats = np.unique(raw_lats)
456    uniq_lons = np.unique(raw_lons)
457    # Initialize grid
458    grid = np.full((uniq_lats.size, uniq_lons.size), np.nan)
459    # Build the grid
460    for value, lat, lon in zip(data.ravel(), raw_lats.ravel(), raw_lons.ravel()):
461        i = np.where(np.isclose(uniq_lats, lat))[0][0]
462        j = np.where(np.isclose(uniq_lons, lon))[0][0]
463        grid[i, j] = value
464    # Duplicate the pole value across all longitudes
465    for i in (0, -1):
466        row = grid[i, :]
467        count_good = np.count_nonzero(~np.isnan(row))
468        if count_good == 1:
469            pole_value = row[~np.isnan(row)][0]
470            grid[i, :] = pole_value
471    # Wrap longitude if needed
472    if -180.0 in uniq_lons:
473        idx = np.where(np.isclose(uniq_lons, -180.0))[0][0]
474        grid = np.hstack([grid, grid[:, [idx]]])
475        uniq_lons = np.append(uniq_lons, 180.0)
476    return grid, uniq_lats, uniq_lons, lat_var, lon_var
477
478
479def get_dimension_indices(ds, varname):
480    """
481    For each dimension of the variable:
482     - if size == 1 → automatically select index 0
483     - otherwise prompt the user:
484         <number>     : take that specific index (1-based)
485         'a'          : average over this dimension
486         'e' or Enter : take all values
487    Returns {dim_name: int index, 'avg', or None}.
488    """
489    var = ds.variables[varname]
490    dims = var.dimensions
491    shape = var.shape
492    selection = {}
493    for dim, size in zip(dims, shape):
494        if size == 1:
495            selection[dim] = 0
496            continue
497        prompt = (
498                f"Available options for '{dim}' (size {size}):\n"
499                f"  - '1–{size}' to pick that index\n"
500                "  - 'a' to average over this dimension\n"
501                "  - 'e' or Enter to take all values\n"
502                "Choose: "
503        )
504        while True:
505            resp = input(prompt).strip().lower()
506            if resp in ("", "e"):
507                selection[dim] = None
508                break
509            if resp == 'a':
510                selection[dim] = 'avg'
511                break
512            if resp.isdigit():
513                n = int(resp)
514                if 1 <= n <= size:
515                    selection[dim] = n - 1
516                    break
517            print(f"  Invalid entry '{resp}'. Please enter a number, 'a', 'e', or just Enter.")
518    return selection
519
520
521def plot_variable(dataset, varname, colormap="jet", output_path=None, extra_indices=None):
522    """
523    Automatically select singleton dims, prompt for others,
524    allow user to choose x/y for 2D, handle special cases (physical_points, averaging).
525    """
526    var = dataset.variables[varname]
527    dims = list(var.dimensions)
528    # Read data
529    try:
530        data_full = var[:]
531    except Exception as e:
532        print(f"Error: Cannot read data for '{varname}': {e}")
533        return
534    # Unmask
535    if hasattr(data_full, 'mask'):
536        data_full = np.where(data_full.mask, np.nan, data_full.data)
537    # Initialize extra_indices
538    extra_indices = extra_indices or {}
539    # Handle averaging selections
540    for dim, mode in dict(extra_indices).items():
541        if mode == 'avg':
542            ax = dims.index(dim)
543            data_full = np.nanmean(data_full, axis=ax, keepdims=True)
544            extra_indices[dim] = 0
545    # Build slicer
546    slicer = []
547    for dim in dims:
548        idx = extra_indices.get(dim)
549        slicer.append(idx if isinstance(idx, int) else slice(None))
550    data_slice = data_full[tuple(slicer)]
551    nd = data_slice.ndim
552    # Special case: physical_points dimension
553    if nd == 1 and 'physical_points' in dims:
554        # Transform into 2D grid
555        grid, uniq_lats, uniq_lons, latv, lonv = transform_physical_points(dataset, var, data_slice)
556        # Plot map
557        proj = ccrs.PlateCarree()
558        fig, ax = plt.subplots(figsize=(8, 6), subplot_kw=dict(projection=proj))
559        lon2d, lat2d = np.meshgrid(uniq_lons, uniq_lats)
560        lon_ticks = np.arange(-180, 181, 30)
561        lat_ticks = np.arange(-90, 91, 30)
562        ax.set_xticks(lon_ticks, crs=ccrs.PlateCarree())
563        ax.set_yticks(lat_ticks, crs=ccrs.PlateCarree())
564        ax.tick_params(
565            axis='x', which='major',
566            length=4,
567            direction='out',
568            pad=2,
569            labelsize=8
570        )
571        ax.tick_params(
572            axis='y', which='major',
573            length=4,
574            direction='out',
575            pad=2,
576            labelsize=8
577        )
578        cf = ax.pcolormesh(lon2d, lat2d, grid, shading='auto', cmap=colormap, transform=ccrs.PlateCarree())
579        attach_format_coord(ax, grid, uniq_lons, uniq_lats, is_pcolormesh=True)
580        overlay_topography(ax, transform=proj, levels=10) # Overlay MOLA topography
581        cbar = fig.colorbar(cf, ax=ax, pad=0.02)
582        cbar.set_label(varname)
583        ax.set_title(f"{varname} (physical_points)", fontweight='bold')
584        ax.set_xlabel(f"Longitude ({getattr(dataset.variables[lonv], 'units', 'deg')})")
585        ax.set_ylabel(f"Latitude ({getattr(dataset.variables[latv], 'units', 'deg')})")
586        # Prompt for polar-stereo views if interactive
587        if input("Display polar-stereo views? [y/n]: ").strip().lower() == "y":
588            units = getattr(var, 'units', None)
589            plot_polar_views(lon2d, lat2d, grid, colormap, varname, units)
590        # Prompt for 3D globe view if interactive
591        if input("Display 3D globe view? [y/n]: ").strip().lower() == "y":
592            units = getattr(var, 'units', None)
593            plot_3D_globe(lon2d, lat2d, grid, colormap, varname, units)
594        if output_path:
595            fig.savefig(output_path, bbox_inches='tight')
596            print(f"Saved to {output_path}")
597        else:
598            plt.show()
599        return
600    # 0D
601    if nd == 0:
602        print(f"\033[36mScalar '{varname}': {float(data_slice)}\033[0m")
603        return
604    # 1D
605    if nd == 1:
606        rem = [(i, d) for i, (d, s) in enumerate(zip(dims, slicer)) if isinstance(s, slice)]
607        axis_idx, dim_name = rem[0]
608        coord_var = find_coord_var(dataset, [dim_name])
609        if coord_var:
610            x = dataset.variables[coord_var][:]
611            if hasattr(x, 'mask'):
612                x = np.where(x.mask, np.nan, x.data)
613            xlabel = coord_var
614        else:
615            x = np.arange(data_slice.shape[0])
616            xlabel = dim_name
617        y = data_slice
618        plt.figure(figsize=(8, 4))
619        plt.plot(x, y)
620        plt.grid(True)
621        plt.xlabel(xlabel)
622        plt.ylabel(varname)
623        plt.title(f"{varname} vs {xlabel}")
624        if output_path:
625            plt.savefig(output_path, bbox_inches='tight')
626            print(f"Saved plot to {output_path}")
627        else:
628            plt.show()
629        return
630    # 2D
631    if nd == 2:
632        remaining = [d for d, idx in zip(dims, slicer) if isinstance(idx, slice)]
633        # Choose X/Y interactively
634        resp = input(f"Which dimension on X? {remaining}: ").strip()
635        if resp == remaining[1]:
636            x_dim, y_dim = remaining[1], remaining[0]
637        else:
638            x_dim, y_dim = remaining[0], remaining[1]
639        def get_coords(dim):
640            coord_var = find_coord_var(dataset, [dim])
641            if coord_var:
642                arr = dataset.variables[coord_var][:]
643                if hasattr(arr, 'mask'):
644                    arr = np.where(arr.mask, np.nan, arr.data)
645                return arr
646            return np.arange(data_slice.shape[remaining.index(dim)])
647        x_coords = get_coords(x_dim)
648        y_coords = get_coords(y_dim)
649        order = [remaining.index(y_dim), remaining.index(x_dim)]
650        plot_data = np.moveaxis(data_slice, order, [0, 1])
651        fig, ax = plt.subplots(figsize=(8, 6))
652        im = ax.pcolormesh(x_coords, y_coords, plot_data, shading='auto', cmap=colormap)
653        attach_format_coord(ax, plot_data, x_coords, y_coords, is_pcolormesh=True)
654        cbar = fig.colorbar(im, ax=ax, pad=0.02)
655        cbar.set_label(varname)
656        ax.set_xlabel(x_dim)
657        ax.set_ylabel(y_dim)
658        ax.set_title(f"{varname} ({y_dim} vs {x_dim})")
659        ax.grid(True)
660        # Prompt for polar-stereo views if interactive
661        if sys.stdin.isatty() and input("Display polar-stereo views? [y/n]: ").strip().lower() == "y":
662            units = getattr(dataset.variables[varname], "units", None)
663            plot_polar_views(x_coords, y_coords, plot_data, colormap, varname, units)
664        # Prompt for 3D globe view if interactive
665        if sys.stdin.isatty() and input("Display 3D globe view? [y/n]: ").strip().lower() == "y":
666            units = getattr(dataset.variables[varname], "units", None)
667            plot_3D_globe(x_coords, y_coords, plot_data, colormap, varname, units)
668        if output_path:
669            fig.savefig(output_path, bbox_inches='tight')
670            print(f"Saved to {output_path}")
671        else:
672            plt.show()
673        return
674    print(f"Plotting for ndim={nd} not yet supported.")
675
676
677def visualize_variable_interactive(nc_path=None):
678    """
679    Interactive loop: keep prompting for variables to plot until user quits.
680    """
681    # Open dataset
682    if nc_path:
683        path = nc_path
684    else:
685        readline.set_completer(complete_filename)
686        readline.parse_and_bind("tab: complete")
687        path = input("Enter path to NetCDF file: ").strip()
688
689    if not os.path.isfile(path):
690        print(f"Error: '{path}' not found.")
691        return
692
693    ds = Dataset(path, "r")
694    var_list = list(ds.variables.keys())
695    if not var_list:
696        print("No variables found in file.")
697        ds.close()
698        return
699
700    # Enable interactive mode
701    plt.ion()
702
703    while True:
704        # Enable tab-completion for variable names
705        readline.set_completer(make_varname_completer(var_list))
706        readline.parse_and_bind("tab: complete")
707
708        print("\nAvailable variables:")
709        for name in var_list:
710            print(f"  - {name}")
711        varname = input("\nEnter variable name to plot (or 'q' to quit): ").strip()
712        if varname.lower() in ("q", "quit", "exit"):
713            print("Exiting.")
714            break
715        if varname not in ds.variables:
716            print(f"Variable '{varname}' not found. Try again.")
717            continue
718
719        # Display dimensions and size
720        var = ds.variables[varname]
721        dims, shape = var.dimensions, var.shape
722        print(f"\nVariable '{varname}' has dimensions:")
723        for dim, size in zip(dims, shape):
724            print(f"  > {dim}: size {size}")
725        print()
726
727        # Prepare slicing parameters
728        selection = get_dimension_indices(ds, varname)
729
730        # Plot the variable
731        plot_variable(
732            ds,
733            varname,
734            colormap    = 'jet',
735            output_path = None,
736            extra_indices = selection
737        )
738
739    ds.close()
740
741
742def visualize_variable_cli(nc_file, varname, time_index, alt_index,
743                           colormap, output_path, extra_json, avg_lat):
744    """
745    Command-line mode: visualize directly, parsing the --extra-indices argument (JSON string).
746    """
747    if not os.path.isfile(nc_file):
748        print(f"Error: '{nc_file}' not found.")
749        return
750    ds = Dataset(nc_file, "r")
751    if varname not in ds.variables:
752        print(f"Variable '{varname}' not in file.")
753        ds.close()
754        return
755
756    # Display dimensions and size
757    dims  = ds.variables[varname].dimensions
758    shape = ds.variables[varname].shape
759    print(f"\nVariable '{varname}' has {len(dims)} dimensions:")
760    for name, size in zip(dims, shape):
761        print(f"  - {name}: size {size}")
762    print()
763
764    # Special case: time-only → plot directly
765    t_idx = find_dim_index(dims, TIME_DIMS)
766    if (
767        t_idx is not None and shape[t_idx] > 1 and
768        all(shape[i] == 1 for i in range(len(dims)) if i != t_idx)
769    ):
770        print("Detected single-point spatial dims; plotting time series…")
771        var_obj = ds.variables[varname]
772        data = var_obj[:].squeeze()
773        time_var = find_coord_var(ds, TIME_DIMS)
774        if time_var:
775            tvals = ds.variables[time_var][:]
776        else:
777            tvals = np.arange(data.shape[0])
778        if hasattr(data, "mask"):
779            data = np.where(data.mask, np.nan, data.data)
780        if hasattr(tvals, "mask"):
781            tvals = np.where(tvals.mask, np.nan, tvals.data)
782        plt.figure()
783        plt.plot(tvals, data, marker="o")
784        plt.xlabel(time_var or "Time Index")
785        plt.ylabel(varname + (f" ({var_obj.units})" if hasattr(var_obj, "units") else ""))
786        plt.title(f"{varname} vs {time_var or 'Index'}", fontweight='bold')
787        if output_path:
788            plt.savefig(output_path, bbox_inches="tight")
789            print(f"Saved to {output_path}")
790        else:
791            plt.show()
792        ds.close()
793        return
794
795    # if --avg-lat but lat/lon/Time not compatible → disable
796    lat_idx = find_dim_index(dims, LAT_DIMS)
797    lon_idx = find_dim_index(dims, LON_DIMS)
798    if avg_lat and not (
799        t_idx   is not None and shape[t_idx]  > 1 and
800        lat_idx is not None and shape[lat_idx] > 1 and
801        lon_idx is not None and shape[lon_idx] > 1
802    ):
803        print("Note: disabling --avg-lat (requires Time, lat & lon each >1).")
804        avg_lat = False
805
806    # Parse extra indices JSON
807    extra = {}
808    if extra_json:
809        try:
810            parsed = json.loads(extra_json)
811            for k, v in parsed.items():
812                if isinstance(v, int):
813                    if "slope" in k.lower():
814                        extra[k] = v - 1
815                    else:
816                        extra[k] = v
817        except:
818            print("Warning: bad extra-indices.")
819
820    plot_variable(ds, varname, time_index, alt_index,
821                  colormap, output_path, extra, avg_lat)
822    ds.close()
823
824
825def main():
826    parser = argparse.ArgumentParser()
827    parser.add_argument('nc_file', nargs='?', help='NetCDF file (omit for interactive)')
828    parser.add_argument('-v','--variable', help='Variable name')
829    parser.add_argument('-t','--time-index', type=int, help='Time index (0-based)')
830    parser.add_argument('-a','--alt-index', type=int, help='Altitude index (0-based)')
831    parser.add_argument('-c','--cmap', default='jet', help='Colormap')
832    parser.add_argument('--avg-lat', action='store_true', help='Average over latitude')
833    parser.add_argument('--show-polar', action='store_true', help='Enable polar-stereo views')
834    parser.add_argument('--show-3d', action='store_true', help='Enable 3D globe view')
835    parser.add_argument('-o','--output', help='Save figure path')
836    parser.add_argument('-e','--extra-indices', help='JSON string for other dims')
837    args = parser.parse_args()
838
839    if args.nc_file and args.variable:
840        visualize_variable_cli(
841            args.nc_file, args.variable,
842            args.time_index, args.alt_index,
843            args.cmap, args.output,
844            args.extra_indices, args.avg_lat
845        )
846    else:
847        visualize_variable_interactive(args.nc_file)
848
849
850if __name__ == "__main__":
851    main()
852
Note: See TracBrowser for help on using the repository browser.