Changeset 3849


Ignore:
Timestamp:
Jul 15, 2025, 12:24:07 PM (34 hours ago)
Author:
jbclement
Message:

Mars PCM:
Simplification and generalization of the "display_netcdf.py" script which can now handle more cases of shapes/dimensions in the variables + using 'pcolormesh' instead of 'imshow' to better stick to the original data.
JBC

Location:
trunk/LMDZ.MARS
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • trunk/LMDZ.MARS/changelog.txt

    r3848 r3849  
    49104910== 15/07/2025 == JBC
    49114911In the 1D model, if no input profile found, the 'dust_mass' and 'dust_number' tracers are initialized according to "Conrath dust" in line with what it's done for CO2 and HDO, instead of being set to 0 which caused bugs.
     4912
     4913== 15/07/2025 == JBC
     4914Simplification and generalization of the "display_netcdf.py" script which can now handle more cases of shapes/dimensions in the variables + using 'pcolormesh' instead of 'imshow' to better stick to the original data.
  • trunk/LMDZ.MARS/util/display_netcdf.py

    r3839 r3849  
    173173
    174174
    175 def attach_format_coord(ax, mat, x, y, is_pcolormesh=True):
     175def attach_format_coord(ax, mat, x, y, is_pcolormesh=True, data_crs=ccrs.PlateCarree()):
    176176    """
    177177    Attach a format_coord function to the axes to display x, y, and value at cursor.
     
    192192        y0, y1 = y.min(), y.max()
    193193
     194    # Detect if ax is a GeoAxes with a projection we can invert
     195    proj = getattr(ax, 'projection', None)
     196    use_geo = isinstance(proj, ccrs.Projection)
     197
    194198    def format_coord(xp, yp):
    195         # Map to indices
     199        # If GeoAxes, invert back to geographic lon/lat
     200        if use_geo:
     201            try:
     202                lonp, latp = data_crs.transform_point(xp, yp, src_crs=proj)
     203            except Exception:
     204                lonp, latp = xp, yp
     205            xi, yi = lonp, latp
     206        else:
     207            xi, yi = xp, yp
     208
     209        # Map to matrix indices
    196210        if is_pcolormesh:
    197             col = np.searchsorted(xedges, xp) - 1
    198             row = np.searchsorted(yedges, yp) - 1
    199         else:
    200             col = int((xp - x0) / (x1 - x0) * nx)
    201             row = int((yp - y0) / (y1 - y0) * ny)
     211            col = np.searchsorted(xedges, xi) - 1
     212            row = np.searchsorted(yedges, yi) - 1
     213        else:
     214            col = int((xi - x0) / (x1 - x0) * nx)
     215            row = int((yi - y0) / (y1 - y0) * ny)
     216
    202217        # Within bounds?
    203218        if 0 <= row < ny and 0 <= col < nx:
    204219            if mat.ndim == 2:
    205220                v = mat[row, col]
    206                 return f"x={xp:.3g}, y={yp:.3g}, val={v:.3g}"
     221                return f"lon={xi:.3g}, lat={yi:.3g}, val={v:.3g}"
    207222            else:
    208223                vals = mat[row, col]
    209224                txt = ", ".join(f"{vv:.3g}" for vv in vals[:3])
    210                 return f"x={xp:.3g}, y={yp:.3g}, val=({txt})"
    211         return f"x={xp:.3g}, y={yp:.3g}"
     225                return f"lon={xi:.3g}, lat={yi:.3g}, val=({txt})"
     226        # Out of bounds: still show coords
     227        return f"lon={xi:.3g}, lat={yi:.3g}"
    212228
    213229    ax.format_coord = format_coord
     
    253269
    254270        # Plot data in PlateCarree projection
    255         cf = ax.contourf(
     271        cf = ax.pcolormesh(
    256272            lon2d, lat2d, data2d,
    257             levels=100,
     273            shading='auto',
    258274            cmap=colormap,
    259275            transform=ccrs.PlateCarree()
    260276        )
     277        uniq_lons = np.unique(lon2d.ravel())
     278        uniq_lats = np.unique(lat2d.ravel())
     279        attach_format_coord(ax, data2d, uniq_lons, uniq_lats, is_pcolormesh=True)
    261280
    262281        # Optionally overlay MOLA topography
     
    413432
    414433
    415 def plot_variable(dataset, varname, time_index=None, alt_index=None,
    416                   colormap="jet", output_path=None, extra_indices=None,
    417                   avg_lat=False):
    418     """
    419     Core plotting logic: reads the variable, handles masks,
    420     determines dimensionality, and creates the appropriate plot:
    421       - 1D time series
    422       - 1D profiles or physical_points maps
    423       - 2D lat×lon or generic 2D
    424       - Time×lon heatmap if avg_lat=True
    425       - Scalar printing
     434def transform_physical_points(dataset, var, data):
     435    """
     436    Transform a physical_points 1D array into a 2D grid of shape (nlat, nlon).
     437    """
     438    # Fetch lat/lon coordinate variables
     439    lat_var = find_coord_var(dataset, LAT_DIMS)
     440    lon_var = find_coord_var(dataset, LON_DIMS)
     441    if lat_var is None or lon_var is None:
     442        raise ValueError("Cannot find latitude or longitude variables for physical_points")
     443    raw_lats = dataset.variables[lat_var][:]
     444    raw_lons = dataset.variables[lon_var][:]
     445    # Unmask
     446    if hasattr(raw_lats, 'mask'):
     447        raw_lats = np.where(raw_lats.mask, np.nan, raw_lats.data)
     448    if hasattr(raw_lons, 'mask'):
     449        raw_lons = np.where(raw_lons.mask, np.nan, raw_lons.data)
     450    # Convert radians to degrees if in radians
     451    if np.max(np.abs(raw_lats)) <= np.pi:
     452        raw_lats = np.degrees(raw_lats)
     453        raw_lons = np.degrees(raw_lons)
     454    # Get unique coords
     455    uniq_lats = np.unique(raw_lats)
     456    uniq_lons = np.unique(raw_lons)
     457    # Initialize grid
     458    grid = np.full((uniq_lats.size, uniq_lons.size), np.nan)
     459    # Build the grid
     460    for value, lat, lon in zip(data.ravel(), raw_lats.ravel(), raw_lons.ravel()):
     461        i = np.where(np.isclose(uniq_lats, lat))[0][0]
     462        j = np.where(np.isclose(uniq_lons, lon))[0][0]
     463        grid[i, j] = value
     464    # Duplicate the pole value across all longitudes
     465    for i in (0, -1):
     466        row = grid[i, :]
     467        count_good = np.count_nonzero(~np.isnan(row))
     468        if count_good == 1:
     469            pole_value = row[~np.isnan(row)][0]
     470            grid[i, :] = pole_value
     471    # Wrap longitude if needed
     472    if -180.0 in uniq_lons:
     473        idx = np.where(np.isclose(uniq_lons, -180.0))[0][0]
     474        grid = np.hstack([grid, grid[:, [idx]]])
     475        uniq_lons = np.append(uniq_lons, 180.0)
     476    return grid, uniq_lats, uniq_lons, lat_var, lon_var
     477
     478
     479def get_dimension_indices(ds, varname):
     480    """
     481    For each dimension of the variable:
     482     - if size == 1 → automatically select index 0
     483     - otherwise prompt the user:
     484         <number>     : take that specific index (1-based)
     485         'a'          : average over this dimension
     486         'e' or Enter : take all values
     487    Returns {dim_name: int index, 'avg', or None}.
     488    """
     489    var = ds.variables[varname]
     490    dims = var.dimensions
     491    shape = var.shape
     492    selection = {}
     493    for dim, size in zip(dims, shape):
     494        if size == 1:
     495            selection[dim] = 0
     496            continue
     497        prompt = (
     498                f"Available options for '{dim}' (size {size}):\n"
     499                f"  - '1–{size}' to pick that index\n"
     500                "  - 'a' to average over this dimension\n"
     501                "  - 'e' or Enter to take all values\n"
     502                "Choose: "
     503        )
     504        while True:
     505            resp = input(prompt).strip().lower()
     506            if resp in ("", "e"):
     507                selection[dim] = None
     508                break
     509            if resp == 'a':
     510                selection[dim] = 'avg'
     511                break
     512            if resp.isdigit():
     513                n = int(resp)
     514                if 1 <= n <= size:
     515                    selection[dim] = n - 1
     516                    break
     517            print(f"  Invalid entry '{resp}'. Please enter a number, 'a', 'e', or just Enter.")
     518    return selection
     519
     520
     521def plot_variable(dataset, varname, colormap="jet", output_path=None, extra_indices=None):
     522    """
     523    Automatically select singleton dims, prompt for others,
     524    allow user to choose x/y for 2D, handle special cases (physical_points, averaging).
    426525    """
    427526    var = dataset.variables[varname]
    428     dims = var.dimensions
    429 
    430     # Read full data
     527    dims = list(var.dimensions)
     528    # Read data
    431529    try:
    432530        data_full = var[:]
     
    434532        print(f"Error: Cannot read data for '{varname}': {e}")
    435533        return
    436     if hasattr(data_full, "mask"):
     534    # Unmask
     535    if hasattr(data_full, 'mask'):
    437536        data_full = np.where(data_full.mask, np.nan, data_full.data)
    438 
    439     # If Time and altitude are both present and neither indexed,
    440     # and every other dim has size 1:
    441     t_idx = find_dim_index(dims, TIME_DIMS)
    442     a_idx = find_dim_index(dims, ALT_DIMS)
    443     shape = var.shape
    444     if (t_idx is not None and a_idx is not None
    445         and time_index is None and alt_index is None
    446         and all(size == 1 for i, size in enumerate(shape) if i not in (t_idx, a_idx))):
    447 
    448         # Build a slicer that keeps Time & altitude, drops other singletons
    449         slicer = [0] * len(dims)
    450         slicer[t_idx] = slice(None)
    451         slicer[a_idx] = slice(None)
    452         data2d = data_full[tuple(slicer)]  # shape (ntime, nalt)
    453 
    454         # Coordinate arrays
    455         tvar = find_coord_var(dataset, TIME_DIMS)
    456         avar = find_coord_var(dataset, ALT_DIMS)
    457         tvals = dataset.variables[tvar][:]
    458         avals = dataset.variables[avar][:]
    459 
    460         # Unmask if necessary
    461         if hasattr(tvals, "mask"):
    462             tvals = np.where(tvals.mask, np.nan, tvals.data)
    463         if hasattr(avals, "mask"):
    464             avals = np.where(avals.mask, np.nan, avals.data)
    465 
    466         # Plot heatmap with x=time, y=altitude
    467         fig, ax = plt.subplots(figsize=(10, 6))
    468         T, A = np.meshgrid(tvals, avals)
    469         im = ax.pcolormesh(
    470             T, A, data2d.T,
    471             shading="auto", cmap=colormap
     537    # Initialize extra_indices
     538    extra_indices = extra_indices or {}
     539    # Handle averaging selections
     540    for dim, mode in dict(extra_indices).items():
     541        if mode == 'avg':
     542            ax = dims.index(dim)
     543            data_full = np.nanmean(data_full, axis=ax, keepdims=True)
     544            extra_indices[dim] = 0
     545    # Build slicer
     546    slicer = []
     547    for dim in dims:
     548        idx = extra_indices.get(dim)
     549        slicer.append(idx if isinstance(idx, int) else slice(None))
     550    data_slice = data_full[tuple(slicer)]
     551    nd = data_slice.ndim
     552    # Special case: physical_points dimension
     553    if nd == 1 and 'physical_points' in dims:
     554        # Transform into 2D grid
     555        grid, uniq_lats, uniq_lons, latv, lonv = transform_physical_points(dataset, var, data_slice)
     556        # Plot map
     557        proj = ccrs.PlateCarree()
     558        fig, ax = plt.subplots(figsize=(8, 6), subplot_kw=dict(projection=proj))
     559        lon2d, lat2d = np.meshgrid(uniq_lons, uniq_lats)
     560        lon_ticks = np.arange(-180, 181, 30)
     561        lat_ticks = np.arange(-90, 91, 30)
     562        ax.set_xticks(lon_ticks, crs=ccrs.PlateCarree())
     563        ax.set_yticks(lat_ticks, crs=ccrs.PlateCarree())
     564        ax.tick_params(
     565            axis='x', which='major',
     566            length=4,
     567            direction='out',
     568            pad=2,
     569            labelsize=8
    472570        )
    473         dt = tvals[1] - tvals[0]
    474         da = avals[1] - avals[0]
    475         x_edges = np.concatenate([tvals - dt/2, [tvals[-1] + dt/2]])
    476         y_edges = np.concatenate([avals - da/2, [avals[-1] + da/2]])
    477         attach_format_coord(ax, data2d.T, x_edges, y_edges, is_pcolormesh=True)
    478         ax.set_xlabel(tvar)
    479         ax.set_ylabel(avar)
    480         cbar = fig.colorbar(im, ax=ax)
    481         cbar.set_label(varname + (f" ({getattr(var, 'units','')})"))
    482         ax.set_title(f"{varname} — {avar} vs {tvar}", fontweight="bold")
    483 
     571        ax.tick_params(
     572            axis='y', which='major',
     573            length=4,
     574            direction='out',
     575            pad=2,
     576            labelsize=8
     577        )
     578        cf = ax.pcolormesh(lon2d, lat2d, grid, shading='auto', cmap=colormap, transform=ccrs.PlateCarree())
     579        attach_format_coord(ax, grid, uniq_lons, uniq_lats, is_pcolormesh=True)
     580        overlay_topography(ax, transform=proj, levels=10) # Overlay MOLA topography
     581        cbar = fig.colorbar(cf, ax=ax, pad=0.02)
     582        cbar.set_label(varname)
     583        ax.set_title(f"{varname} (physical_points)", fontweight='bold')
     584        ax.set_xlabel(f"Longitude ({getattr(dataset.variables[lonv], 'units', 'deg')})")
     585        ax.set_ylabel(f"Latitude ({getattr(dataset.variables[latv], 'units', 'deg')})")
     586        # Prompt for polar-stereo views if interactive
     587        if input("Display polar-stereo views? [y/n]: ").strip().lower() == "y":
     588            units = getattr(var, 'units', None)
     589            plot_polar_views(lon2d, lat2d, grid, colormap, varname, units)
     590        # Prompt for 3D globe view if interactive
     591        if input("Display 3D globe view? [y/n]: ").strip().lower() == "y":
     592            units = getattr(var, 'units', None)
     593            plot_3D_globe(lon2d, lat2d, grid, colormap, varname, units)
    484594        if output_path:
    485             fig.savefig(output_path, bbox_inches="tight")
     595            fig.savefig(output_path, bbox_inches='tight')
    486596            print(f"Saved to {output_path}")
    487597        else:
    488598            plt.show()
    489599        return
    490 
    491     # Pure 1D time series
    492     if len(dims) == 1 and find_dim_index(dims, TIME_DIMS) is not None:
    493         time_var = find_coord_var(dataset, TIME_DIMS)
    494         tvals = (dataset.variables[time_var][:] if time_var
    495                  else np.arange(data_full.shape[0]))
    496         if hasattr(tvals, "mask"):
    497             tvals = np.where(tvals.mask, np.nan, tvals.data)
    498         plt.figure()
    499         plt.plot(tvals, data_full, marker="o")
    500         plt.xlabel(time_var or "Time Index")
    501         plt.ylabel(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
    502         plt.title(f"{varname} vs {time_var or 'Index'}", fontweight='bold')
     600    # 0D
     601    if nd == 0:
     602        print(f"\033[36mScalar '{varname}': {float(data_slice)}\033[0m")
     603        return
     604    # 1D
     605    if nd == 1:
     606        rem = [(i, d) for i, (d, s) in enumerate(zip(dims, slicer)) if isinstance(s, slice)]
     607        axis_idx, dim_name = rem[0]
     608        coord_var = find_coord_var(dataset, [dim_name])
     609        if coord_var:
     610            x = dataset.variables[coord_var][:]
     611            if hasattr(x, 'mask'):
     612                x = np.where(x.mask, np.nan, x.data)
     613            xlabel = coord_var
     614        else:
     615            x = np.arange(data_slice.shape[0])
     616            xlabel = dim_name
     617        y = data_slice
     618        plt.figure(figsize=(8, 4))
     619        plt.plot(x, y)
     620        plt.grid(True)
     621        plt.xlabel(xlabel)
     622        plt.ylabel(varname)
     623        plt.title(f"{varname} vs {xlabel}")
    503624        if output_path:
    504             plt.savefig(output_path, bbox_inches="tight")
     625            plt.savefig(output_path, bbox_inches='tight')
     626            print(f"Saved plot to {output_path}")
     627        else:
     628            plt.show()
     629        return
     630    # 2D
     631    if nd == 2:
     632        remaining = [d for d, idx in zip(dims, slicer) if isinstance(idx, slice)]
     633        # Choose X/Y interactively
     634        resp = input(f"Which dimension on X? {remaining}: ").strip()
     635        if resp == remaining[1]:
     636            x_dim, y_dim = remaining[1], remaining[0]
     637        else:
     638            x_dim, y_dim = remaining[0], remaining[1]
     639        def get_coords(dim):
     640            coord_var = find_coord_var(dataset, [dim])
     641            if coord_var:
     642                arr = dataset.variables[coord_var][:]
     643                if hasattr(arr, 'mask'):
     644                    arr = np.where(arr.mask, np.nan, arr.data)
     645                return arr
     646            return np.arange(data_slice.shape[remaining.index(dim)])
     647        x_coords = get_coords(x_dim)
     648        y_coords = get_coords(y_dim)
     649        order = [remaining.index(y_dim), remaining.index(x_dim)]
     650        plot_data = np.moveaxis(data_slice, order, [0, 1])
     651        fig, ax = plt.subplots(figsize=(8, 6))
     652        im = ax.pcolormesh(x_coords, y_coords, plot_data, shading='auto', cmap=colormap)
     653        attach_format_coord(ax, plot_data, x_coords, y_coords, is_pcolormesh=True)
     654        cbar = fig.colorbar(im, ax=ax, pad=0.02)
     655        cbar.set_label(varname)
     656        ax.set_xlabel(x_dim)
     657        ax.set_ylabel(y_dim)
     658        ax.set_title(f"{varname} ({y_dim} vs {x_dim})")
     659        ax.grid(True)
     660        # Prompt for polar-stereo views if interactive
     661        if sys.stdin.isatty() and input("Display polar-stereo views? [y/n]: ").strip().lower() == "y":
     662            units = getattr(dataset.variables[varname], "units", None)
     663            plot_polar_views(x_coords, y_coords, plot_data, colormap, varname, units)
     664        # Prompt for 3D globe view if interactive
     665        if sys.stdin.isatty() and input("Display 3D globe view? [y/n]: ").strip().lower() == "y":
     666            units = getattr(dataset.variables[varname], "units", None)
     667            plot_3D_globe(x_coords, y_coords, plot_data, colormap, varname, units)
     668        if output_path:
     669            fig.savefig(output_path, bbox_inches='tight')
    505670            print(f"Saved to {output_path}")
    506671        else:
    507672            plt.show()
    508673        return
    509 
    510     # Identify dims
    511     t_idx = find_dim_index(dims, TIME_DIMS)
    512     lat_idx = find_dim_index(dims, LAT_DIMS)
    513     lon_idx = find_dim_index(dims, LON_DIMS)
    514     a_idx = find_dim_index(dims, ALT_DIMS)
    515 
    516     # Average over latitude & plot time × lon heatmap
    517     if avg_lat and t_idx is not None and lat_idx is not None and lon_idx is not None:
    518         # compute mean over lat axis
    519         data_avg = np.nanmean(data_full, axis=lat_idx)
    520         # prepare coordinates
    521         time_var = find_coord_var(dataset, TIME_DIMS)
    522         lon_var = find_coord_var(dataset, LON_DIMS)
    523         tvals = dataset.variables[time_var][:]
    524         lons = dataset.variables[lon_var][:]
    525         if hasattr(tvals, "mask"):
    526             tvals = np.where(tvals.mask, np.nan, tvals.data)
    527         if hasattr(lons, "mask"):
    528             lons = np.where(lons.mask, np.nan, lons.data)
    529         fig, ax = ax.subplots(figsize=(10, 6))
    530         im = plt.pcolormesh(lons, tvals, data_avg, shading="auto", cmap=colormap)
    531         dx = lons[1] - lons[0]
    532         dy = tvals[1] - tvals[0]
    533         x_edges = np.concatenate([lons - dx/2, [lons[-1] + dx/2]])
    534         y_edges = np.concatenate([tvals - dy/2, [tvals[-1] + dy/2]])
    535         attach_format_coord(ax, data_avg.T, x_edges, y_edges, is_pcolormesh=True)
    536         ax.set_xlabel(f"Longitude ({getattr(dataset.variables[lon_var], 'units', 'deg')})")
    537         ax.set_ylabel(time_var)
    538         cbar = fig.colorbar(im, ax=ax)
    539         cbar.set_label(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
    540         ax.set_title(f"{varname} averaged over latitude", fontweight='bold')
    541         if output_path:
    542             fig.savefig(output_path, bbox_inches="tight")
    543             print(f"Saved to {output_path}")
    544         else:
    545             plt.show()
    546         return
    547 
    548     # Build slicer for other cases
    549     slicer = [slice(None)] * len(dims)
    550     if t_idx is not None:
    551         if time_index is None:
    552             print("Error: please supply a time index.")
    553             return
    554         slicer[t_idx] = time_index
    555     if a_idx is not None:
    556         if alt_index is None:
    557             print("Error: please supply an altitude index.")
    558             return
    559         slicer[a_idx] = alt_index
    560 
    561     if extra_indices is None:
    562         extra_indices = {}
    563     for dn, idx_val in extra_indices.items():
    564         if dn in dims:
    565             slicer[dims.index(dn)] = idx_val
    566 
    567     # Extract slice
    568     try:
    569         dslice = data_full[tuple(slicer)]
    570     except Exception as e:
    571         print(f"Error slicing '{varname}': {e}")
    572         return
    573 
    574     # Scalar
    575     if np.ndim(dslice) == 0:
    576         print(f"Scalar '{varname}': {float(dslice)}")
    577         return
    578 
    579     # 1D: vector, profile, or physical_points
    580     if dslice.ndim == 1:
    581         rem = [(i, name) for i, name in enumerate(dims) if slicer[i] == slice(None)]
    582         if rem:
    583             di, dname = rem[0]
    584             # physical_points → interpolated map
    585             if dname.lower() == "physical_points":
    586                 latv = find_coord_var(dataset, LAT_DIMS)
    587                 lonv = find_coord_var(dataset, LON_DIMS)
    588                 if latv and lonv:
    589                     lats = dataset.variables[latv][:]
    590                     lons = dataset.variables[lonv][:]
    591 
    592                     # Unmask
    593                     if hasattr(lats, "mask"):
    594                         lats = np.where(lats.mask, np.nan, lats.data)
    595                     if hasattr(lons, "mask"):
    596                         lons = np.where(lons.mask, np.nan, lons.data)
    597 
    598                     # Convert radians to degrees if needed
    599                     lats_deg = np.round(np.degrees(lats), 6)
    600                     lons_deg = np.round(np.degrees(lons), 6)
    601 
    602                     # Build regular grid
    603                     uniq_lats = np.unique(lats_deg)
    604                     uniq_lons = np.unique(lons_deg)
    605                     nlon = len(uniq_lons)
    606 
    607                     data2d = []
    608                     for lat_val in uniq_lats:
    609                         mask = lats_deg == lat_val
    610                         slice_vals = dslice[mask]
    611                         lons_at_lat = lons_deg[mask]
    612                         if len(slice_vals) == 1:
    613                             row = np.full(nlon, slice_vals[0])
    614                         else:
    615                             order = np.argsort(lons_at_lat)
    616                             row = np.full(nlon, np.nan)
    617                             row[: len(slice_vals)] = slice_vals[order]
    618                         data2d.append(row)
    619                     data2d = np.array(data2d)
    620 
    621                     # Wrap longitude if needed
    622                     if -180.0 in uniq_lons:
    623                         idx = np.where(np.isclose(uniq_lons, -180.0))[0][0]
    624                         data2d = np.hstack([data2d, data2d[:, [idx]]])
    625                         uniq_lons = np.append(uniq_lons, 180.0)
    626 
    627                     # Plot interpolated map
    628                     proj = ccrs.PlateCarree()
    629                     fig, ax = plt.subplots(subplot_kw=dict(projection=proj), figsize=(8, 6))
    630                     lon2d, lat2d = np.meshgrid(uniq_lons, uniq_lats)
    631                     lon_ticks = np.arange(-180, 181, 30)
    632                     lat_ticks = np.arange(-90, 91, 30)
    633                     ax.set_xticks(lon_ticks, crs=ccrs.PlateCarree())
    634                     ax.set_yticks(lat_ticks, crs=ccrs.PlateCarree())
    635                     ax.tick_params(
    636                         axis='x', which='major',
    637                         length=4,
    638                         direction='out',
    639                         pad=2,
    640                         labelsize=8
    641                     )
    642                     ax.tick_params(
    643                        axis='y', which='major',
    644                        length=4,
    645                        direction='out',
    646                        pad=2,
    647                        labelsize=8
    648                     )
    649                     cf = ax.contourf(
    650                         lon2d, lat2d, data2d,
    651                         levels=100,
    652                         cmap=colormap,
    653                         transform=proj
    654                     )
    655 
    656                     # Overlay MOLA topography
    657                     overlay_topography(ax, transform=proj, levels=10)
    658 
    659                     # Colorbar & labels
    660                     cbar = fig.colorbar(cf, ax=ax, pad=0.02)
    661                     cbar.set_label(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
    662                     ax.set_title(f"{varname} (interpolated map over physical_points)", fontweight='bold')
    663                     ax.set_xlabel(f"Longitude ({getattr(dataset.variables[lonv], 'units', 'deg')})")
    664                     ax.set_ylabel(f"Latitude ({getattr(dataset.variables[latv], 'units', 'deg')})")
    665 
    666                     # Prompt for polar-stereo views if interactive
    667                     if input("Display polar-stereo views? [y/n]: ").strip().lower() == "y":
    668                         units = getattr(dataset.variables[varname], "units", None)
    669                         plot_polar_views(lon2d, lat2d, data2d, colormap, varname, units)
    670 
    671                     # Prompt for 3D globe view if interactive
    672                     if input("Display 3D globe view? [y/n]: ").strip().lower() == "y":
    673                         units = getattr(dataset.variables[varname], "units", None)
    674                         plot_3D_globe(lon2d, lat2d, data2d, colormap, varname, units)
    675 
    676                     if output_path:
    677                         plt.savefig(output_path, bbox_inches="tight")
    678                         print(f"Saved to {output_path}")
    679                     else:
    680                         plt.show()
    681                     return
    682             # vertical profile?
    683             coord = None
    684             if dname.lower() == "subsurface_layers" and "soildepth" in dataset.variables:
    685                 coord = "soildepth"
    686             elif dname in dataset.variables:
    687                 coord = dname
    688             if coord:
    689                 coords = dataset.variables[coord][:]
    690                 if hasattr(coords, "mask"):
    691                     coords = np.where(coords.mask, np.nan, coords.data)
    692                 plt.figure()
    693                 plt.plot(dslice, coords, marker="o")
    694                 if dname.lower() == "subsurface_layers":
    695                     plt.gca().invert_yaxis()
    696                 plt.xlabel(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
    697                 plt.ylabel(coord + (f" ({dataset.variables[coord].units})" if hasattr(dataset.variables[coord], "units") else ""))
    698                 plt.title(f"{varname} vs {coord}", fontweight='bold')
    699                 if output_path:
    700                     plt.savefig(output_path, bbox_inches="tight")
    701                     print(f"Saved to {output_path}")
    702                 else:
    703                     plt.show()
    704                 return
    705         # generic 1D
    706         plt.figure()
    707         plt.plot(dslice, marker="o")
    708         plt.xlabel("Index")
    709         plt.ylabel(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
    710         plt.title(f"{varname} (1D)", fontweight='bold')
    711         if output_path:
    712             plt.savefig(output_path, bbox_inches="tight")
    713             print(f"Saved to {output_path}")
    714         else:
    715             plt.show()
    716         return
    717 
    718     if dslice.ndim == 2:
    719         lat_idx2 = find_dim_index(dims, LAT_DIMS)
    720         lon_idx2 = find_dim_index(dims, LON_DIMS)
    721 
    722         # Geographic lat×lon slice
    723         if lat_idx2 is not None and lon_idx2 is not None:
    724             latv = find_coord_var(dataset, LAT_DIMS)
    725             lonv = find_coord_var(dataset, LON_DIMS)
    726             lats = dataset.variables[latv][:]
    727             lons = dataset.variables[lonv][:]
    728 
    729             # Correct latitudes order
    730             if lats[0] > lats[-1]:
    731                 lats = lats[::-1]
    732                 dslice = np.flipud(dslice)
    733 
    734             # Handle masked arrays
    735             if hasattr(lats, "mask"):
    736                 lats = np.where(lats.mask, np.nan, lats.data)
    737             if hasattr(lons, "mask"):
    738                 lons = np.where(lons.mask, np.nan, lons.data)
    739 
    740             # Create map projection
    741             proj = ccrs.PlateCarree()
    742             fig, ax = plt.subplots(figsize=(10, 6), subplot_kw=dict(projection=proj))
    743 
    744             # Make meshgrid and plot
    745             lon2d, lat2d = np.meshgrid(lons, lats)
    746             cf = ax.contourf(
    747                 lon2d, lat2d, dslice,
    748                 levels=100,
    749                 cmap=colormap,
    750                 transform=proj
    751             )
    752 
    753             # Overlay topography
    754             overlay_topography(ax, transform=proj, levels=10)
    755 
    756             # Colorbar and labels
    757             lon_ticks = np.arange(-180, 181, 30)
    758             lat_ticks = np.arange(-90, 91, 30)
    759             ax.set_xticks(lon_ticks, crs=ccrs.PlateCarree())
    760             ax.set_yticks(lat_ticks, crs=ccrs.PlateCarree())
    761             ax.tick_params(
    762                 axis='x', which='major',
    763                 length=4,
    764                 direction='out',
    765                 pad=2,
    766                 labelsize=8
    767             )
    768             ax.tick_params(
    769                 axis='y', which='major',
    770                 length=4,
    771                 direction='out',
    772                 pad=2,
    773                 labelsize=8
    774             )
    775             cbar = fig.colorbar(cf, ax=ax, orientation="vertical", pad=0.02)
    776             cbar.set_label(varname + (f" ({dataset.variables[varname].units})"
    777                                       if hasattr(dataset.variables[varname], "units") else ""))
    778             ax.set_title(f"{varname} (lat × lon)", fontweight='bold')
    779             ax.set_xlabel(f"Longitude ({getattr(dataset.variables[lonv], 'units', 'deg')})")
    780             ax.set_ylabel(f"Latitude ({getattr(dataset.variables[latv], 'units', 'deg')})")
    781 
    782             # Prompt for polar-stereo views if interactive
    783             if sys.stdin.isatty() and input("Display polar-stereo views? [y/n]: ").strip().lower() == "y":
    784                 units = getattr(dataset.variables[varname], "units", None)
    785                 plot_polar_views(lon2d, lat2d, dslice, colormap, varname, units)
    786 
    787             # Prompt for 3D globe view if interactive
    788             if sys.stdin.isatty() and input("Display 3D globe view? [y/n]: ").strip().lower() == "y":
    789                 units = getattr(dataset.variables[varname], "units", None)
    790                 plot_3D_globe(lon2d, lat2d, dslice, colormap, varname, units)
    791 
    792             if output_path:
    793                 plt.savefig(output_path, bbox_inches="tight")
    794                 print(f"Saved to {output_path}")
    795             else:
    796                 plt.show()
    797             return
    798 
    799         # Generic 2D
    800         fig, ax = plt.subplots(figsize=(8, 6))
    801         im = ax.imshow(
    802             dslice,
    803             aspect="auto",
    804             interpolation='nearest'
    805         )
    806         x0, x1 = 0, dslice.shape[1] - 1
    807         y0, y1 = 0, dslice.shape[0] - 1
    808         x_centers = np.linspace(x0, x1, dslice.shape[1])
    809         y_centers = np.linspace(y0, y1, dslice.shape[0])
    810         attach_format_coord(ax, dslice, x_centers, y_centers, is_pcolormesh=False)
    811         cbar = fig.colorbar(im, ax=ax, orientation='vertical')
    812         cbar.set_label(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
    813         ax.set_xlabel("Dim 2 index")
    814         ax.set_ylabel("Dim 1 index")
    815         ax.set_title(f"{varname} (2D)")
    816 
    817         if output_path:
    818             fig.savefig(output_path, bbox_inches="tight")
    819             print(f"Saved to {output_path}")
    820         else:
    821             plt.show()
    822         return
    823 
    824     print(f"Error: ndim={dslice.ndim} not supported.")
     674    print(f"Plotting for ndim={nd} not yet supported.")
    825675
    826676
     
    872722        print(f"\nVariable '{varname}' has dimensions:")
    873723        for dim, size in zip(dims, shape):
    874             print(f"  - {dim}: size {size}")
     724            print(f"  > {dim}: size {size}")
    875725        print()
    876726
    877727        # Prepare slicing parameters
    878         time_index = None
    879         alt_index = None
    880         avg = False
    881         extra_indices = {}
    882 
    883         # Time index
    884         t_idx = find_dim_index(dims, TIME_DIMS)
    885         if t_idx is not None:
    886             if shape[t_idx] > 1:
    887                 while True:
    888                     idx = input(f"Enter time index [1–{shape[t_idx]}] (press Enter for all): ").strip()
    889                     if idx == '':
    890                         time_index = None
    891                         break
    892                     if idx.isdigit():
    893                         i = int(idx)
    894                         if 1 <= i <= shape[t_idx]:
    895                             time_index = i - 1
    896                             break
    897                     print("Invalid entry. Please enter a valid number or press Enter.")
    898             else:
    899                 time_index = 0
    900 
    901         # Altitude index
    902         a_idx = find_dim_index(dims, ALT_DIMS)
    903         if a_idx is not None:
    904             if shape[a_idx] > 1:
    905                 while True:
    906                     idx = input(f"Enter altitude index [1–{shape[a_idx]}] (press Enter for all): ").strip()
    907                     if idx == '':
    908                         alt_index = None
    909                         break
    910                     if idx.isdigit():
    911                         i = int(idx)
    912                         if 1 <= i <= shape[a_idx]:
    913                             alt_index = i - 1
    914                             break
    915                     print("Invalid entry. Please enter a valid number or press Enter.")
    916             else:
    917                 alt_index = 0
    918 
    919         # Average over latitude?
    920         lat_idx = find_dim_index(dims, LAT_DIMS)
    921         lon_idx = find_dim_index(dims, LON_DIMS)
    922         if (t_idx is not None and lat_idx is not None and lon_idx is not None and
    923             shape[t_idx] > 1 and shape[lat_idx] > 1 and shape[lon_idx] > 1):
    924             resp = input("Average over latitude and plot lon vs time? [y/n]: ").strip().lower()
    925             avg = (resp == 'y')
    926 
    927         # Other dimensions
    928         for i, dname in enumerate(dims):
    929             if i in (t_idx, a_idx):
    930                 continue
    931             size = shape[i]
    932             if size == 1:
    933                 extra_indices[dname] = 0
    934                 continue
    935             while True:
    936                 idx = input(f"Enter index [1–{size}] for '{dname}' (press Enter for all): ").strip()
    937                 if idx == '':
    938                     # keep all values
    939                     break
    940                 if idx.isdigit():
    941                     j = int(idx)
    942                     if 1 <= j <= size:
    943                         extra_indices[dname] = j - 1
    944                         break
    945                 print("Invalid entry. Please enter a valid number or press Enter.")
     728        selection = get_dimension_indices(ds, varname)
    946729
    947730        # Plot the variable
    948731        plot_variable(
    949             ds, varname,
    950             time_index    = time_index,
    951             alt_index     = alt_index,
    952             colormap      = 'jet',
    953             output_path   = None,
    954             extra_indices = extra_indices,
    955             avg_lat       = avg
     732            ds,
     733            varname,
     734            colormap    = 'jet',
     735            output_path = None,
     736            extra_indices = selection
    956737        )
    957738
  • trunk/LMDZ.MARS/util/display_netcdf.yml

    r3839 r3849  
    77  - _openmp_mutex=4.5=2_gnu
    88  - aiohappyeyeballs=2.6.1=pyhd8ed1ab_0
    9   - aiohttp=3.12.13=py312h178313f_0
     9  - aiohttp=3.12.14=py312h8a5da7c_0
    1010  - aiosignal=1.4.0=pyhd8ed1ab_0
    1111  - alsa-lib=1.2.14=hb9d3cd8_0
     
    1818  - bzip2=1.0.8=h4bc722e_7
    1919  - c-ares=1.34.5=hb9d3cd8_0
    20   - ca-certificates=2025.6.15=hbd8a1cb_0
     20  - ca-certificates=2025.7.9=hbd8a1cb_0
    2121  - cairo=1.18.4=h3394656_0
    2222  - cartopy=0.24.0=py312hf9745cd_0
    23   - certifi=2025.6.15=pyhd8ed1ab_0
     23  - certifi=2025.7.9=pyhd8ed1ab_0
    2424  - cftime=1.6.4=py312hc0a28a1_1
    2525  - contourpy=1.3.2=py312h68727a3_0
     
    4141  - freetype=2.13.3=ha770c72_1
    4242  - fribidi=1.0.10=h36c2ea0_0
    43   - frozenlist=1.6.0=py312hb9e946c_0
     43  - frozenlist=1.7.0=py312h447239a_0
    4444  - gdk-pixbuf=2.42.12=hb9ae30d_0
    4545  - geos=3.13.1=h97f6797_0
     
    5151  - harfbuzz=11.2.1=h3beb420_0
    5252  - hdf4=4.2.15=h2a13503_7
    53   - hdf5=1.14.6=nompi_h2d575fe_101
     53  - hdf5=1.14.6=nompi_h6e4c0c1_102
    5454  - icu=75.1=he02047a_0
    5555  - idna=3.10=pyhd8ed1ab_1
     
    6060  - lame=3.100=h166bdaf_1003
    6161  - lcms2=2.17=h717163a_0
    62   - ld_impl_linux-64=2.44=h1423503_0
     62  - ld_impl_linux-64=2.44=h1423503_1
    6363  - lerc=4.0.0=h0aef613_1
    64   - level-zero=1.23.0=h84d6215_0
     64  - level-zero=1.23.1=h84d6215_0
    6565  - libabseil=20250127.1=cxx17_hbbce691_0
    6666  - libaec=1.1.4=h3f801dc_0
     
    7474  - libcap=2.75=h39aace5_0
    7575  - libcblas=3.9.0=32_he106b2a_openblas
    76   - libclang-cpp20.1=20.1.7=default_h1df26ce_0
    77   - libclang13=20.1.7=default_he06ed0a_0
     76  - libclang-cpp20.1=20.1.8=default_hddf928d_0
     77  - libclang13=20.1.8=default_ha444ac7_0
    7878  - libcups=2.3.3=hb8b1518_5
    7979  - libcurl=8.14.1=h332b0f4_0
     
    105105  - libjpeg-turbo=3.1.0=hb9d3cd8_0
    106106  - liblapack=3.9.0=32_h7ac8fdf_openblas
    107   - libllvm20=20.1.7=he9d0ab4_0
     107  - libllvm20=20.1.8=hecd9e04_0
    108108  - liblzma=5.8.1=hb9d3cd8_2
    109109  - libnetcdf=4.9.2=nompi_h0134ee8_117
     
    134134  - librsvg=2.58.4=he92a37e_3
    135135  - libsndfile=1.2.2=hc60ed4a_1
    136   - libsqlite=3.50.2=h6cd9bfd_0
     136  - libsqlite=3.50.2=hee844dc_2
    137137  - libssh2=1.11.1=hcf80075_0
    138138  - libstdcxx=15.1.0=h8f9b012_3
     
    149149  - libvorbis=1.3.7=h9c3ff4c_0
    150150  - libvpx=1.14.1=hac33072_0
    151   - libwebp-base=1.5.0=h851e524_0
     151  - libwebp-base=1.6.0=hd42ef1d_0
    152152  - libxcb=1.17.0=h8a09558_0
    153153  - libxcrypt=4.4.36=hd590300_1
     
    176176  - openssl=3.5.1=h7b32b05_0
    177177  - packaging=25.0=pyh29332c3_1
    178   - pandas=2.3.0=py312hf9745cd_0
     178  - pandas=2.3.1=py312hf79963d_0
    179179  - pango=1.56.4=hadf4263_0
    180180  - pcre2=10.45=hc749103_0
     
    206206  - six=1.17.0=pyhd8ed1ab_0
    207207  - snappy=1.2.1=h8bd8927_1
    208   - sqlite=3.50.2=hb7a22d2_0
     208  - sqlite=3.50.2=heff268d_2
    209209  - svt-av1=3.0.2=h5888daf_0
    210210  - tbb=2022.1.0=h4ce085d_0
Note: See TracChangeset for help on using the changeset viewer.