source: trunk/LMDZ.MARS/util/display_netcdf.py

Last change on this file was 3839, checked in by jbclement, 6 days ago

Mars PCM:
In "display_netcdf.py", addition of the possibility to display altitude as function of time for 1D variables + showing value at cursor for 2D heatmaps.
JBC

  • Property svn:executable set to *
File size: 38.8 KB
RevLine 
[3783]1#!/usr/bin/env python3
[3459]2##############################################################
3### Python script to visualize a variable in a NetCDF file ###
4##############################################################
5
[3783]6"""
[3798]7This script can display any numeric variable from a NetCDF file.
8It supports the following cases:
[3818]9  - Scalar output
10  - 1D time series
11  - 1D vertical profiles
[3798]12  - 2D latitude/longitude map
[3818]13  - 2D cross-sections
[3808]14  - Optionally average over latitude and plot longitude vs. time heatmap
[3818]15  - Optionally display polar stereographic view of 2D maps
16  - Optionally display 3D globe view of 2D maps
[3459]17
[3818]18Automatic setup from the environment file found in the "LMDZ.MARS/util" folder:
19  1. Make sure Conda is installed.
20  2. In terminal, navigate to the folder containing this script.
21  4. Create the environment:
22       conda env create -f display_netcdf.yml
23  5. Activate the environment:
24       conda activate my_env
25  6. Run the script:
26     python display_netcdf.py
27
[3783]28Usage:
29  1) Command-line mode:
[3818]30       python display_netcdf.py /path/to/your_file.nc --variable VAR_NAME [options]
31     Options:
[3810]32    --variable         : Name of the variable to visualize.
33    --time-index       : Index along the Time dimension (0-based, ignored for purely 1D time series).
34    --alt-index        : Index along the altitude dimension (0-based), if present.
35    --cmap             : Matplotlib colormap (default: "jet").
36    --avg-lat          : Average over latitude and plot longitude vs. time heatmap.
37    --slice-lon-index  : Fixed longitude index for altitude×longitude cross-section.
38    --slice-lat-index  : Fixed latitude index for altitude×latitude cross-section.
39    --show-topo        : Overlay MOLA topography on lat/lon maps.
[3818]40    --show-polar       :
41    --show-3d          :
[3810]42    --output           : If provided, save the figure to this filename instead of displaying.
43    --extra-indices    : JSON string to fix indices for any other dimensions.
44                         For dimensions with "slope", use 1-based numbering here.
45                         Example: '{"nslope": 1, "physical_points": 3}'
[3798]46
47  2) Interactive mode:
[3783]48       python display_netcdf.py
[3818]49     The script will prompt for everything.
[3783]50"""
51
[3459]52import os
[3783]53import sys
54import glob
[3459]55import readline
[3783]56import argparse
[3798]57import json
[3783]58import numpy as np
59import matplotlib.pyplot as plt
[3810]60import matplotlib.path as mpath
[3818]61import matplotlib.colors as mcolors
[3810]62import cartopy.crs as ccrs
[3818]63import pandas as pd
[3459]64from netCDF4 import Dataset
65
[3818]66# Attempt vedo import early for global use
67try:
68    import vedo
69    from vedo import *
70    from scipy.interpolate import RegularGridInterpolator
71    vedo_available = True
72except ImportError:
73    vedo_available = False
74
[3810]75# Constants for recognized dimension names
[3798]76TIME_DIMS = ("Time", "time", "time_counter")
77ALT_DIMS  = ("altitude",)
78LAT_DIMS  = ("latitude", "lat")
79LON_DIMS  = ("longitude", "lon")
[3783]80
[3818]81# Paths for MOLA data
82MOLA_NPY = 'MOLA_1px_per_deg.npy'
83MOLA_CSV = 'molaTeam_contour_31rgb_steps.csv'
84
[3810]85# Attempt to load MOLA topography
86try:
[3839]87    MOLA = np.load('MOLA_1px_per_deg.npy') # shape (nlat, nlon) at 1° per pixel: lat from -90 to 90, lon from 0 to 360
[3810]88    nlat, nlon = MOLA.shape
89    topo_lats = np.linspace(90 - 0.5, -90 + 0.5, nlat)
90    topo_lons = np.linspace(-180 + 0.5, 180 - 0.5, nlon)
91    topo_lon2d, topo_lat2d = np.meshgrid(topo_lons, topo_lats)
92    topo_loaded = True
93except Exception as e:
[3818]94    print(f"Warning: '{MOLA_NPY}' not found: {e}")
[3810]95    topo_loaded = False
[3798]96
[3810]97
[3818]98# Attempt to load contour color table
99if os.path.isfile(MOLA_CSV):
100    color_table = pd.read_csv(MOLA_CSV)
101    csv_loaded = True
102else:
103    print(f"Warning: '{MOLA_CSV}' not found. 3D view colors disabled.")
104    csv_loaded = False
105
106
[3783]107def complete_filename(text, state):
108    """
[3810]109    Tab-completion for filesystem paths.
[3783]110    """
111    if "*" not in text:
112        pattern = text + "*"
113    else:
114        pattern = text
115    matches = glob.glob(os.path.expanduser(pattern))
116    matches = [m + "/" if os.path.isdir(m) else m for m in matches]
[3459]117    try:
118        return matches[state]
119    except IndexError:
120        return None
121
[3783]122
123def make_varname_completer(varnames):
124    """
[3810]125    Returns a readline completer for variable names.
[3783]126    """
[3459]127    def completer(text, state):
[3783]128        options = [name for name in varnames if name.startswith(text)]
129        try:
[3459]130            return options[state]
[3783]131        except IndexError:
[3459]132            return None
133    return completer
134
[3783]135
136def find_dim_index(dims, candidates):
137    """
138    Search through dims tuple for any name in candidates.
139    Returns the index if found, else returns None.
140    """
141    for idx, dim in enumerate(dims):
142        for cand in candidates:
143            if cand.lower() == dim.lower():
144                return idx
145    return None
146
147
148def find_coord_var(dataset, candidates):
149    """
150    Among dataset variables, return the first variable whose name matches any candidate.
151    Returns None if none found.
152    """
153    for name in dataset.variables:
154        for cand in candidates:
155            if cand.lower() == name.lower():
156                return name
157    return None
158
159
[3810]160def overlay_topography(ax, transform, levels=10):
161    """
162    Overlay MOLA topography contours onto a given GeoAxes.
163    """
164    if not topo_loaded:
165        return
166    ax.contour(
167        topo_lon2d, topo_lat2d, MOLA,
168        levels=levels,
169        linewidths=0.5,
170        colors='black',
171        transform=transform
172    )
173
174
[3839]175def attach_format_coord(ax, mat, x, y, is_pcolormesh=True):
176    """
177    Attach a format_coord function to the axes to display x, y, and value at cursor.
178    Works for both pcolormesh and imshow style grids.
179    """
180    # Determine dimensions
181    if mat.ndim == 2:
182        ny, nx = mat.shape
183    elif mat.ndim == 3 and mat.shape[2] in (3, 4):
184        ny, nx, nc = mat.shape
185    else:
186        raise ValueError(f"Unsupported mat shape {mat.shape}")
187    # Edges or extents
188    if is_pcolormesh:
189        xedges, yedges = x, y
190    else:
191        x0, x1 = x.min(), x.max()
192        y0, y1 = y.min(), y.max()
193
194    def format_coord(xp, yp):
195        # Map to indices
196        if is_pcolormesh:
197            col = np.searchsorted(xedges, xp) - 1
198            row = np.searchsorted(yedges, yp) - 1
199        else:
200            col = int((xp - x0) / (x1 - x0) * nx)
201            row = int((yp - y0) / (y1 - y0) * ny)
202        # Within bounds?
203        if 0 <= row < ny and 0 <= col < nx:
204            if mat.ndim == 2:
205                v = mat[row, col]
206                return f"x={xp:.3g}, y={yp:.3g}, val={v:.3g}"
207            else:
208                vals = mat[row, col]
209                txt = ", ".join(f"{vv:.3g}" for vv in vals[:3])
210                return f"x={xp:.3g}, y={yp:.3g}, val=({txt})"
211        return f"x={xp:.3g}, y={yp:.3g}"
212
213    ax.format_coord = format_coord
214
215
[3810]216def plot_polar_views(lon2d, lat2d, data2d, colormap, varname, units=None, topo_overlay=True):
217    """
218    Plot two polar‐stereographic views (north & south) of the same data.
219    """
220    figs = []  # collect figure handles
221
222    for pole in ("north", "south"):
223        # Choose projection and extent for each pole
224        if pole == "north":
225            proj = ccrs.NorthPolarStereo(central_longitude=180)
226            extent = [-180, 180, 60, 90]
227        else:
228            proj = ccrs.SouthPolarStereo(central_longitude=180)
229            extent = [-180, 180, -90, -60]
230
231        # Create figure and GeoAxes
232        fig = plt.figure(figsize=(8, 6))
233        ax = fig.add_subplot(1, 1, 1, projection=proj, aspect=True)
234        ax.set_global()
235        ax.set_extent(extent, ccrs.PlateCarree())
236
237        # Draw circular boundary
238        theta = np.linspace(0, 2 * np.pi, 100)
239        center, radius = [0.5, 0.5], 0.5
240        verts = np.vstack([np.sin(theta), np.cos(theta)]).T
241        circle = mpath.Path(verts * radius + center)
242        ax.set_boundary(circle, transform=ax.transAxes)
243
244        # Add meridians/parallels
245        gl = ax.gridlines(
246            draw_labels=True,
247            color='k',
248            xlocs=range(-180, 181, 30),
249            ylocs=range(-90, 91, 10),
250            linestyle='--',
251            linewidth=0.5
252        )
253
254        # Plot data in PlateCarree projection
255        cf = ax.contourf(
256            lon2d, lat2d, data2d,
257            levels=100,
258            cmap=colormap,
259            transform=ccrs.PlateCarree()
260        )
261
262        # Optionally overlay MOLA topography
263        if topo_overlay:
264            overlay_topography(ax, transform=ccrs.PlateCarree(), levels=20)
265
266        # Colorbar and title
267        cbar = fig.colorbar(cf, ax=ax, pad=0.1)
268        label = varname + (f" ({units})" if units else "")
269        cbar.set_label(label)
[3824]270        ax.set_title(f"{varname} — {pole.capitalize()} polar region", pad=20, y=1.05, fontsize=12, fontweight='bold')
[3810]271
272        figs.append(fig)
273
274    # Show both figures
275    plt.show()
276
[3818]277
278def plot_3D_globe(lon2d, lat2d, data2d, colormap, varname, units=None):
279    """
280    Plot a 3D globe view of the data using vedo, with surface coloring based on data2d
281    and overlaid contour lines from MOLA topography.
282    """
283    if not vedo_available:
284        print("3D view skipped: vedo missing.")
285        return
286    if not csv_loaded:
287        print("3D view skipped: color table missing.")
288        return
289
290    # Prepare MOLA grid
291    nlat, nlon = MOLA.shape
292    lats = np.linspace(90, -90, nlat)
293    lons = np.linspace(-180, 180, nlon)
294    lon_grid, lat_grid = np.meshgrid(lons, lats)
295
296    # Interpolate data2d onto MOLA grid
297    lat_data = np.linspace(-90, 90, data2d.shape[0])
298    lon_data = np.linspace(-180, 180, data2d.shape[1])
299    interp2d = RegularGridInterpolator((lat_data, lon_data), data2d,
300                                       bounds_error=False, fill_value=None)
301    newdata2d = interp2d((lat_grid, lon_grid))
302
303    # Generate contour lines from MOLA
304    cs = plt.contour(lon_grid, lat_grid, MOLA, levels=10, linewidths=0)
305    plt.clf()
306    contour_lines = []
307    radius = 3389500 # Mars average radius [m]
308    for segs, level in zip(cs.allsegs, cs.levels):
309        for verts in segs:
310            lon_c = verts[:, 0]
311            lat_c = verts[:, 1]
312            phi_c = np.radians(90 - lat_c)
313            theta_c = np.radians(lon_c)
314            elev = RegularGridInterpolator((lats, lons), MOLA,
315                                           bounds_error=False,
316                                           fill_value=0.0)((lat_c, lon_c))
317            r_cont = radius + elev * 10
318            x_c = r_cont * np.sin(phi_c) * np.cos(theta_c) * 1.002
319            y_c = r_cont * np.sin(phi_c) * np.sin(theta_c) * 1.002
320            z_c = r_cont * np.cos(phi_c) * 1.002
321            pts = np.column_stack([x_c, y_c, z_c])
322            if pts.shape[0] > 1:
323                contour_lines.append(Line(pts, c='k', lw=0.5))
324
325    # Create sphere surface mesh
326    phi = np.deg2rad(90 - lat_grid)
327    theta = np.deg2rad(lon_grid)
328    r = radius + MOLA * 10
329    x = r * np.sin(phi) * np.cos(theta)
330    y = r * np.sin(phi) * np.sin(theta)
331    z = r * np.cos(phi)
332    pts = np.stack([x.ravel(), y.ravel(), z.ravel()], axis=1)
333
334    # Build mesh faces
335    faces = []
336    for i in range(nlat - 1):
337        for j in range(nlon - 1):
338            p0 = i * nlon + j
339            p1 = p0 + 1
340            p2 = p0 + nlon
341            p3 = p2 + 1
342            faces.extend([(p0, p2, p1), (p1, p2, p3)])
343
344    mesh = Mesh([pts, faces])
345    mesh.cmap(colormap, newdata2d.ravel())
346    mesh.add_scalarbar(title=varname + (f' [{units}]' if units else ''), c='white')
347    mesh.lighting('default')
348
349    # Geographic grid lines
350    meridians, parallels, labels = [], [], []
351    zero_lon_offset = radius * 0.03
352    for lon in range(-150, 181, 30):
353        lat_line = np.linspace(-90, 90, nlat)
354        lon_line = np.full_like(lat_line, lon)
355        phi = np.deg2rad(90 - lat_line)
356        theta = np.deg2rad(lon_line)
357        elev = RegularGridInterpolator((lats, lons), MOLA)((lat_line, lon_line))
358        rr = radius + elev * 10
359        pts_line = np.column_stack([
360            rr * np.sin(phi) * np.cos(theta),
361            rr * np.sin(phi) * np.sin(theta),
362            rr * np.cos(phi)
363        ]) * 1.005
364        label_pos = pts_line[len(pts_line)//2]
365        norm = np.linalg.norm(label_pos)
366        label_pos_out = label_pos / norm * (norm + radius * 0.02)
367        if lon == 0:
368            label_pos_out[1] += zero_lon_offset
369        meridians.append(Line(pts_line, c='k', lw=1)#.flagpole(
370            #f"{lon}°",
371            #point=label_pos_out,
372            #offset=[0, 0, radius * 0.05],
373            #s=radius*0.01,
374            #c='yellow'
375        #).follow_camera()
376        )
377
378    for lat in range(-60, 91, 30):
379        lon_line = np.linspace(-180, 180, nlon)
380        lat_line = np.full_like(lon_line, lat)
381        phi = np.deg2rad(90 - lat_line)
382        theta = np.deg2rad(lon_line)
383        elev = RegularGridInterpolator((lats, lons), MOLA)((lat_line, lon_line))
384        rr = radius + elev * 10
385        pts_line = np.column_stack([
386            rr * np.sin(phi) * np.cos(theta),
387            rr * np.sin(phi) * np.sin(theta),
388            rr * np.cos(phi)
389        ]) * 1.005
390        label_pos = pts_line[len(pts_line)//2]
391        norm = np.linalg.norm(label_pos)
392        label_pos_out = label_pos / norm * (norm + radius * 0.02)
393        parallels.append(Line(pts_line, c='k', lw=1)#.flagpole(
394            #f"{lat}°",
395            #point=label_pos_out,
396            #offset=[0, 0, radius * 0.05],
397            #s=radius*0.01,
398            #c='yellow'
399        #).follow_camera()
400        )
401
402    # Create plotter
[3824]403    plotter = Plotter(title="3D globe view", bg="bb", axes=0)
[3818]404
405    # Configure camera
406    cam_dist = radius * 3
407    plotter.camera.SetPosition([cam_dist, 0, 0])
408    plotter.camera.SetFocalPoint([0, 0, 0])
409    plotter.camera.SetViewUp([0, 0, 1])
410
411    # Show the globe
412    plotter.show(mesh, *contour_lines, *meridians, *parallels)
413
414
[3808]415def plot_variable(dataset, varname, time_index=None, alt_index=None,
416                  colormap="jet", output_path=None, extra_indices=None,
417                  avg_lat=False):
[3783]418    """
[3808]419    Core plotting logic: reads the variable, handles masks,
420    determines dimensionality, and creates the appropriate plot:
421      - 1D time series
422      - 1D profiles or physical_points maps
423      - 2D lat×lon or generic 2D
424      - Time×lon heatmap if avg_lat=True
425      - Scalar printing
[3783]426    """
427    var = dataset.variables[varname]
[3808]428    dims = var.dimensions
[3783]429
[3808]430    # Read full data
[3459]431    try:
[3783]432        data_full = var[:]
433    except Exception as e:
[3808]434        print(f"Error: Cannot read data for '{varname}': {e}")
[3459]435        return
[3783]436    if hasattr(data_full, "mask"):
437        data_full = np.where(data_full.mask, np.nan, data_full.data)
438
[3839]439    # If Time and altitude are both present and neither indexed,
440    # and every other dim has size 1:
441    t_idx = find_dim_index(dims, TIME_DIMS)
442    a_idx = find_dim_index(dims, ALT_DIMS)
443    shape = var.shape
444    if (t_idx is not None and a_idx is not None
445        and time_index is None and alt_index is None
446        and all(size == 1 for i, size in enumerate(shape) if i not in (t_idx, a_idx))):
447
448        # Build a slicer that keeps Time & altitude, drops other singletons
449        slicer = [0] * len(dims)
450        slicer[t_idx] = slice(None)
451        slicer[a_idx] = slice(None)
452        data2d = data_full[tuple(slicer)]  # shape (ntime, nalt)
453
454        # Coordinate arrays
455        tvar = find_coord_var(dataset, TIME_DIMS)
456        avar = find_coord_var(dataset, ALT_DIMS)
457        tvals = dataset.variables[tvar][:]
458        avals = dataset.variables[avar][:]
459
460        # Unmask if necessary
461        if hasattr(tvals, "mask"):
462            tvals = np.where(tvals.mask, np.nan, tvals.data)
463        if hasattr(avals, "mask"):
464            avals = np.where(avals.mask, np.nan, avals.data)
465
466        # Plot heatmap with x=time, y=altitude
467        fig, ax = plt.subplots(figsize=(10, 6))
468        T, A = np.meshgrid(tvals, avals)
469        im = ax.pcolormesh(
470            T, A, data2d.T,
471            shading="auto", cmap=colormap
472        )
473        dt = tvals[1] - tvals[0]
474        da = avals[1] - avals[0]
475        x_edges = np.concatenate([tvals - dt/2, [tvals[-1] + dt/2]])
476        y_edges = np.concatenate([avals - da/2, [avals[-1] + da/2]])
477        attach_format_coord(ax, data2d.T, x_edges, y_edges, is_pcolormesh=True)
478        ax.set_xlabel(tvar)
479        ax.set_ylabel(avar)
480        cbar = fig.colorbar(im, ax=ax)
481        cbar.set_label(varname + (f" ({getattr(var, 'units','')})"))
482        ax.set_title(f"{varname} — {avar} vs {tvar}", fontweight="bold")
483
484        if output_path:
485            fig.savefig(output_path, bbox_inches="tight")
486            print(f"Saved to {output_path}")
487        else:
488            plt.show()
489        return
490
[3808]491    # Pure 1D time series
[3798]492    if len(dims) == 1 and find_dim_index(dims, TIME_DIMS) is not None:
[3808]493        time_var = find_coord_var(dataset, TIME_DIMS)
494        tvals = (dataset.variables[time_var][:] if time_var
495                 else np.arange(data_full.shape[0]))
496        if hasattr(tvals, "mask"):
497            tvals = np.where(tvals.mask, np.nan, tvals.data)
[3798]498        plt.figure()
[3808]499        plt.plot(tvals, data_full, marker="o")
500        plt.xlabel(time_var or "Time Index")
501        plt.ylabel(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
[3824]502        plt.title(f"{varname} vs {time_var or 'Index'}", fontweight='bold')
[3798]503        if output_path:
[3808]504            plt.savefig(output_path, bbox_inches="tight")
505            print(f"Saved to {output_path}")
[3798]506        else:
507            plt.show()
508        return
509
[3808]510    # Identify dims
[3783]511    t_idx = find_dim_index(dims, TIME_DIMS)
512    lat_idx = find_dim_index(dims, LAT_DIMS)
513    lon_idx = find_dim_index(dims, LON_DIMS)
[3808]514    a_idx = find_dim_index(dims, ALT_DIMS)
[3783]515
[3808]516    # Average over latitude & plot time × lon heatmap
517    if avg_lat and t_idx is not None and lat_idx is not None and lon_idx is not None:
[3810]518        # compute mean over lat axis
[3808]519        data_avg = np.nanmean(data_full, axis=lat_idx)
[3810]520        # prepare coordinates
[3808]521        time_var = find_coord_var(dataset, TIME_DIMS)
522        lon_var = find_coord_var(dataset, LON_DIMS)
523        tvals = dataset.variables[time_var][:]
524        lons = dataset.variables[lon_var][:]
525        if hasattr(tvals, "mask"):
526            tvals = np.where(tvals.mask, np.nan, tvals.data)
527        if hasattr(lons, "mask"):
528            lons = np.where(lons.mask, np.nan, lons.data)
[3839]529        fig, ax = ax.subplots(figsize=(10, 6))
530        im = plt.pcolormesh(lons, tvals, data_avg, shading="auto", cmap=colormap)
531        dx = lons[1] - lons[0]
532        dy = tvals[1] - tvals[0]
533        x_edges = np.concatenate([lons - dx/2, [lons[-1] + dx/2]])
534        y_edges = np.concatenate([tvals - dy/2, [tvals[-1] + dy/2]])
535        attach_format_coord(ax, data_avg.T, x_edges, y_edges, is_pcolormesh=True)
536        ax.set_xlabel(f"Longitude ({getattr(dataset.variables[lon_var], 'units', 'deg')})")
537        ax.set_ylabel(time_var)
538        cbar = fig.colorbar(im, ax=ax)
[3808]539        cbar.set_label(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
[3839]540        ax.set_title(f"{varname} averaged over latitude", fontweight='bold')
[3808]541        if output_path:
[3839]542            fig.savefig(output_path, bbox_inches="tight")
[3808]543            print(f"Saved to {output_path}")
544        else:
545            plt.show()
546        return
547
548    # Build slicer for other cases
[3798]549    slicer = [slice(None)] * len(dims)
[3783]550    if t_idx is not None:
551        if time_index is None:
[3808]552            print("Error: please supply a time index.")
[3783]553            return
554        slicer[t_idx] = time_index
555    if a_idx is not None:
556        if alt_index is None:
[3808]557            print("Error: please supply an altitude index.")
[3783]558            return
559        slicer[a_idx] = alt_index
560
[3798]561    if extra_indices is None:
562        extra_indices = {}
[3808]563    for dn, idx_val in extra_indices.items():
564        if dn in dims:
565            slicer[dims.index(dn)] = idx_val
[3798]566
[3808]567    # Extract slice
[3783]568    try:
[3808]569        dslice = data_full[tuple(slicer)]
[3783]570    except Exception as e:
[3808]571        print(f"Error slicing '{varname}': {e}")
[3783]572        return
573
[3808]574    # Scalar
575    if np.ndim(dslice) == 0:
576        print(f"Scalar '{varname}': {float(dslice)}")
[3783]577        return
578
[3808]579    # 1D: vector, profile, or physical_points
580    if dslice.ndim == 1:
581        rem = [(i, name) for i, name in enumerate(dims) if slicer[i] == slice(None)]
582        if rem:
583            di, dname = rem[0]
584            # physical_points → interpolated map
585            if dname.lower() == "physical_points":
586                latv = find_coord_var(dataset, LAT_DIMS)
587                lonv = find_coord_var(dataset, LON_DIMS)
588                if latv and lonv:
589                    lats = dataset.variables[latv][:]
590                    lons = dataset.variables[lonv][:]
[3810]591
592                    # Unmask
[3808]593                    if hasattr(lats, "mask"):
594                        lats = np.where(lats.mask, np.nan, lats.data)
595                    if hasattr(lons, "mask"):
596                        lons = np.where(lons.mask, np.nan, lons.data)
[3810]597
598                    # Convert radians to degrees if needed
599                    lats_deg = np.round(np.degrees(lats), 6)
600                    lons_deg = np.round(np.degrees(lons), 6)
601
602                    # Build regular grid
603                    uniq_lats = np.unique(lats_deg)
604                    uniq_lons = np.unique(lons_deg)
605                    nlon = len(uniq_lons)
606
607                    data2d = []
608                    for lat_val in uniq_lats:
609                        mask = lats_deg == lat_val
610                        slice_vals = dslice[mask]
611                        lons_at_lat = lons_deg[mask]
612                        if len(slice_vals) == 1:
613                            row = np.full(nlon, slice_vals[0])
614                        else:
615                            order = np.argsort(lons_at_lat)
616                            row = np.full(nlon, np.nan)
617                            row[: len(slice_vals)] = slice_vals[order]
618                        data2d.append(row)
619                    data2d = np.array(data2d)
620
621                    # Wrap longitude if needed
622                    if -180.0 in uniq_lons:
623                        idx = np.where(np.isclose(uniq_lons, -180.0))[0][0]
624                        data2d = np.hstack([data2d, data2d[:, [idx]]])
625                        uniq_lons = np.append(uniq_lons, 180.0)
626
627                    # Plot interpolated map
628                    proj = ccrs.PlateCarree()
629                    fig, ax = plt.subplots(subplot_kw=dict(projection=proj), figsize=(8, 6))
630                    lon2d, lat2d = np.meshgrid(uniq_lons, uniq_lats)
631                    lon_ticks = np.arange(-180, 181, 30)
632                    lat_ticks = np.arange(-90, 91, 30)
633                    ax.set_xticks(lon_ticks, crs=ccrs.PlateCarree())
634                    ax.set_yticks(lat_ticks, crs=ccrs.PlateCarree())
635                    ax.tick_params(
636                        axis='x', which='major',
637                        length=4,
638                        direction='out',
639                        pad=2,
640                        labelsize=8
641                    )
642                    ax.tick_params(
643                       axis='y', which='major',
644                       length=4,
645                       direction='out',
646                       pad=2,
647                       labelsize=8
648                    )
649                    cf = ax.contourf(
650                        lon2d, lat2d, data2d,
651                        levels=100,
652                        cmap=colormap,
653                        transform=proj
654                    )
655
656                    # Overlay MOLA topography
657                    overlay_topography(ax, transform=proj, levels=10)
658
659                    # Colorbar & labels
660                    cbar = fig.colorbar(cf, ax=ax, pad=0.02)
[3808]661                    cbar.set_label(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
[3824]662                    ax.set_title(f"{varname} (interpolated map over physical_points)", fontweight='bold')
[3810]663                    ax.set_xlabel(f"Longitude ({getattr(dataset.variables[lonv], 'units', 'deg')})")
664                    ax.set_ylabel(f"Latitude ({getattr(dataset.variables[latv], 'units', 'deg')})")
665
666                    # Prompt for polar-stereo views if interactive
[3818]667                    if input("Display polar-stereo views? [y/n]: ").strip().lower() == "y":
[3810]668                        units = getattr(dataset.variables[varname], "units", None)
669                        plot_polar_views(lon2d, lat2d, data2d, colormap, varname, units)
670
[3818]671                    # Prompt for 3D globe view if interactive
672                    if input("Display 3D globe view? [y/n]: ").strip().lower() == "y":
673                        units = getattr(dataset.variables[varname], "units", None)
674                        plot_3D_globe(lon2d, lat2d, data2d, colormap, varname, units)
675
[3808]676                    if output_path:
677                        plt.savefig(output_path, bbox_inches="tight")
678                        print(f"Saved to {output_path}")
679                    else:
680                        plt.show()
681                    return
682            # vertical profile?
683            coord = None
[3798]684            if dname.lower() == "subsurface_layers" and "soildepth" in dataset.variables:
[3808]685                coord = "soildepth"
[3798]686            elif dname in dataset.variables:
[3808]687                coord = dname
688            if coord:
689                coords = dataset.variables[coord][:]
690                if hasattr(coords, "mask"):
691                    coords = np.where(coords.mask, np.nan, coords.data)
[3783]692                plt.figure()
[3808]693                plt.plot(dslice, coords, marker="o")
[3798]694                if dname.lower() == "subsurface_layers":
695                    plt.gca().invert_yaxis()
[3808]696                plt.xlabel(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
697                plt.ylabel(coord + (f" ({dataset.variables[coord].units})" if hasattr(dataset.variables[coord], "units") else ""))
[3824]698                plt.title(f"{varname} vs {coord}", fontweight='bold')
[3798]699                if output_path:
[3808]700                    plt.savefig(output_path, bbox_inches="tight")
701                    print(f"Saved to {output_path}")
[3798]702                else:
703                    plt.show()
704                return
[3808]705        # generic 1D
706        plt.figure()
707        plt.plot(dslice, marker="o")
708        plt.xlabel("Index")
709        plt.ylabel(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
[3824]710        plt.title(f"{varname} (1D)", fontweight='bold')
[3808]711        if output_path:
712            plt.savefig(output_path, bbox_inches="tight")
713            print(f"Saved to {output_path}")
[3798]714        else:
[3808]715            plt.show()
716        return
[3798]717
[3824]718    if dslice.ndim == 2:
[3798]719        lat_idx2 = find_dim_index(dims, LAT_DIMS)
720        lon_idx2 = find_dim_index(dims, LON_DIMS)
[3810]721
722        # Geographic lat×lon slice
[3798]723        if lat_idx2 is not None and lon_idx2 is not None:
[3808]724            latv = find_coord_var(dataset, LAT_DIMS)
725            lonv = find_coord_var(dataset, LON_DIMS)
726            lats = dataset.variables[latv][:]
727            lons = dataset.variables[lonv][:]
[3810]728
[3824]729            # Correct latitudes order
730            if lats[0] > lats[-1]:
731                lats = lats[::-1]
732                dslice = np.flipud(dslice)
733
[3810]734            # Handle masked arrays
[3808]735            if hasattr(lats, "mask"):
736                lats = np.where(lats.mask, np.nan, lats.data)
737            if hasattr(lons, "mask"):
738                lons = np.where(lons.mask, np.nan, lons.data)
[3810]739
740            # Create map projection
741            proj = ccrs.PlateCarree()
742            fig, ax = plt.subplots(figsize=(10, 6), subplot_kw=dict(projection=proj))
743
744            # Make meshgrid and plot
745            lon2d, lat2d = np.meshgrid(lons, lats)
746            cf = ax.contourf(
747                lon2d, lat2d, dslice,
748                levels=100,
749                cmap=colormap,
750                transform=proj
751            )
752
753            # Overlay topography
754            overlay_topography(ax, transform=proj, levels=10)
755
756            # Colorbar and labels
757            lon_ticks = np.arange(-180, 181, 30)
758            lat_ticks = np.arange(-90, 91, 30)
759            ax.set_xticks(lon_ticks, crs=ccrs.PlateCarree())
760            ax.set_yticks(lat_ticks, crs=ccrs.PlateCarree())
761            ax.tick_params(
762                axis='x', which='major',
763                length=4,
764                direction='out',
765                pad=2,
766                labelsize=8
767            )
768            ax.tick_params(
769                axis='y', which='major',
770                length=4,
771                direction='out',
772                pad=2,
773                labelsize=8
774            )
775            cbar = fig.colorbar(cf, ax=ax, orientation="vertical", pad=0.02)
776            cbar.set_label(varname + (f" ({dataset.variables[varname].units})"
777                                      if hasattr(dataset.variables[varname], "units") else ""))
[3824]778            ax.set_title(f"{varname} (lat × lon)", fontweight='bold')
[3810]779            ax.set_xlabel(f"Longitude ({getattr(dataset.variables[lonv], 'units', 'deg')})")
780            ax.set_ylabel(f"Latitude ({getattr(dataset.variables[latv], 'units', 'deg')})")
781
782            # Prompt for polar-stereo views if interactive
783            if sys.stdin.isatty() and input("Display polar-stereo views? [y/n]: ").strip().lower() == "y":
784                units = getattr(dataset.variables[varname], "units", None)
785                plot_polar_views(lon2d, lat2d, dslice, colormap, varname, units)
786
[3818]787            # Prompt for 3D globe view if interactive
788            if sys.stdin.isatty() and input("Display 3D globe view? [y/n]: ").strip().lower() == "y":
789                units = getattr(dataset.variables[varname], "units", None)
790                plot_3D_globe(lon2d, lat2d, dslice, colormap, varname, units)
791
[3783]792            if output_path:
[3808]793                plt.savefig(output_path, bbox_inches="tight")
794                print(f"Saved to {output_path}")
[3783]795            else:
796                plt.show()
797            return
[3810]798
799        # Generic 2D
[3839]800        fig, ax = plt.subplots(figsize=(8, 6))
801        im = ax.imshow(
802            dslice,
803            aspect="auto",
804            interpolation='nearest'
805        )
806        x0, x1 = 0, dslice.shape[1] - 1
807        y0, y1 = 0, dslice.shape[0] - 1
808        x_centers = np.linspace(x0, x1, dslice.shape[1])
809        y_centers = np.linspace(y0, y1, dslice.shape[0])
810        attach_format_coord(ax, dslice, x_centers, y_centers, is_pcolormesh=False)
811        cbar = fig.colorbar(im, ax=ax, orientation='vertical')
812        cbar.set_label(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
813        ax.set_xlabel("Dim 2 index")
814        ax.set_ylabel("Dim 1 index")
815        ax.set_title(f"{varname} (2D)")
816
[3808]817        if output_path:
[3839]818            fig.savefig(output_path, bbox_inches="tight")
[3808]819            print(f"Saved to {output_path}")
[3798]820        else:
[3808]821            plt.show()
822        return
[3783]823
[3808]824    print(f"Error: ndim={dslice.ndim} not supported.")
[3783]825
826
827def visualize_variable_interactive(nc_path=None):
828    """
[3810]829    Interactive loop: keep prompting for variables to plot until user quits.
[3783]830    """
[3810]831    # Open dataset
[3783]832    if nc_path:
[3808]833        path = nc_path
[3783]834    else:
835        readline.set_completer(complete_filename)
836        readline.parse_and_bind("tab: complete")
[3808]837        path = input("Enter path to NetCDF file: ").strip()
[3810]838
[3808]839    if not os.path.isfile(path):
[3810]840        print(f"Error: '{path}' not found.")
841        return
842
[3808]843    ds = Dataset(path, "r")
[3810]844    var_list = list(ds.variables.keys())
845    if not var_list:
846        print("No variables found in file.")
847        ds.close()
848        return
[3783]849
[3810]850    # Enable interactive mode
851    plt.ion()
852
853    while True:
854        # Enable tab-completion for variable names
855        readline.set_completer(make_varname_completer(var_list))
[3808]856        readline.parse_and_bind("tab: complete")
[3783]857
[3810]858        print("\nAvailable variables:")
859        for name in var_list:
860            print(f"  - {name}")
861        varname = input("\nEnter variable name to plot (or 'q' to quit): ").strip()
862        if varname.lower() in ("q", "quit", "exit"):
863            print("Exiting.")
864            break
865        if varname not in ds.variables:
866            print(f"Variable '{varname}' not found. Try again.")
867            continue
[3783]868
[3810]869        # Display dimensions and size
870        var = ds.variables[varname]
871        dims, shape = var.dimensions, var.shape
872        print(f"\nVariable '{varname}' has dimensions:")
873        for dim, size in zip(dims, shape):
874            print(f"  - {dim}: size {size}")
875        print()
[3783]876
[3810]877        # Prepare slicing parameters
878        time_index = None
879        alt_index = None
880        avg = False
881        extra_indices = {}
[3798]882
[3810]883        # Time index
884        t_idx = find_dim_index(dims, TIME_DIMS)
885        if t_idx is not None:
886            if shape[t_idx] > 1:
887                while True:
888                    idx = input(f"Enter time index [1–{shape[t_idx]}] (press Enter for all): ").strip()
889                    if idx == '':
890                        time_index = None
891                        break
892                    if idx.isdigit():
893                        i = int(idx)
894                        if 1 <= i <= shape[t_idx]:
895                            time_index = i - 1
896                            break
897                    print("Invalid entry. Please enter a valid number or press Enter.")
898            else:
899                time_index = 0
[3783]900
[3810]901        # Altitude index
902        a_idx = find_dim_index(dims, ALT_DIMS)
903        if a_idx is not None:
904            if shape[a_idx] > 1:
905                while True:
906                    idx = input(f"Enter altitude index [1–{shape[a_idx]}] (press Enter for all): ").strip()
907                    if idx == '':
908                        alt_index = None
[3783]909                        break
[3810]910                    if idx.isdigit():
911                        i = int(idx)
912                        if 1 <= i <= shape[a_idx]:
913                            alt_index = i - 1
914                            break
915                    print("Invalid entry. Please enter a valid number or press Enter.")
916            else:
917                alt_index = 0
[3783]918
[3810]919        # Average over latitude?
920        lat_idx = find_dim_index(dims, LAT_DIMS)
921        lon_idx = find_dim_index(dims, LON_DIMS)
922        if (t_idx is not None and lat_idx is not None and lon_idx is not None and
923            shape[t_idx] > 1 and shape[lat_idx] > 1 and shape[lon_idx] > 1):
924            resp = input("Average over latitude and plot lon vs time? [y/n]: ").strip().lower()
925            avg = (resp == 'y')
926
927        # Other dimensions
928        for i, dname in enumerate(dims):
929            if i in (t_idx, a_idx):
930                continue
931            size = shape[i]
932            if size == 1:
933                extra_indices[dname] = 0
934                continue
[3783]935            while True:
[3810]936                idx = input(f"Enter index [1–{size}] for '{dname}' (press Enter for all): ").strip()
937                if idx == '':
938                    # keep all values
939                    break
940                if idx.isdigit():
941                    j = int(idx)
942                    if 1 <= j <= size:
943                        extra_indices[dname] = j - 1
[3783]944                        break
[3810]945                print("Invalid entry. Please enter a valid number or press Enter.")
[3783]946
[3810]947        # Plot the variable
948        plot_variable(
949            ds, varname,
950            time_index    = time_index,
951            alt_index     = alt_index,
952            colormap      = 'jet',
953            output_path   = None,
954            extra_indices = extra_indices,
955            avg_lat       = avg
956        )
[3798]957
[3783]958    ds.close()
959
960
[3808]961def visualize_variable_cli(nc_file, varname, time_index, alt_index,
962                           colormap, output_path, extra_json, avg_lat):
[3783]963    """
[3798]964    Command-line mode: visualize directly, parsing the --extra-indices argument (JSON string).
[3783]965    """
[3808]966    if not os.path.isfile(nc_file):
[3810]967        print(f"Error: '{nc_file}' not found.")
968        return
[3808]969    ds = Dataset(nc_file, "r")
970    if varname not in ds.variables:
[3810]971        print(f"Variable '{varname}' not in file.")
972        ds.close()
973        return
[3798]974
[3810]975    # Display dimensions and size
[3808]976    dims  = ds.variables[varname].dimensions
977    shape = ds.variables[varname].shape
978    print(f"\nVariable '{varname}' has {len(dims)} dimensions:")
979    for name, size in zip(dims, shape):
980        print(f"  - {name}: size {size}")
981    print()
[3783]982
[3810]983    # Special case: time-only → plot directly
[3808]984    t_idx = find_dim_index(dims, TIME_DIMS)
985    if (
986        t_idx is not None and shape[t_idx] > 1 and
987        all(shape[i] == 1 for i in range(len(dims)) if i != t_idx)
988    ):
989        print("Detected single-point spatial dims; plotting time series…")
990        var_obj = ds.variables[varname]
991        data = var_obj[:].squeeze()
992        time_var = find_coord_var(ds, TIME_DIMS)
993        if time_var:
994            tvals = ds.variables[time_var][:]
995        else:
996            tvals = np.arange(data.shape[0])
997        if hasattr(data, "mask"):
998            data = np.where(data.mask, np.nan, data.data)
999        if hasattr(tvals, "mask"):
1000            tvals = np.where(tvals.mask, np.nan, tvals.data)
1001        plt.figure()
1002        plt.plot(tvals, data, marker="o")
1003        plt.xlabel(time_var or "Time Index")
1004        plt.ylabel(varname + (f" ({var_obj.units})" if hasattr(var_obj, "units") else ""))
[3824]1005        plt.title(f"{varname} vs {time_var or 'Index'}", fontweight='bold')
[3808]1006        if output_path:
1007            plt.savefig(output_path, bbox_inches="tight")
1008            print(f"Saved to {output_path}")
1009        else:
1010            plt.show()
[3783]1011        ds.close()
1012        return
1013
[3810]1014    # if --avg-lat but lat/lon/Time not compatible → disable
[3808]1015    lat_idx = find_dim_index(dims, LAT_DIMS)
1016    lon_idx = find_dim_index(dims, LON_DIMS)
1017    if avg_lat and not (
1018        t_idx   is not None and shape[t_idx]  > 1 and
1019        lat_idx is not None and shape[lat_idx] > 1 and
1020        lon_idx is not None and shape[lon_idx] > 1
1021    ):
1022        print("Note: disabling --avg-lat (requires Time, lat & lon each >1).")
1023        avg_lat = False
1024
1025    # Parse extra indices JSON
1026    extra = {}
[3798]1027    if extra_json:
1028        try:
1029            parsed = json.loads(extra_json)
[3808]1030            for k, v in parsed.items():
1031                if isinstance(v, int):
1032                    if "slope" in k.lower():
1033                        extra[k] = v - 1
1034                    else:
1035                        extra[k] = v
1036        except:
1037            print("Warning: bad extra-indices.")
[3798]1038
[3808]1039    plot_variable(ds, varname, time_index, alt_index,
1040                  colormap, output_path, extra, avg_lat)
[3783]1041    ds.close()
1042
1043
1044def main():
[3808]1045    parser = argparse.ArgumentParser()
[3818]1046    parser.add_argument('nc_file', nargs='?', help='NetCDF file (omit for interactive)')
1047    parser.add_argument('-v','--variable', help='Variable name')
1048    parser.add_argument('-t','--time-index', type=int, help='Time index (0-based)')
1049    parser.add_argument('-a','--alt-index', type=int, help='Altitude index (0-based)')
1050    parser.add_argument('-c','--cmap', default='jet', help='Colormap')
1051    parser.add_argument('--avg-lat', action='store_true', help='Average over latitude')
1052    parser.add_argument('--show-polar', action='store_true', help='Enable polar-stereo views')
1053    parser.add_argument('--show-3d', action='store_true', help='Enable 3D globe view')
1054    parser.add_argument('-o','--output', help='Save figure path')
1055    parser.add_argument('-e','--extra-indices', help='JSON string for other dims')
[3783]1056    args = parser.parse_args()
1057
[3798]1058    if args.nc_file and args.variable:
[3783]1059        visualize_variable_cli(
[3808]1060            args.nc_file, args.variable,
1061            args.time_index, args.alt_index,
1062            args.cmap, args.output,
1063            args.extra_indices, args.avg_lat
[3783]1064        )
[3798]1065    else:
[3808]1066        visualize_variable_interactive(args.nc_file)
[3783]1067
1068
1069if __name__ == "__main__":
1070    main()
[3810]1071
Note: See TracBrowser for help on using the repository browser.