source: trunk/LMDZ.MARS/util/display_netcdf.py

Last change on this file was 3810, checked in by jbclement, 20 hours ago

Mars PCM:
Big improvements of script "display_netcdf.py" in util folder. In particular:

  • Addition of a MOLA map which, if present, overlays on 2D map plots. A file "MOLA_1px_per_deg.npy" is provided for Mars in util folder;
  • Addition of the possibility to plot polar‐stereographic views;
  • More interactive (keep asking the user to plot variables until he/she quits for ex);
  • Simplification and generalization of the script to get user answers.

JBC

  • Property svn:executable set to *
File size: 28.4 KB
Line 
1#!/usr/bin/env python3
2##############################################################
3### Python script to visualize a variable in a NetCDF file ###
4##############################################################
5
6"""
7This script can display any numeric variable from a NetCDF file.
8It supports the following cases:
9  - 1D time series (Time)
10  - 1D vertical profiles (e.g., subsurface_layers)
11  - 2D latitude/longitude map
12  - 2D (Time × another dimension)
13  - Variables with dimension “physical_points” displayed on a 2D map if lat/lon are present
14  - Optionally average over latitude and plot longitude vs. time heatmap
15  - Scalar output (ndim == 0 after slicing)
16  - 2D cross-sections (altitude × latitude or altitude × longitude)
17
18Usage:
19  1) Command-line mode:
20       python display_netcdf.py /path/to/your_file.nc --variable VAR_NAME \
21           [--time-index 0] [--alt-index 0] [--cmap viridis] [--avg-lat] \
22           [--slice-lon-index 10] [--slice-lat-index 20] [--show-topo] \
23           [--output out.png] [--extra-indices '{"nslope": 1}']
24
25    --variable         : Name of the variable to visualize.
26    --time-index       : Index along the Time dimension (0-based, ignored for purely 1D time series).
27    --alt-index        : Index along the altitude dimension (0-based), if present.
28    --cmap             : Matplotlib colormap (default: "jet").
29    --avg-lat          : Average over latitude and plot longitude vs. time heatmap.
30    --slice-lon-index  : Fixed longitude index for altitude×longitude cross-section.
31    --slice-lat-index  : Fixed latitude index for altitude×latitude cross-section.
32    --show-topo        : Overlay MOLA topography on lat/lon maps.
33    --output           : If provided, save the figure to this filename instead of displaying.
34    --extra-indices    : JSON string to fix indices for any other dimensions.
35                         For dimensions with "slope", use 1-based numbering here.
36                         Example: '{"nslope": 1, "physical_points": 3}'
37
38  2) Interactive mode:
39       python display_netcdf.py
40       (The script will prompt for everything, including averaging or slicing options.)
41"""
42
43import os
44import sys
45import glob
46import readline
47import argparse
48import json
49import numpy as np
50import matplotlib.pyplot as plt
51import matplotlib.tri as mtri
52import matplotlib.path as mpath
53import cartopy.crs as ccrs
54from netCDF4 import Dataset
55
56# Constants for recognized dimension names
57TIME_DIMS = ("Time", "time", "time_counter")
58ALT_DIMS  = ("altitude",)
59LAT_DIMS  = ("latitude", "lat")
60LON_DIMS  = ("longitude", "lon")
61
62# Attempt to load MOLA topography
63try:
64    MOLA = np.load('MOLA_1px_per_deg.npy')  # shape (nlat, nlon) at 1° per pixel: lat from -90 to 90, lon from 0 to 360
65    nlat, nlon = MOLA.shape
66    topo_lats = np.linspace(90 - 0.5, -90 + 0.5, nlat)
67    topo_lons = np.linspace(-180 + 0.5, 180 - 0.5, nlon)
68    topo_lon2d, topo_lat2d = np.meshgrid(topo_lons, topo_lats)
69    topo_loaded = True
70    print("MOLA topography loaded successfully from 'MOLA_1px_per_deg.npy'.")
71except Exception as e:
72    topo_loaded = False
73    print(f"Warning: failed to load MOLA topography ('MOLA_1px_per_deg.npy'): {e}")
74
75
76def complete_filename(text, state):
77    """
78    Tab-completion for filesystem paths.
79    """
80    if "*" not in text:
81        pattern = text + "*"
82    else:
83        pattern = text
84    matches = glob.glob(os.path.expanduser(pattern))
85    matches = [m + "/" if os.path.isdir(m) else m for m in matches]
86    try:
87        return matches[state]
88    except IndexError:
89        return None
90
91
92def make_varname_completer(varnames):
93    """
94    Returns a readline completer for variable names.
95    """
96    def completer(text, state):
97        options = [name for name in varnames if name.startswith(text)]
98        try:
99            return options[state]
100        except IndexError:
101            return None
102    return completer
103
104
105def find_dim_index(dims, candidates):
106    """
107    Search through dims tuple for any name in candidates.
108    Returns the index if found, else returns None.
109    """
110    for idx, dim in enumerate(dims):
111        for cand in candidates:
112            if cand.lower() == dim.lower():
113                return idx
114    return None
115
116
117def find_coord_var(dataset, candidates):
118    """
119    Among dataset variables, return the first variable whose name matches any candidate.
120    Returns None if none found.
121    """
122    for name in dataset.variables:
123        for cand in candidates:
124            if cand.lower() == name.lower():
125                return name
126    return None
127
128
129def overlay_topography(ax, transform, levels=10):
130    """
131    Overlay MOLA topography contours onto a given GeoAxes.
132    """
133    if not topo_loaded:
134        return
135    ax.contour(
136        topo_lon2d, topo_lat2d, MOLA,
137        levels=levels,
138        linewidths=0.5,
139        colors='black',
140        transform=transform
141    )
142
143
144def plot_polar_views(lon2d, lat2d, data2d, colormap, varname, units=None, topo_overlay=True):
145    """
146    Plot two polar‐stereographic views (north & south) of the same data.
147    """
148    figs = []  # collect figure handles
149
150    for pole in ("north", "south"):
151        # Choose projection and extent for each pole
152        if pole == "north":
153            proj = ccrs.NorthPolarStereo(central_longitude=180)
154            extent = [-180, 180, 60, 90]
155        else:
156            proj = ccrs.SouthPolarStereo(central_longitude=180)
157            extent = [-180, 180, -90, -60]
158
159        # Create figure and GeoAxes
160        fig = plt.figure(figsize=(8, 6))
161        ax = fig.add_subplot(1, 1, 1, projection=proj, aspect=True)
162        ax.set_global()
163        ax.set_extent(extent, ccrs.PlateCarree())
164
165        # Draw circular boundary
166        theta = np.linspace(0, 2 * np.pi, 100)
167        center, radius = [0.5, 0.5], 0.5
168        verts = np.vstack([np.sin(theta), np.cos(theta)]).T
169        circle = mpath.Path(verts * radius + center)
170        ax.set_boundary(circle, transform=ax.transAxes)
171
172        # Add meridians/parallels
173        gl = ax.gridlines(
174            draw_labels=True,
175            color='k',
176            xlocs=range(-180, 181, 30),
177            ylocs=range(-90, 91, 10),
178            linestyle='--',
179            linewidth=0.5
180        )
181        #gl.top_labels = False
182        #gl.right_labels = False
183
184        # Plot data in PlateCarree projection
185        cf = ax.contourf(
186            lon2d, lat2d, data2d,
187            levels=100,
188            cmap=colormap,
189            transform=ccrs.PlateCarree()
190        )
191
192        # Optionally overlay MOLA topography
193        if topo_overlay:
194            overlay_topography(ax, transform=ccrs.PlateCarree(), levels=20)
195
196        # Colorbar and title
197        cbar = fig.colorbar(cf, ax=ax, pad=0.1)
198        label = varname + (f" ({units})" if units else "")
199        cbar.set_label(label)
200        ax.set_title(f"{varname} — {pole.capitalize()} Pole", pad=50)
201
202        figs.append(fig)
203
204    # Show both figures
205    plt.show()
206
207def plot_variable(dataset, varname, time_index=None, alt_index=None,
208                  colormap="jet", output_path=None, extra_indices=None,
209                  avg_lat=False):
210    """
211    Core plotting logic: reads the variable, handles masks,
212    determines dimensionality, and creates the appropriate plot:
213      - 1D time series
214      - 1D profiles or physical_points maps
215      - 2D lat×lon or generic 2D
216      - Time×lon heatmap if avg_lat=True
217      - Scalar printing
218    """
219    var = dataset.variables[varname]
220    dims = var.dimensions
221
222    # Read full data
223    try:
224        data_full = var[:]
225    except Exception as e:
226        print(f"Error: Cannot read data for '{varname}': {e}")
227        return
228    if hasattr(data_full, "mask"):
229        data_full = np.where(data_full.mask, np.nan, data_full.data)
230
231    # Pure 1D time series
232    if len(dims) == 1 and find_dim_index(dims, TIME_DIMS) is not None:
233        time_var = find_coord_var(dataset, TIME_DIMS)
234        tvals = (dataset.variables[time_var][:] if time_var
235                 else np.arange(data_full.shape[0]))
236        if hasattr(tvals, "mask"):
237            tvals = np.where(tvals.mask, np.nan, tvals.data)
238        plt.figure()
239        plt.plot(tvals, data_full, marker="o")
240        plt.xlabel(time_var or "Time Index")
241        plt.ylabel(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
242        plt.title(f"{varname} vs {time_var or 'Index'}")
243        if output_path:
244            plt.savefig(output_path, bbox_inches="tight")
245            print(f"Saved to {output_path}")
246        else:
247            plt.show()
248        return
249
250    # Identify dims
251    t_idx = find_dim_index(dims, TIME_DIMS)
252    lat_idx = find_dim_index(dims, LAT_DIMS)
253    lon_idx = find_dim_index(dims, LON_DIMS)
254    a_idx = find_dim_index(dims, ALT_DIMS)
255
256    # Average over latitude & plot time × lon heatmap
257    if avg_lat and t_idx is not None and lat_idx is not None and lon_idx is not None:
258        # compute mean over lat axis
259        data_avg = np.nanmean(data_full, axis=lat_idx)
260        # prepare coordinates
261        time_var = find_coord_var(dataset, TIME_DIMS)
262        lon_var = find_coord_var(dataset, LON_DIMS)
263        tvals = dataset.variables[time_var][:]
264        lons = dataset.variables[lon_var][:]
265        if hasattr(tvals, "mask"):
266            tvals = np.where(tvals.mask, np.nan, tvals.data)
267        if hasattr(lons, "mask"):
268            lons = np.where(lons.mask, np.nan, lons.data)
269        plt.figure(figsize=(10, 6))
270        plt.pcolormesh(lons, tvals, data_avg, shading="auto", cmap=colormap)
271        plt.xlabel(f"Longitude ({getattr(dataset.variables[lon_var], 'units', 'deg')})")
272        plt.ylabel(time_var)
273        cbar = plt.colorbar()
274        cbar.set_label(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
275        plt.title(f"{varname} averaged over latitude")
276        if output_path:
277            plt.savefig(output_path, bbox_inches="tight")
278            print(f"Saved to {output_path}")
279        else:
280            plt.show()
281        return
282
283    # Build slicer for other cases
284    slicer = [slice(None)] * len(dims)
285    if t_idx is not None:
286        if time_index is None:
287            print("Error: please supply a time index.")
288            return
289        slicer[t_idx] = time_index
290    if a_idx is not None:
291        if alt_index is None:
292            print("Error: please supply an altitude index.")
293            return
294        slicer[a_idx] = alt_index
295
296    if extra_indices is None:
297        extra_indices = {}
298    for dn, idx_val in extra_indices.items():
299        if dn in dims:
300            slicer[dims.index(dn)] = idx_val
301
302    # Extract slice
303    try:
304        dslice = data_full[tuple(slicer)]
305    except Exception as e:
306        print(f"Error slicing '{varname}': {e}")
307        return
308
309    # Scalar
310    if np.ndim(dslice) == 0:
311        print(f"Scalar '{varname}': {float(dslice)}")
312        return
313
314    # 1D: vector, profile, or physical_points
315    if dslice.ndim == 1:
316        rem = [(i, name) for i, name in enumerate(dims) if slicer[i] == slice(None)]
317        if rem:
318            di, dname = rem[0]
319            # physical_points → interpolated map
320            if dname.lower() == "physical_points":
321                latv = find_coord_var(dataset, LAT_DIMS)
322                lonv = find_coord_var(dataset, LON_DIMS)
323                if latv and lonv:
324                    lats = dataset.variables[latv][:]
325                    lons = dataset.variables[lonv][:]
326
327                    # Unmask
328                    if hasattr(lats, "mask"):
329                        lats = np.where(lats.mask, np.nan, lats.data)
330                    if hasattr(lons, "mask"):
331                        lons = np.where(lons.mask, np.nan, lons.data)
332
333                    # Convert radians to degrees if needed
334                    lats_deg = np.round(np.degrees(lats), 6)
335                    lons_deg = np.round(np.degrees(lons), 6)
336
337                    # Build regular grid
338                    uniq_lats = np.unique(lats_deg)
339                    uniq_lons = np.unique(lons_deg)
340                    nlon = len(uniq_lons)
341
342                    data2d = []
343                    for lat_val in uniq_lats:
344                        mask = lats_deg == lat_val
345                        slice_vals = dslice[mask]
346                        lons_at_lat = lons_deg[mask]
347                        if len(slice_vals) == 1:
348                            row = np.full(nlon, slice_vals[0])
349                        else:
350                            order = np.argsort(lons_at_lat)
351                            row = np.full(nlon, np.nan)
352                            row[: len(slice_vals)] = slice_vals[order]
353                        data2d.append(row)
354                    data2d = np.array(data2d)
355
356                    # Wrap longitude if needed
357                    if -180.0 in uniq_lons:
358                        idx = np.where(np.isclose(uniq_lons, -180.0))[0][0]
359                        data2d = np.hstack([data2d, data2d[:, [idx]]])
360                        uniq_lons = np.append(uniq_lons, 180.0)
361
362                    # Plot interpolated map
363                    proj = ccrs.PlateCarree()
364                    fig, ax = plt.subplots(subplot_kw=dict(projection=proj), figsize=(8, 6))
365                    lon2d, lat2d = np.meshgrid(uniq_lons, uniq_lats)
366                    lon_ticks = np.arange(-180, 181, 30)
367                    lat_ticks = np.arange(-90, 91, 30)
368                    ax.set_xticks(lon_ticks, crs=ccrs.PlateCarree())
369                    ax.set_yticks(lat_ticks, crs=ccrs.PlateCarree())
370                    ax.tick_params(
371                        axis='x', which='major',
372                        length=4,
373                        direction='out',
374                        pad=2,
375                        labelsize=8
376                    )
377                    ax.tick_params(
378                       axis='y', which='major',
379                       length=4,
380                       direction='out',
381                       pad=2,
382                       labelsize=8
383                    )
384                    cf = ax.contourf(
385                        lon2d, lat2d, data2d,
386                        levels=100,
387                        cmap=colormap,
388                        transform=proj
389                    )
390
391                    # Overlay MOLA topography
392                    overlay_topography(ax, transform=proj, levels=10)
393
394                    # Colorbar & labels
395                    cbar = fig.colorbar(cf, ax=ax, pad=0.02)
396                    cbar.set_label(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
397                    ax.set_title(f"{varname} (interpolated map over physical_points)")
398                    ax.set_xlabel(f"Longitude ({getattr(dataset.variables[lonv], 'units', 'deg')})")
399                    ax.set_ylabel(f"Latitude ({getattr(dataset.variables[latv], 'units', 'deg')})")
400
401                    # Prompt for polar-stereo views if interactive
402                    if sys.stdin.isatty() and input("Display polar-stereo views? [y/n]: ").strip().lower() == "y":
403                        units = getattr(dataset.variables[varname], "units", None)
404                        plot_polar_views(lon2d, lat2d, data2d, colormap, varname, units)
405
406                    if output_path:
407                        plt.savefig(output_path, bbox_inches="tight")
408                        print(f"Saved to {output_path}")
409                    else:
410                        plt.show()
411                    return
412            # vertical profile?
413            coord = None
414            if dname.lower() == "subsurface_layers" and "soildepth" in dataset.variables:
415                coord = "soildepth"
416            elif dname in dataset.variables:
417                coord = dname
418            if coord:
419                coords = dataset.variables[coord][:]
420                if hasattr(coords, "mask"):
421                    coords = np.where(coords.mask, np.nan, coords.data)
422                plt.figure()
423                plt.plot(dslice, coords, marker="o")
424                if dname.lower() == "subsurface_layers":
425                    plt.gca().invert_yaxis()
426                plt.xlabel(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
427                plt.ylabel(coord + (f" ({dataset.variables[coord].units})" if hasattr(dataset.variables[coord], "units") else ""))
428                plt.title(f"{varname} vs {coord}")
429                if output_path:
430                    plt.savefig(output_path, bbox_inches="tight")
431                    print(f"Saved to {output_path}")
432                else:
433                    plt.show()
434                return
435        # generic 1D
436        plt.figure()
437        plt.plot(dslice, marker="o")
438        plt.xlabel("Index")
439        plt.ylabel(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
440        plt.title(f"{varname} (1D)")
441        if output_path:
442            plt.savefig(output_path, bbox_inches="tight")
443            print(f"Saved to {output_path}")
444        else:
445            plt.show()
446        return
447
448    # if dslice.ndim == 2:
449        lat_idx2 = find_dim_index(dims, LAT_DIMS)
450        lon_idx2 = find_dim_index(dims, LON_DIMS)
451
452        # Geographic lat×lon slice
453        if lat_idx2 is not None and lon_idx2 is not None:
454            latv = find_coord_var(dataset, LAT_DIMS)
455            lonv = find_coord_var(dataset, LON_DIMS)
456            lats = dataset.variables[latv][:]
457            lons = dataset.variables[lonv][:]
458
459            # Handle masked arrays
460            if hasattr(lats, "mask"):
461                lats = np.where(lats.mask, np.nan, lats.data)
462            if hasattr(lons, "mask"):
463                lons = np.where(lons.mask, np.nan, lons.data)
464
465            # Create map projection
466            proj = ccrs.PlateCarree()
467            fig, ax = plt.subplots(figsize=(10, 6), subplot_kw=dict(projection=proj))
468
469            # Make meshgrid and plot
470            lon2d, lat2d = np.meshgrid(lons, lats)
471            cf = ax.contourf(
472                lon2d, lat2d, dslice,
473                levels=100,
474                cmap=colormap,
475                transform=proj
476            )
477
478            # Overlay topography
479            overlay_topography(ax, transform=proj, levels=10)
480
481            # Colorbar and labels
482            lon_ticks = np.arange(-180, 181, 30)
483            lat_ticks = np.arange(-90, 91, 30)
484            ax.set_xticks(lon_ticks, crs=ccrs.PlateCarree())
485            ax.set_yticks(lat_ticks, crs=ccrs.PlateCarree())
486            ax.tick_params(
487                axis='x', which='major',
488                length=4,
489                direction='out',
490                pad=2,
491                labelsize=8
492            )
493            ax.tick_params(
494                axis='y', which='major',
495                length=4,
496                direction='out',
497                pad=2,
498                labelsize=8
499            )
500            cbar = fig.colorbar(cf, ax=ax, orientation="vertical", pad=0.02)
501            cbar.set_label(varname + (f" ({dataset.variables[varname].units})"
502                                      if hasattr(dataset.variables[varname], "units") else ""))
503            ax.set_title(f"{varname} (lat × lon)")
504            ax.set_xlabel(f"Longitude ({getattr(dataset.variables[lonv], 'units', 'deg')})")
505            ax.set_ylabel(f"Latitude ({getattr(dataset.variables[latv], 'units', 'deg')})")
506
507            # Prompt for polar-stereo views if interactive
508            if sys.stdin.isatty() and input("Display polar-stereo views? [y/n]: ").strip().lower() == "y":
509                units = getattr(dataset.variables[varname], "units", None)
510                plot_polar_views(lon2d, lat2d, dslice, colormap, varname, units)
511
512            if output_path:
513                plt.savefig(output_path, bbox_inches="tight")
514                print(f"Saved to {output_path}")
515            else:
516                plt.show()
517            return
518
519        # Generic 2D
520        plt.figure(figsize=(8, 6))
521        plt.imshow(dslice, aspect="auto")
522        plt.colorbar(label=varname + (f" ({var.units})" if hasattr(var, "units") else ""))
523        plt.xlabel("Dim 2 index")
524        plt.ylabel("Dim 1 index")
525        plt.title(f"{varname} (2D)")
526        if output_path:
527            plt.savefig(output_path, bbox_inches="tight")
528            print(f"Saved to {output_path}")
529        else:
530            plt.show()
531        return
532
533    print(f"Error: ndim={dslice.ndim} not supported.")
534
535
536def visualize_variable_interactive(nc_path=None):
537    """
538    Interactive loop: keep prompting for variables to plot until user quits.
539    """
540    # Open dataset
541    if nc_path:
542        path = nc_path
543    else:
544        readline.set_completer(complete_filename)
545        readline.parse_and_bind("tab: complete")
546        path = input("Enter path to NetCDF file: ").strip()
547
548    if not os.path.isfile(path):
549        print(f"Error: '{path}' not found.")
550        return
551
552    ds = Dataset(path, "r")
553    var_list = list(ds.variables.keys())
554    if not var_list:
555        print("No variables found in file.")
556        ds.close()
557        return
558
559    # Enable interactive mode
560    plt.ion()
561
562    while True:
563        # Enable tab-completion for variable names
564        readline.set_completer(make_varname_completer(var_list))
565        readline.parse_and_bind("tab: complete")
566
567        print("\nAvailable variables:")
568        for name in var_list:
569            print(f"  - {name}")
570        varname = input("\nEnter variable name to plot (or 'q' to quit): ").strip()
571        if varname.lower() in ("q", "quit", "exit"):
572            print("Exiting.")
573            break
574        if varname not in ds.variables:
575            print(f"Variable '{varname}' not found. Try again.")
576            continue
577
578        # Display dimensions and size
579        var = ds.variables[varname]
580        dims, shape = var.dimensions, var.shape
581        print(f"\nVariable '{varname}' has dimensions:")
582        for dim, size in zip(dims, shape):
583            print(f"  - {dim}: size {size}")
584        print()
585
586        # Prepare slicing parameters
587        time_index = None
588        alt_index = None
589        avg = False
590        extra_indices = {}
591
592        # Time index
593        t_idx = find_dim_index(dims, TIME_DIMS)
594        if t_idx is not None:
595            if shape[t_idx] > 1:
596                while True:
597                    idx = input(f"Enter time index [1–{shape[t_idx]}] (press Enter for all): ").strip()
598                    if idx == '':
599                        time_index = None
600                        break
601                    if idx.isdigit():
602                        i = int(idx)
603                        if 1 <= i <= shape[t_idx]:
604                            time_index = i - 1
605                            break
606                    print("Invalid entry. Please enter a valid number or press Enter.")
607            else:
608                time_index = 0
609
610        # Altitude index
611        a_idx = find_dim_index(dims, ALT_DIMS)
612        if a_idx is not None:
613            if shape[a_idx] > 1:
614                while True:
615                    idx = input(f"Enter altitude index [1–{shape[a_idx]}] (press Enter for all): ").strip()
616                    if idx == '':
617                        alt_index = None
618                        break
619                    if idx.isdigit():
620                        i = int(idx)
621                        if 1 <= i <= shape[a_idx]:
622                            alt_index = i - 1
623                            break
624                    print("Invalid entry. Please enter a valid number or press Enter.")
625            else:
626                alt_index = 0
627
628        # Average over latitude?
629        lat_idx = find_dim_index(dims, LAT_DIMS)
630        lon_idx = find_dim_index(dims, LON_DIMS)
631        if (t_idx is not None and lat_idx is not None and lon_idx is not None and
632            shape[t_idx] > 1 and shape[lat_idx] > 1 and shape[lon_idx] > 1):
633            resp = input("Average over latitude and plot lon vs time? [y/n]: ").strip().lower()
634            avg = (resp == 'y')
635
636        # Other dimensions
637        for i, dname in enumerate(dims):
638            if i in (t_idx, a_idx):
639                continue
640            size = shape[i]
641            if size == 1:
642                extra_indices[dname] = 0
643                continue
644            while True:
645                idx = input(f"Enter index [1–{size}] for '{dname}' (press Enter for all): ").strip()
646                if idx == '':
647                    # keep all values
648                    break
649                if idx.isdigit():
650                    j = int(idx)
651                    if 1 <= j <= size:
652                        extra_indices[dname] = j - 1
653                        break
654                print("Invalid entry. Please enter a valid number or press Enter.")
655
656        # Plot the variable
657        plot_variable(
658            ds, varname,
659            time_index    = time_index,
660            alt_index     = alt_index,
661            colormap      = 'jet',
662            output_path   = None,
663            extra_indices = extra_indices,
664            avg_lat       = avg
665        )
666
667    ds.close()
668
669
670def visualize_variable_cli(nc_file, varname, time_index, alt_index,
671                           colormap, output_path, extra_json, avg_lat):
672    """
673    Command-line mode: visualize directly, parsing the --extra-indices argument (JSON string).
674    """
675    if not os.path.isfile(nc_file):
676        print(f"Error: '{nc_file}' not found.")
677        return
678    ds = Dataset(nc_file, "r")
679    if varname not in ds.variables:
680        print(f"Variable '{varname}' not in file.")
681        ds.close()
682        return
683
684    # Display dimensions and size
685    dims  = ds.variables[varname].dimensions
686    shape = ds.variables[varname].shape
687    print(f"\nVariable '{varname}' has {len(dims)} dimensions:")
688    for name, size in zip(dims, shape):
689        print(f"  - {name}: size {size}")
690    print()
691
692    # Special case: time-only → plot directly
693    t_idx = find_dim_index(dims, TIME_DIMS)
694    if (
695        t_idx is not None and shape[t_idx] > 1 and
696        all(shape[i] == 1 for i in range(len(dims)) if i != t_idx)
697    ):
698        print("Detected single-point spatial dims; plotting time series…")
699        var_obj = ds.variables[varname]
700        data = var_obj[:].squeeze()
701        time_var = find_coord_var(ds, TIME_DIMS)
702        if time_var:
703            tvals = ds.variables[time_var][:]
704        else:
705            tvals = np.arange(data.shape[0])
706        if hasattr(data, "mask"):
707            data = np.where(data.mask, np.nan, data.data)
708        if hasattr(tvals, "mask"):
709            tvals = np.where(tvals.mask, np.nan, tvals.data)
710        plt.figure()
711        plt.plot(tvals, data, marker="o")
712        plt.xlabel(time_var or "Time Index")
713        plt.ylabel(varname + (f" ({var_obj.units})" if hasattr(var_obj, "units") else ""))
714        plt.title(f"{varname} vs {time_var or 'Index'}")
715        if output_path:
716            plt.savefig(output_path, bbox_inches="tight")
717            print(f"Saved to {output_path}")
718        else:
719            plt.show()
720        ds.close()
721        return
722
723    # if --avg-lat but lat/lon/Time not compatible → disable
724    lat_idx = find_dim_index(dims, LAT_DIMS)
725    lon_idx = find_dim_index(dims, LON_DIMS)
726    if avg_lat and not (
727        t_idx   is not None and shape[t_idx]  > 1 and
728        lat_idx is not None and shape[lat_idx] > 1 and
729        lon_idx is not None and shape[lon_idx] > 1
730    ):
731        print("Note: disabling --avg-lat (requires Time, lat & lon each >1).")
732        avg_lat = False
733
734    # Parse extra indices JSON
735    extra = {}
736    if extra_json:
737        try:
738            parsed = json.loads(extra_json)
739            for k, v in parsed.items():
740                if isinstance(v, int):
741                    if "slope" in k.lower():
742                        extra[k] = v - 1
743                    else:
744                        extra[k] = v
745        except:
746            print("Warning: bad extra-indices.")
747
748    plot_variable(ds, varname, time_index, alt_index,
749                  colormap, output_path, extra, avg_lat)
750    ds.close()
751
752
753def main():
754    parser = argparse.ArgumentParser()
755    parser.add_argument("nc_file", nargs="?", help="NetCDF file (omit for interactive)")
756    parser.add_argument("-v", "--variable", help="Variable name")
757    parser.add_argument("-t", "--time-index", type=int, help="Time index (0-based)")
758    parser.add_argument("-a", "--alt-index", type=int, help="Altitude index (0-based)")
759    parser.add_argument("-c", "--cmap", default="jet", help="Colormap")
760    parser.add_argument("--avg-lat", action="store_true",
761                        help="Average over latitude (time × lon heatmap)")
762    parser.add_argument("-o", "--output", help="Save figure path")
763    parser.add_argument("-e", "--extra-indices", help="JSON for other dims")
764    args = parser.parse_args()
765
766    if args.nc_file and args.variable:
767        visualize_variable_cli(
768            args.nc_file, args.variable,
769            args.time_index, args.alt_index,
770            args.cmap, args.output,
771            args.extra_indices, args.avg_lat
772        )
773    else:
774        visualize_variable_interactive(args.nc_file)
775
776
777if __name__ == "__main__":
778    main()
779
Note: See TracBrowser for help on using the repository browser.