source: trunk/LMDZ.PLUTO/util/script_figures/display_netcdf.py @ 3833

Last change on this file since 3833 was 3833, checked in by afalco, 32 hours ago

Pluto: updated plots scripts.
Fixed some issues with reading XIOS, etc.
Included display_netcdf.py tool from Mars PCM.
AF

  • Property svn:executable set to *
File size: 34.5 KB
Line 
1#!/usr/bin/env python3
2##############################################################
3### Python script to visualize a variable in a NetCDF file ###
4##############################################################
5
6"""
7This script can display any numeric variable from a NetCDF file.
8It supports the following cases:
9  - Scalar output
10  - 1D time series
11  - 1D vertical profiles
12  - 2D latitude/longitude map
13  - 2D cross-sections
14  - Optionally average over latitude and plot longitude vs. time heatmap
15  - Optionally display polar stereographic view of 2D maps
16  - Optionally display 3D globe view of 2D maps
17
18Automatic setup from the environment file found in the "LMDZ.MARS/util" folder:
19  1. Make sure Conda is installed.
20  2. In terminal, navigate to the folder containing this script.
21  4. Create the environment:
22       conda env create -f display_netcdf.yml
23  5. Activate the environment:
24       conda activate my_env
25  6. Run the script:
26     python display_netcdf.py
27
28Usage:
29  1) Command-line mode:
30       python display_netcdf.py /path/to/your_file.nc --variable VAR_NAME [options]
31     Options:
32    --variable         : Name of the variable to visualize.
33    --time-index       : Index along the Time dimension (0-based, ignored for purely 1D time series).
34    --alt-index        : Index along the altitude dimension (0-based), if present.
35    --cmap             : Matplotlib colormap (default: "jet").
36    --avg-lat          : Average over latitude and plot longitude vs. time heatmap.
37    --slice-lon-index  : Fixed longitude index for altitude×longitude cross-section.
38    --slice-lat-index  : Fixed latitude index for altitude×latitude cross-section.
39    --show-topo        : Overlay MOLA topography on lat/lon maps.
40    --show-polar       :
41    --show-3d          :
42    --output           : If provided, save the figure to this filename instead of displaying.
43    --extra-indices    : JSON string to fix indices for any other dimensions.
44                         For dimensions with "slope", use 1-based numbering here.
45                         Example: '{"nslope": 1, "physical_points": 3}'
46
47  2) Interactive mode:
48       python display_netcdf.py
49     The script will prompt for everything.
50"""
51
52import os
53import sys
54import glob
55import readline
56import argparse
57import json
58import numpy as np
59import matplotlib.pyplot as plt
60import matplotlib.path as mpath
61import matplotlib.colors as mcolors
62import cartopy.crs as ccrs
63import pandas as pd
64from netCDF4 import Dataset
65
66# Attempt vedo import early for global use
67try:
68    import vedo
69    from vedo import *
70    from scipy.interpolate import RegularGridInterpolator
71    vedo_available = True
72except ImportError:
73    vedo_available = False
74
75# Constants for recognized dimension names
76TIME_DIMS = ("Time", "time", "time_counter")
77ALT_DIMS  = ("altitude",)
78LAT_DIMS  = ("latitude", "lat")
79LON_DIMS  = ("longitude", "lon")
80
81# Paths for MOLA data
82MOLA_NPY = 'MOLA_1px_per_deg.npy'
83MOLA_CSV = 'molaTeam_contour_31rgb_steps.csv'
84
85# Attempt to load MOLA topography
86try:
87    MOLA = np.load('MOLA_1px_per_deg.npy')  # shape (nlat, nlon) at 1° per pixel: lat from -90 to 90, lon from 0 to 360
88    nlat, nlon = MOLA.shape
89    topo_lats = np.linspace(90 - 0.5, -90 + 0.5, nlat)
90    topo_lons = np.linspace(-180 + 0.5, 180 - 0.5, nlon)
91    topo_lon2d, topo_lat2d = np.meshgrid(topo_lons, topo_lats)
92    topo_loaded = True
93except Exception as e:
94    print(f"Warning: '{MOLA_NPY}' not found: {e}")
95    topo_loaded = False
96
97
98# Attempt to load contour color table
99if os.path.isfile(MOLA_CSV):
100    color_table = pd.read_csv(MOLA_CSV)
101    csv_loaded = True
102else:
103    print(f"Warning: '{MOLA_CSV}' not found. 3D view colors disabled.")
104    csv_loaded = False
105
106
107def complete_filename(text, state):
108    """
109    Tab-completion for filesystem paths.
110    """
111    if "*" not in text:
112        pattern = text + "*"
113    else:
114        pattern = text
115    matches = glob.glob(os.path.expanduser(pattern))
116    matches = [m + "/" if os.path.isdir(m) else m for m in matches]
117    try:
118        return matches[state]
119    except IndexError:
120        return None
121
122
123def make_varname_completer(varnames):
124    """
125    Returns a readline completer for variable names.
126    """
127    def completer(text, state):
128        options = [name for name in varnames if name.startswith(text)]
129        try:
130            return options[state]
131        except IndexError:
132            return None
133    return completer
134
135
136def find_dim_index(dims, candidates):
137    """
138    Search through dims tuple for any name in candidates.
139    Returns the index if found, else returns None.
140    """
141    for idx, dim in enumerate(dims):
142        for cand in candidates:
143            if cand.lower() == dim.lower():
144                return idx
145    return None
146
147
148def find_coord_var(dataset, candidates):
149    """
150    Among dataset variables, return the first variable whose name matches any candidate.
151    Returns None if none found.
152    """
153    for name in dataset.variables:
154        for cand in candidates:
155            if cand.lower() == name.lower():
156                return name
157    return None
158
159
160def overlay_topography(ax, transform, levels=10):
161    """
162    Overlay MOLA topography contours onto a given GeoAxes.
163    """
164    if not topo_loaded:
165        return
166    ax.contour(
167        topo_lon2d, topo_lat2d, MOLA,
168        levels=levels,
169        linewidths=0.5,
170        colors='black',
171        transform=transform
172    )
173
174
175def plot_polar_views(lon2d, lat2d, data2d, colormap, varname, units=None, topo_overlay=True):
176    """
177    Plot two polar‐stereographic views (north & south) of the same data.
178    """
179    figs = []  # collect figure handles
180
181    for pole in ("north", "south"):
182        # Choose projection and extent for each pole
183        if pole == "north":
184            proj = ccrs.NorthPolarStereo(central_longitude=180)
185            extent = [-180, 180, 60, 90]
186        else:
187            proj = ccrs.SouthPolarStereo(central_longitude=180)
188            extent = [-180, 180, -90, -60]
189
190        # Create figure and GeoAxes
191        fig = plt.figure(figsize=(8, 6))
192        ax = fig.add_subplot(1, 1, 1, projection=proj, aspect=True)
193        ax.set_global()
194        ax.set_extent(extent, ccrs.PlateCarree())
195
196        # Draw circular boundary
197        theta = np.linspace(0, 2 * np.pi, 100)
198        center, radius = [0.5, 0.5], 0.5
199        verts = np.vstack([np.sin(theta), np.cos(theta)]).T
200        circle = mpath.Path(verts * radius + center)
201        ax.set_boundary(circle, transform=ax.transAxes)
202
203        # Add meridians/parallels
204        gl = ax.gridlines(
205            draw_labels=True,
206            color='k',
207            xlocs=range(-180, 181, 30),
208            ylocs=range(-90, 91, 10),
209            linestyle='--',
210            linewidth=0.5
211        )
212        #gl.top_labels = False
213        #gl.right_labels = False
214
215        # Plot data in PlateCarree projection
216        cf = ax.contourf(
217            lon2d, lat2d, data2d,
218            levels=100,
219            cmap=colormap,
220            transform=ccrs.PlateCarree()
221        )
222
223        # Optionally overlay MOLA topography
224        if topo_overlay:
225            overlay_topography(ax, transform=ccrs.PlateCarree(), levels=20)
226
227        # Colorbar and title
228        cbar = fig.colorbar(cf, ax=ax, pad=0.1)
229        label = varname + (f" ({units})" if units else "")
230        cbar.set_label(label)
231        ax.set_title(f"{varname} — {pole.capitalize()} Pole", pad=50)
232
233        figs.append(fig)
234
235    # Show both figures
236    plt.show()
237
238
239def plot_3D_globe(lon2d, lat2d, data2d, colormap, varname, units=None):
240    """
241    Plot a 3D globe view of the data using vedo, with surface coloring based on data2d
242    and overlaid contour lines from MOLA topography.
243    """
244    if not vedo_available:
245        print("3D view skipped: vedo missing.")
246        return
247    if not csv_loaded:
248        print("3D view skipped: color table missing.")
249        return
250
251    # Prepare MOLA grid
252    nlat, nlon = MOLA.shape
253    lats = np.linspace(90, -90, nlat)
254    lons = np.linspace(-180, 180, nlon)
255    lon_grid, lat_grid = np.meshgrid(lons, lats)
256
257    # Interpolate data2d onto MOLA grid
258    lat_data = np.linspace(-90, 90, data2d.shape[0])
259    lon_data = np.linspace(-180, 180, data2d.shape[1])
260    interp2d = RegularGridInterpolator((lat_data, lon_data), data2d,
261                                       bounds_error=False, fill_value=None)
262    newdata2d = interp2d((lat_grid, lon_grid))
263
264    # Generate contour lines from MOLA
265    cs = plt.contour(lon_grid, lat_grid, MOLA, levels=10, linewidths=0)
266    plt.clf()
267    contour_lines = []
268    radius = 3389500 # Mars average radius [m]
269    for segs, level in zip(cs.allsegs, cs.levels):
270        for verts in segs:
271            lon_c = verts[:, 0]
272            lat_c = verts[:, 1]
273            phi_c = np.radians(90 - lat_c)
274            theta_c = np.radians(lon_c)
275            elev = RegularGridInterpolator((lats, lons), MOLA,
276                                           bounds_error=False,
277                                           fill_value=0.0)((lat_c, lon_c))
278            r_cont = radius + elev * 10
279            x_c = r_cont * np.sin(phi_c) * np.cos(theta_c) * 1.002
280            y_c = r_cont * np.sin(phi_c) * np.sin(theta_c) * 1.002
281            z_c = r_cont * np.cos(phi_c) * 1.002
282            pts = np.column_stack([x_c, y_c, z_c])
283            if pts.shape[0] > 1:
284                contour_lines.append(Line(pts, c='k', lw=0.5))
285
286    # Create sphere surface mesh
287    phi = np.deg2rad(90 - lat_grid)
288    theta = np.deg2rad(lon_grid)
289    r = radius + MOLA * 10
290    x = r * np.sin(phi) * np.cos(theta)
291    y = r * np.sin(phi) * np.sin(theta)
292    z = r * np.cos(phi)
293    pts = np.stack([x.ravel(), y.ravel(), z.ravel()], axis=1)
294
295    # Build mesh faces
296    faces = []
297    for i in range(nlat - 1):
298        for j in range(nlon - 1):
299            p0 = i * nlon + j
300            p1 = p0 + 1
301            p2 = p0 + nlon
302            p3 = p2 + 1
303            faces.extend([(p0, p2, p1), (p1, p2, p3)])
304
305    mesh = Mesh([pts, faces])
306    mesh.cmap(colormap, newdata2d.ravel())
307    mesh.add_scalarbar(title=varname + (f' [{units}]' if units else ''), c='white')
308    mesh.lighting('default')
309
310    # Geographic grid lines
311    meridians, parallels, labels = [], [], []
312    zero_lon_offset = radius * 0.03
313    for lon in range(-150, 181, 30):
314        lat_line = np.linspace(-90, 90, nlat)
315        lon_line = np.full_like(lat_line, lon)
316        phi = np.deg2rad(90 - lat_line)
317        theta = np.deg2rad(lon_line)
318        elev = RegularGridInterpolator((lats, lons), MOLA)((lat_line, lon_line))
319        rr = radius + elev * 10
320        pts_line = np.column_stack([
321            rr * np.sin(phi) * np.cos(theta),
322            rr * np.sin(phi) * np.sin(theta),
323            rr * np.cos(phi)
324        ]) * 1.005
325        label_pos = pts_line[len(pts_line)//2]
326        norm = np.linalg.norm(label_pos)
327        label_pos_out = label_pos / norm * (norm + radius * 0.02)
328        if lon == 0:
329            label_pos_out[1] += zero_lon_offset
330        meridians.append(Line(pts_line, c='k', lw=1)#.flagpole(
331            #f"{lon}°",
332            #point=label_pos_out,
333            #offset=[0, 0, radius * 0.05],
334            #s=radius*0.01,
335            #c='yellow'
336        #).follow_camera()
337        )
338
339    for lat in range(-60, 91, 30):
340        lon_line = np.linspace(-180, 180, nlon)
341        lat_line = np.full_like(lon_line, lat)
342        phi = np.deg2rad(90 - lat_line)
343        theta = np.deg2rad(lon_line)
344        elev = RegularGridInterpolator((lats, lons), MOLA)((lat_line, lon_line))
345        rr = radius + elev * 10
346        pts_line = np.column_stack([
347            rr * np.sin(phi) * np.cos(theta),
348            rr * np.sin(phi) * np.sin(theta),
349            rr * np.cos(phi)
350        ]) * 1.005
351        label_pos = pts_line[len(pts_line)//2]
352        norm = np.linalg.norm(label_pos)
353        label_pos_out = label_pos / norm * (norm + radius * 0.02)
354        parallels.append(Line(pts_line, c='k', lw=1)#.flagpole(
355            #f"{lat}°",
356            #point=label_pos_out,
357            #offset=[0, 0, radius * 0.05],
358            #s=radius*0.01,
359            #c='yellow'
360        #).follow_camera()
361        )
362
363    # Create plotter
364    plotter = Plotter(title="3D topography view", bg="bb", axes=0)
365
366    # Configure camera
367    cam_dist = radius * 3
368    plotter.camera.SetPosition([cam_dist, 0, 0])
369    plotter.camera.SetFocalPoint([0, 0, 0])
370    plotter.camera.SetViewUp([0, 0, 1])
371
372    # Show the globe
373    plotter.show(mesh, *contour_lines, *meridians, *parallels)
374
375
376def plot_variable(dataset, varname, time_index=None, alt_index=None,
377                  colormap="jet", output_path=None, extra_indices=None,
378                  avg_lat=False):
379    """
380    Core plotting logic: reads the variable, handles masks,
381    determines dimensionality, and creates the appropriate plot:
382      - 1D time series
383      - 1D profiles or physical_points maps
384      - 2D lat×lon or generic 2D
385      - Time×lon heatmap if avg_lat=True
386      - Scalar printing
387    """
388    var = dataset.variables[varname]
389    dims = var.dimensions
390
391    # Read full data
392    try:
393        data_full = var[:]
394    except Exception as e:
395        print(f"Error: Cannot read data for '{varname}': {e}")
396        return
397    if hasattr(data_full, "mask"):
398        data_full = np.where(data_full.mask, np.nan, data_full.data)
399
400    # Pure 1D time series
401    if len(dims) == 1 and find_dim_index(dims, TIME_DIMS) is not None:
402        time_var = find_coord_var(dataset, TIME_DIMS)
403        tvals = (dataset.variables[time_var][:] if time_var
404                 else np.arange(data_full.shape[0]))
405        if hasattr(tvals, "mask"):
406            tvals = np.where(tvals.mask, np.nan, tvals.data)
407        plt.figure()
408        plt.plot(tvals, data_full, marker="o")
409        plt.xlabel(time_var or "Time Index")
410        plt.ylabel(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
411        plt.title(f"{varname} vs {time_var or 'Index'}")
412        if output_path:
413            plt.savefig(output_path, bbox_inches="tight")
414            print(f"Saved to {output_path}")
415        else:
416            plt.show()
417        return
418
419    # Identify dims
420    t_idx = find_dim_index(dims, TIME_DIMS)
421    lat_idx = find_dim_index(dims, LAT_DIMS)
422    lon_idx = find_dim_index(dims, LON_DIMS)
423    a_idx = find_dim_index(dims, ALT_DIMS)
424
425    # Average over latitude & plot time × lon heatmap
426    if avg_lat and t_idx is not None and lat_idx is not None and lon_idx is not None:
427        # compute mean over lat axis
428        data_avg = np.nanmean(data_full, axis=lat_idx)
429        # prepare coordinates
430        time_var = find_coord_var(dataset, TIME_DIMS)
431        lon_var = find_coord_var(dataset, LON_DIMS)
432        tvals = dataset.variables[time_var][:]
433        lons = dataset.variables[lon_var][:]
434        if hasattr(tvals, "mask"):
435            tvals = np.where(tvals.mask, np.nan, tvals.data)
436        if hasattr(lons, "mask"):
437            lons = np.where(lons.mask, np.nan, lons.data)
438        plt.figure(figsize=(10, 6))
439        plt.pcolormesh(lons, tvals, data_avg, shading="auto", cmap=colormap)
440        plt.xlabel(f"Longitude ({getattr(dataset.variables[lon_var], 'units', 'deg')})")
441        plt.ylabel(time_var)
442        cbar = plt.colorbar()
443        cbar.set_label(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
444        plt.title(f"{varname} averaged over latitude")
445        if output_path:
446            plt.savefig(output_path, bbox_inches="tight")
447            print(f"Saved to {output_path}")
448        else:
449            plt.show()
450        return
451
452    # Build slicer for other cases
453    slicer = [slice(None)] * len(dims)
454    if t_idx is not None:
455        if time_index is None:
456            print("Error: please supply a time index.")
457            return
458        slicer[t_idx] = time_index
459    if a_idx is not None:
460        if alt_index is None:
461            print("Error: please supply an altitude index.")
462            return
463        slicer[a_idx] = alt_index
464
465    if extra_indices is None:
466        extra_indices = {}
467    for dn, idx_val in extra_indices.items():
468        if dn in dims:
469            slicer[dims.index(dn)] = idx_val
470
471    # Extract slice
472    try:
473        dslice = data_full[tuple(slicer)]
474    except Exception as e:
475        print(f"Error slicing '{varname}': {e}")
476        return
477
478    # Scalar
479    if np.ndim(dslice) == 0:
480        print(f"Scalar '{varname}': {float(dslice)}")
481        return
482
483    # 1D: vector, profile, or physical_points
484    if dslice.ndim == 1:
485        rem = [(i, name) for i, name in enumerate(dims) if slicer[i] == slice(None)]
486        if rem:
487            di, dname = rem[0]
488            # physical_points → interpolated map
489            if dname.lower() == "physical_points":
490                latv = find_coord_var(dataset, LAT_DIMS)
491                lonv = find_coord_var(dataset, LON_DIMS)
492                if latv and lonv:
493                    lats = dataset.variables[latv][:]
494                    lons = dataset.variables[lonv][:]
495
496                    # Unmask
497                    if hasattr(lats, "mask"):
498                        lats = np.where(lats.mask, np.nan, lats.data)
499                    if hasattr(lons, "mask"):
500                        lons = np.where(lons.mask, np.nan, lons.data)
501
502                    # Convert radians to degrees if needed
503                    lats_deg = np.round(np.degrees(lats), 6)
504                    lons_deg = np.round(np.degrees(lons), 6)
505
506                    # Build regular grid
507                    uniq_lats = np.unique(lats_deg)
508                    uniq_lons = np.unique(lons_deg)
509                    nlon = len(uniq_lons)
510
511                    data2d = []
512                    for lat_val in uniq_lats:
513                        mask = lats_deg == lat_val
514                        slice_vals = dslice[mask]
515                        lons_at_lat = lons_deg[mask]
516                        if len(slice_vals) == 1:
517                            row = np.full(nlon, slice_vals[0])
518                        else:
519                            order = np.argsort(lons_at_lat)
520                            row = np.full(nlon, np.nan)
521                            row[: len(slice_vals)] = slice_vals[order]
522                        data2d.append(row)
523                    data2d = np.array(data2d)
524
525                    # Wrap longitude if needed
526                    if -180.0 in uniq_lons:
527                        idx = np.where(np.isclose(uniq_lons, -180.0))[0][0]
528                        data2d = np.hstack([data2d, data2d[:, [idx]]])
529                        uniq_lons = np.append(uniq_lons, 180.0)
530
531                    # Plot interpolated map
532                    proj = ccrs.PlateCarree()
533                    fig, ax = plt.subplots(subplot_kw=dict(projection=proj), figsize=(8, 6))
534                    lon2d, lat2d = np.meshgrid(uniq_lons, uniq_lats)
535                    lon_ticks = np.arange(-180, 181, 30)
536                    lat_ticks = np.arange(-90, 91, 30)
537                    ax.set_xticks(lon_ticks, crs=ccrs.PlateCarree())
538                    ax.set_yticks(lat_ticks, crs=ccrs.PlateCarree())
539                    ax.tick_params(
540                        axis='x', which='major',
541                        length=4,
542                        direction='out',
543                        pad=2,
544                        labelsize=8
545                    )
546                    ax.tick_params(
547                       axis='y', which='major',
548                       length=4,
549                       direction='out',
550                       pad=2,
551                       labelsize=8
552                    )
553                    cf = ax.contourf(
554                        lon2d, lat2d, data2d,
555                        levels=100,
556                        cmap=colormap,
557                        transform=proj
558                    )
559
560                    # Overlay MOLA topography
561                    overlay_topography(ax, transform=proj, levels=10)
562
563                    # Colorbar & labels
564                    cbar = fig.colorbar(cf, ax=ax, pad=0.02)
565                    cbar.set_label(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
566                    ax.set_title(f"{varname} (interpolated map over physical_points)")
567                    ax.set_xlabel(f"Longitude ({getattr(dataset.variables[lonv], 'units', 'deg')})")
568                    ax.set_ylabel(f"Latitude ({getattr(dataset.variables[latv], 'units', 'deg')})")
569
570                    # Prompt for polar-stereo views if interactive
571                    if input("Display polar-stereo views? [y/n]: ").strip().lower() == "y":
572                        units = getattr(dataset.variables[varname], "units", None)
573                        plot_polar_views(lon2d, lat2d, data2d, colormap, varname, units)
574
575                    # Prompt for 3D globe view if interactive
576                    if input("Display 3D globe view? [y/n]: ").strip().lower() == "y":
577                        units = getattr(dataset.variables[varname], "units", None)
578                        plot_3D_globe(lon2d, lat2d, data2d, colormap, varname, units)
579
580                    if output_path:
581                        plt.savefig(output_path, bbox_inches="tight")
582                        print(f"Saved to {output_path}")
583                    else:
584                        plt.show()
585                    return
586            # vertical profile?
587            coord = None
588            if dname.lower() == "subsurface_layers" and "soildepth" in dataset.variables:
589                coord = "soildepth"
590            elif dname in dataset.variables:
591                coord = dname
592            if coord:
593                coords = dataset.variables[coord][:]
594                if hasattr(coords, "mask"):
595                    coords = np.where(coords.mask, np.nan, coords.data)
596                plt.figure()
597                plt.plot(dslice, coords, marker="o")
598                if dname.lower() == "subsurface_layers":
599                    plt.gca().invert_yaxis()
600                plt.xlabel(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
601                plt.ylabel(coord + (f" ({dataset.variables[coord].units})" if hasattr(dataset.variables[coord], "units") else ""))
602                plt.title(f"{varname} vs {coord}")
603                if output_path:
604                    plt.savefig(output_path, bbox_inches="tight")
605                    print(f"Saved to {output_path}")
606                else:
607                    plt.show()
608                return
609        # generic 1D
610        plt.figure()
611        plt.plot(dslice, marker="o")
612        plt.xlabel("Index")
613        plt.ylabel(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
614        plt.title(f"{varname} (1D)")
615        if output_path:
616            plt.savefig(output_path, bbox_inches="tight")
617            print(f"Saved to {output_path}")
618        else:
619            plt.show()
620        return
621
622    # if dslice.ndim == 2:
623        lat_idx2 = find_dim_index(dims, LAT_DIMS)
624        lon_idx2 = find_dim_index(dims, LON_DIMS)
625
626        # Geographic lat×lon slice
627        if lat_idx2 is not None and lon_idx2 is not None:
628            latv = find_coord_var(dataset, LAT_DIMS)
629            lonv = find_coord_var(dataset, LON_DIMS)
630            lats = dataset.variables[latv][:]
631            lons = dataset.variables[lonv][:]
632
633            # Handle masked arrays
634            if hasattr(lats, "mask"):
635                lats = np.where(lats.mask, np.nan, lats.data)
636            if hasattr(lons, "mask"):
637                lons = np.where(lons.mask, np.nan, lons.data)
638
639            # Create map projection
640            proj = ccrs.PlateCarree()
641            fig, ax = plt.subplots(figsize=(10, 6), subplot_kw=dict(projection=proj))
642
643            # Make meshgrid and plot
644            lon2d, lat2d = np.meshgrid(lons, lats)
645            cf = ax.contourf(
646                lon2d, lat2d, dslice,
647                levels=100,
648                cmap=colormap,
649                transform=proj
650            )
651
652            # Overlay topography
653            overlay_topography(ax, transform=proj, levels=10)
654
655            # Colorbar and labels
656            lon_ticks = np.arange(-180, 181, 30)
657            lat_ticks = np.arange(-90, 91, 30)
658            ax.set_xticks(lon_ticks, crs=ccrs.PlateCarree())
659            ax.set_yticks(lat_ticks, crs=ccrs.PlateCarree())
660            ax.tick_params(
661                axis='x', which='major',
662                length=4,
663                direction='out',
664                pad=2,
665                labelsize=8
666            )
667            ax.tick_params(
668                axis='y', which='major',
669                length=4,
670                direction='out',
671                pad=2,
672                labelsize=8
673            )
674            cbar = fig.colorbar(cf, ax=ax, orientation="vertical", pad=0.02)
675            cbar.set_label(varname + (f" ({dataset.variables[varname].units})"
676                                      if hasattr(dataset.variables[varname], "units") else ""))
677            ax.set_title(f"{varname} (lat × lon)")
678            ax.set_xlabel(f"Longitude ({getattr(dataset.variables[lonv], 'units', 'deg')})")
679            ax.set_ylabel(f"Latitude ({getattr(dataset.variables[latv], 'units', 'deg')})")
680
681            # Prompt for polar-stereo views if interactive
682            if sys.stdin.isatty() and input("Display polar-stereo views? [y/n]: ").strip().lower() == "y":
683                units = getattr(dataset.variables[varname], "units", None)
684                plot_polar_views(lon2d, lat2d, dslice, colormap, varname, units)
685
686            # Prompt for 3D globe view if interactive
687            if sys.stdin.isatty() and input("Display 3D globe view? [y/n]: ").strip().lower() == "y":
688                units = getattr(dataset.variables[varname], "units", None)
689                plot_3D_globe(lon2d, lat2d, dslice, colormap, varname, units)
690
691            if output_path:
692                plt.savefig(output_path, bbox_inches="tight")
693                print(f"Saved to {output_path}")
694            else:
695                plt.show()
696            return
697
698        # Generic 2D
699        plt.figure(figsize=(8, 6))
700        plt.imshow(dslice, aspect="auto")
701        plt.colorbar(label=varname + (f" ({var.units})" if hasattr(var, "units") else ""))
702        plt.xlabel("Dim 2 index")
703        plt.ylabel("Dim 1 index")
704        plt.title(f"{varname} (2D)")
705        if output_path:
706            plt.savefig(output_path, bbox_inches="tight")
707            print(f"Saved to {output_path}")
708        else:
709            plt.show()
710        return
711
712    print(f"Error: ndim={dslice.ndim} not supported.")
713
714
715def visualize_variable_interactive(nc_path=None):
716    """
717    Interactive loop: keep prompting for variables to plot until user quits.
718    """
719    # Open dataset
720    if nc_path:
721        path = nc_path
722    else:
723        readline.set_completer(complete_filename)
724        readline.parse_and_bind("tab: complete")
725        path = input("Enter path to NetCDF file: ").strip()
726
727    if not os.path.isfile(path):
728        print(f"Error: '{path}' not found.")
729        return
730
731    ds = Dataset(path, "r")
732    var_list = list(ds.variables.keys())
733    if not var_list:
734        print("No variables found in file.")
735        ds.close()
736        return
737
738    # Enable interactive mode
739    plt.ion()
740
741    while True:
742        # Enable tab-completion for variable names
743        readline.set_completer(make_varname_completer(var_list))
744        readline.parse_and_bind("tab: complete")
745
746        print("\nAvailable variables:")
747        for name in var_list:
748            print(f"  - {name}")
749        varname = input("\nEnter variable name to plot (or 'q' to quit): ").strip()
750        if varname.lower() in ("q", "quit", "exit"):
751            print("Exiting.")
752            break
753        if varname not in ds.variables:
754            print(f"Variable '{varname}' not found. Try again.")
755            continue
756
757        # Display dimensions and size
758        var = ds.variables[varname]
759        dims, shape = var.dimensions, var.shape
760        print(f"\nVariable '{varname}' has dimensions:")
761        for dim, size in zip(dims, shape):
762            print(f"  - {dim}: size {size}")
763        print()
764
765        # Prepare slicing parameters
766        time_index = None
767        alt_index = None
768        avg = False
769        extra_indices = {}
770
771        # Time index
772        t_idx = find_dim_index(dims, TIME_DIMS)
773        if t_idx is not None:
774            if shape[t_idx] > 1:
775                while True:
776                    idx = input(f"Enter time index [1–{shape[t_idx]}] (press Enter for all): ").strip()
777                    if idx == '':
778                        time_index = None
779                        break
780                    if idx.isdigit():
781                        i = int(idx)
782                        if 1 <= i <= shape[t_idx]:
783                            time_index = i - 1
784                            break
785                    print("Invalid entry. Please enter a valid number or press Enter.")
786            else:
787                time_index = 0
788
789        # Altitude index
790        a_idx = find_dim_index(dims, ALT_DIMS)
791        if a_idx is not None:
792            if shape[a_idx] > 1:
793                while True:
794                    idx = input(f"Enter altitude index [1–{shape[a_idx]}] (press Enter for all): ").strip()
795                    if idx == '':
796                        alt_index = None
797                        break
798                    if idx.isdigit():
799                        i = int(idx)
800                        if 1 <= i <= shape[a_idx]:
801                            alt_index = i - 1
802                            break
803                    print("Invalid entry. Please enter a valid number or press Enter.")
804            else:
805                alt_index = 0
806
807        # Average over latitude?
808        lat_idx = find_dim_index(dims, LAT_DIMS)
809        lon_idx = find_dim_index(dims, LON_DIMS)
810        if (t_idx is not None and lat_idx is not None and lon_idx is not None and
811            shape[t_idx] > 1 and shape[lat_idx] > 1 and shape[lon_idx] > 1):
812            resp = input("Average over latitude and plot lon vs time? [y/n]: ").strip().lower()
813            avg = (resp == 'y')
814
815        # Other dimensions
816        for i, dname in enumerate(dims):
817            if i in (t_idx, a_idx):
818                continue
819            size = shape[i]
820            if size == 1:
821                extra_indices[dname] = 0
822                continue
823            while True:
824                idx = input(f"Enter index [1–{size}] for '{dname}' (press Enter for all): ").strip()
825                if idx == '':
826                    # keep all values
827                    break
828                if idx.isdigit():
829                    j = int(idx)
830                    if 1 <= j <= size:
831                        extra_indices[dname] = j - 1
832                        break
833                print("Invalid entry. Please enter a valid number or press Enter.")
834
835        # Plot the variable
836        plot_variable(
837            ds, varname,
838            time_index    = time_index,
839            alt_index     = alt_index,
840            colormap      = 'jet',
841            output_path   = None,
842            extra_indices = extra_indices,
843            avg_lat       = avg
844        )
845
846    ds.close()
847
848
849def visualize_variable_cli(nc_file, varname, time_index, alt_index,
850                           colormap, output_path, extra_json, avg_lat):
851    """
852    Command-line mode: visualize directly, parsing the --extra-indices argument (JSON string).
853    """
854    if not os.path.isfile(nc_file):
855        print(f"Error: '{nc_file}' not found.")
856        return
857    ds = Dataset(nc_file, "r")
858    if varname not in ds.variables:
859        print(f"Variable '{varname}' not in file.")
860        ds.close()
861        return
862
863    # Display dimensions and size
864    dims  = ds.variables[varname].dimensions
865    shape = ds.variables[varname].shape
866    print(f"\nVariable '{varname}' has {len(dims)} dimensions:")
867    for name, size in zip(dims, shape):
868        print(f"  - {name}: size {size}")
869    print()
870
871    # Special case: time-only → plot directly
872    t_idx = find_dim_index(dims, TIME_DIMS)
873    if (
874        t_idx is not None and shape[t_idx] > 1 and
875        all(shape[i] == 1 for i in range(len(dims)) if i != t_idx)
876    ):
877        print("Detected single-point spatial dims; plotting time series…")
878        var_obj = ds.variables[varname]
879        data = var_obj[:].squeeze()
880        time_var = find_coord_var(ds, TIME_DIMS)
881        if time_var:
882            tvals = ds.variables[time_var][:]
883        else:
884            tvals = np.arange(data.shape[0])
885        if hasattr(data, "mask"):
886            data = np.where(data.mask, np.nan, data.data)
887        if hasattr(tvals, "mask"):
888            tvals = np.where(tvals.mask, np.nan, tvals.data)
889        plt.figure()
890        plt.plot(tvals, data, marker="o")
891        plt.xlabel(time_var or "Time Index")
892        plt.ylabel(varname + (f" ({var_obj.units})" if hasattr(var_obj, "units") else ""))
893        plt.title(f"{varname} vs {time_var or 'Index'}")
894        if output_path:
895            plt.savefig(output_path, bbox_inches="tight")
896            print(f"Saved to {output_path}")
897        else:
898            plt.show()
899        ds.close()
900        return
901
902    # if --avg-lat but lat/lon/Time not compatible → disable
903    lat_idx = find_dim_index(dims, LAT_DIMS)
904    lon_idx = find_dim_index(dims, LON_DIMS)
905    if avg_lat and not (
906        t_idx   is not None and shape[t_idx]  > 1 and
907        lat_idx is not None and shape[lat_idx] > 1 and
908        lon_idx is not None and shape[lon_idx] > 1
909    ):
910        print("Note: disabling --avg-lat (requires Time, lat & lon each >1).")
911        avg_lat = False
912
913    # Parse extra indices JSON
914    extra = {}
915    if extra_json:
916        try:
917            parsed = json.loads(extra_json)
918            for k, v in parsed.items():
919                if isinstance(v, int):
920                    if "slope" in k.lower():
921                        extra[k] = v - 1
922                    else:
923                        extra[k] = v
924        except:
925            print("Warning: bad extra-indices.")
926
927    plot_variable(ds, varname, time_index, alt_index,
928                  colormap, output_path, extra, avg_lat)
929    ds.close()
930
931
932def main():
933    parser = argparse.ArgumentParser()
934    parser.add_argument('nc_file', nargs='?', help='NetCDF file (omit for interactive)')
935    parser.add_argument('-v','--variable', help='Variable name')
936    parser.add_argument('-t','--time-index', type=int, help='Time index (0-based)')
937    parser.add_argument('-a','--alt-index', type=int, help='Altitude index (0-based)')
938    parser.add_argument('-c','--cmap', default='jet', help='Colormap')
939    parser.add_argument('--avg-lat', action='store_true', help='Average over latitude')
940    parser.add_argument('--show-polar', action='store_true', help='Enable polar-stereo views')
941    parser.add_argument('--show-3d', action='store_true', help='Enable 3D globe view')
942    parser.add_argument('-o','--output', help='Save figure path')
943    parser.add_argument('-e','--extra-indices', help='JSON string for other dims')
944    args = parser.parse_args()
945
946    if args.nc_file and args.variable:
947        visualize_variable_cli(
948            args.nc_file, args.variable,
949            args.time_index, args.alt_index,
950            args.cmap, args.output,
951            args.extra_indices, args.avg_lat
952        )
953    else:
954        visualize_variable_interactive(args.nc_file)
955
956
957if __name__ == "__main__":
958    main()
959
Note: See TracBrowser for help on using the repository browser.