Changeset 3798


Ignore:
Timestamp:
Jun 5, 2025, 5:16:44 PM (44 hours ago)
Author:
jbclement
Message:

Mars PCM:
Handle correctly more variables with different types/shapes/dimensions.
JBC

Location:
trunk/LMDZ.MARS
Files:
4 edited

Legend:

Unmodified
Added
Removed
  • trunk/LMDZ.MARS/changelog.txt

    r3783 r3798  
    48494849== 28/05/2025 == JBC
    48504850Big improvement of Python scripts in util folder to analyse/display variables in NetCDF files.
     4851
     4852== 05/06/2025 == JBC
     4853Handle correctly more variables with different types/shapes/dimensions.
  • trunk/LMDZ.MARS/libf/phymars/phyetat0_mod.F90

    r3727 r3798  
    878878        if (abs(latitude(ngrid) - (-pi/2.)) < 1.e-5) perennial_co2ice(ngrid,:) = 10*1.6e3 ! 10m which is convert to kg/m^2
    879879    endif ! not found
    880   else ! no startfiphyle
     880  else ! no startphy_file
    881881    h2o_ice_depth = -1.
    882882    lag_co2_ice = -1.
  • trunk/LMDZ.MARS/util/analyse_netcdf.py

    r3783 r3798  
    126126            continue
    127127
    128         if np.issubdtype(dtype, np.number) or hasattr(variable[:], "mask"):
     128        if np.issubdtype(dtype, np.number):
    129129            analyze_variable(variable)
    130130        else:
  • trunk/LMDZ.MARS/util/display_netcdf.py

    r3783 r3798  
    55
    66"""
    7 This script can display any numeric variable from a NetCDF file on a lat/lon map.
    8 It supports variables of dimension:
    9   - (latitude, longitude)
    10   - (time, latitude, longitude)
    11   - (altitude, latitude, longitude)
    12   - (time, altitude, latitude, longitude)
     7This script can display any numeric variable from a NetCDF file.
     8It supports the following cases:
     9  - 1D time series (Time)
     10  - 1D vertical profiles (e.g., subsurface_layers)
     11  - 2D latitude/longitude map
     12  - 2D (Time × another dimension)
     13  - Variables with dimension “physical_points” as 2D map if lat/lon present,
     14    or generic 2D plot if the remaining axes are spatial
     15  - Scalar output (ndim == 0 after slicing)
    1316
    1417Usage:
    1518  1) Command-line mode:
    16        python display_netcdf.py /path/to/your_file.nc --variable VAR_NAME [--time-index 0] [--alt-index 0] [--cmap viridis] [--output out.png]
    17 
    18   2) Interactive mode through the prompt:
     19       python display_netcdf.py /path/to/your_file.nc --variable VAR_NAME \
     20           [--time-index 0] [--cmap viridis] [--output out.png] \
     21           [--extra-indices '{"dim1": idx1, "dim2": idx2}']
     22
     23    --variable     : Name of the variable to visualize.
     24    --time-index   : Index along the Time dimension (ignored for purely 1D time series).
     25    --alt-index    : Index along the altitude dimension, if present.
     26    --cmap         : Matplotlib colormap for contourf (default: "jet").
     27    --output       : If provided, save the figure to this filename instead of displaying.
     28    --extra-indices: JSON string to fix indices for any dimensions other than Time, lat, lon, or altitude.
     29                     Example: '{"nslope": 0, "physical_points": 2}'
     30                     Omitting a dimension means it remains unfixed (useful to plot a 1D profile).
     31
     32  2) Interactive mode:
    1933       python display_netcdf.py
    20 
    21 The script will:
    22   > Attempt to locate latitude and longitude variables (searching for
    23     names like 'latitude', 'lat', 'longitude', 'lon').
    24   > If there is exactly one variable in the dataset, select it automatically.
    25   > Prompt for time/altitude indices if needed (or accept via CLI).
    26   > Handle masked arrays, converting masked values to NaN.
    27   > Plot with a default colormap ('jet'), adjustable via --cmap.
    28   > Optionally save the figure instead of displaying it interactively.
     34       (The script will prompt for the NetCDF file, the variable, etc.)
    2935"""
    30 
    3136
    3237import os
     
    3540import readline
    3641import argparse
     42import json
    3743import numpy as np
    3844import matplotlib.pyplot as plt
    3945from netCDF4 import Dataset
     46
     47# Constants to recognize dimension names
     48TIME_DIMS = ("Time", "time", "time_counter")
     49ALT_DIMS  = ("altitude",)
     50LAT_DIMS  = ("latitude", "lat")
     51LON_DIMS  = ("longitude", "lon")
    4052
    4153
     
    6981
    7082
    71 # Helper functions to detect common dimension names
    72 TIME_DIMS = ("Time", "time", "time_counter")
    73 ALT_DIMS  = ("altitude",)
    74 LAT_DIMS  = ("latitude", "lat")
    75 LON_DIMS  = ("longitude", "lon")
    76 
    77 
    7883def find_dim_index(dims, candidates):
    7984    """
     
    100105
    101106
    102 # Core plotting helper
    103 
    104 def plot_variable(dataset, varname, time_index=None, alt_index=None, colormap="jet", output_path=None):
    105     """
    106     Extracts the requested slice from the variable and plots it on a lat/lon grid.
     107def plot_variable(dataset, varname, time_index=None, alt_index=None, colormap="jet",
     108                  output_path=None, extra_indices=None):
     109    """
     110    Extracts the requested slice from the variable and plots it according to the data shape:
     111
     112    - Pure 1D time series → time-series plot
     113    - After slicing:
     114        • If data_slice.ndim == 0 → print the scalar value
     115        • If data_slice.ndim == 1:
     116            • If the remaining dimension is “subsurface_layers” (or another known coordinate) → vertical profile
     117            • Else → simple plot vs. index
     118        • If data_slice.ndim == 2:
     119            • If lat/lon exist → contourf map
     120            • Else → imshow generic 2D plot
     121    - If data_slice.ndim is neither 0, 1, nor 2 → error message
    107122
    108123    Parameters
    109124    ----------
    110     dataset    : netCDF4.Dataset object (already open)
    111     varname    : string name of the variable to plot
    112     time_index : int or None (if variable has a time dimension)
    113     alt_index  : int or None (if variable has an altitude dimension)
    114     colormap   : string colormap name (passed to plt.contourf)
    115     output_path: string filepath to save figure, or None to display interactively
     125    dataset       : netCDF4.Dataset object (already open)
     126    varname       : name of the variable to plot
     127    time_index    : int or None (if variable has a time dimension, ignored for pure time series)
     128    alt_index     : int or None (if variable has an altitude dimension)
     129    colormap      : string colormap name (passed to plt.contourf)
     130    output_path   : string filepath to save figure, or None to display interactively
     131    extra_indices : dict { dimension_name (str) : index (int) } for slicing all
     132                     dimensions except Time/lat/lon/alt. If a dimension is not
     133                     included, it remains “slice(None)” (useful for 1D plots).
    116134    """
    117135    var = dataset.variables[varname]
    118     dims = var.dimensions
    119 
     136    dims = var.dimensions  # tuple of dimension names
     137
     138    # Read the full data (could be a masked array)
    120139    try:
    121140        data_full = var[:]
     
    128147        data_full = np.where(data_full.mask, np.nan, data_full.data)
    129148
    130     # Identify dimension indices
     149    # ------------------------------------------------------------------------
     150    # 1) Pure 1D time series (dims == ('Time',) or equivalent)
     151    if len(dims) == 1 and find_dim_index(dims, TIME_DIMS) is not None:
     152        # Plot the time series directly
     153        time_varname = find_coord_var(dataset, TIME_DIMS)
     154        if time_varname:
     155            time_vals = dataset.variables[time_varname][:]
     156            if hasattr(time_vals, "mask"):
     157                time_vals = np.where(time_vals.mask, np.nan, time_vals.data)
     158        else:
     159            time_vals = np.arange(data_full.shape[0])
     160
     161        plt.figure()
     162        plt.plot(time_vals, data_full, marker='o')
     163        xlabel = time_varname if time_varname else "Time Index"
     164        plt.xlabel(xlabel)
     165        ylabel = varname
     166        if hasattr(var, "units"):
     167            ylabel += f" ({var.units})"
     168        plt.ylabel(ylabel)
     169        plt.title(f"{varname} vs {xlabel}")
     170
     171        if output_path:
     172            try:
     173                plt.savefig(output_path, bbox_inches="tight")
     174                print(f"Figure saved to '{output_path}'")
     175            except Exception as e:
     176                print(f"Error saving figure: {e}")
     177        else:
     178            plt.show()
     179        plt.close()
     180        return
     181    # ------------------------------------------------------------------------
     182
     183    # Identify special dimension indices
    131184    t_idx = find_dim_index(dims, TIME_DIMS)
    132185    a_idx = find_dim_index(dims, ALT_DIMS)
     
    134187    lon_idx = find_dim_index(dims, LON_DIMS)
    135188
    136     # Check that lat and lon dims exist
    137     if lat_idx is None or lon_idx is None:
    138         print("Error: Could not find 'latitude'/'lat' and 'longitude'/'lon' dimensions for plotting.")
    139         return
    140 
    141     # Build a slice with defaults
     189    # Build the slicer list
    142190    slicer = [slice(None)] * len(dims)
     191
     192    # Apply slicing on Time and altitude if specified
    143193    if t_idx is not None:
    144194        if time_index is None:
     
    152202        slicer[a_idx] = alt_index
    153203
    154     # Extract the 2D slice
     204    # Apply slicing on all “extra” dimensions (except Time/lat/lon/alt)
     205    if extra_indices is None:
     206        extra_indices = {}
     207    for dim_name, idx_val in extra_indices.items():
     208        if dim_name in dims:
     209            dim_index = dims.index(dim_name)
     210            slicer[dim_index] = idx_val
     211
     212    # Extract the sliced data
    155213    try:
    156214        data_slice = data_full[tuple(slicer)]
     
    159217        return
    160218
    161     # After slicing, data_slice should be 2D
    162     if data_slice.ndim != 2:
    163         print(f"Error: After slicing, data for '{varname}' is not 2D (ndim={data_slice.ndim}).")
    164         return
    165 
    166     nlat, nlon = data_slice.shape
    167     # Handle too-small grid (1x1, 1xN, Nx1)
    168     if nlat < 2 or nlon < 2:
    169         lat_varname = find_coord_var(dataset, LAT_DIMS)
    170         lon_varname = find_coord_var(dataset, LON_DIMS)
    171         if lat_varname and lon_varname:
    172             lat_vals = dataset.variables[lat_varname][:]
    173             lon_vals = dataset.variables[lon_varname][:]
    174             if hasattr(lat_vals, "mask"):
    175                 lat_vals = np.where(lat_vals.mask, np.nan, lat_vals.data)
    176             if hasattr(lon_vals, "mask"):
    177                 lon_vals = np.where(lon_vals.mask, np.nan, lon_vals.data)
    178             # Single point
    179             if nlat == 1 and nlon == 1:
    180                 print(f"Single data point: value={data_slice[0,0]} at (lon={lon_vals[0]}, lat={lat_vals[0]})")
    181                 return
    182             # 1 x N -> plot vs lon
    183             if nlat == 1 and nlon > 1:
    184                 x = lon_vals
    185                 y = data_slice[0, :]
     219    # CASE: After slicing, if data_slice.ndim == 0 → scalar
     220    if np.ndim(data_slice) == 0:
     221        try:
     222            scalar_val = float(data_slice)
     223        except Exception:
     224            scalar_val = data_slice
     225        print(f"Scalar result for '{varname}': {scalar_val}")
     226        return
     227
     228    # CASE: After slicing, if data_slice.ndim == 1 (vertical profile or simple vector)
     229    if data_slice.ndim == 1:
     230        # Identify the remaining dimension
     231        rem_dim = None
     232        for di, dname in enumerate(dims):
     233            if slicer[di] == slice(None):
     234                rem_dim = (di, dname)
     235                break
     236
     237        if rem_dim is not None:
     238            di, dname = rem_dim
     239            coord_var = None
     240
     241            # If it's "subsurface_layers", look for coordinate "soildepth"
     242            if dname.lower() == "subsurface_layers" and "soildepth" in dataset.variables:
     243                coord_var = "soildepth"
     244            # If there is a variable with the same name, use it
     245            elif dname in dataset.variables:
     246                coord_var = dname
     247
     248            if coord_var:
     249                coord_vals = dataset.variables[coord_var][:]
     250                if hasattr(coord_vals, "mask"):
     251                    coord_vals = np.where(coord_vals.mask, np.nan, coord_vals.data)
     252                x = data_slice
     253                y = coord_vals
     254
    186255                plt.figure()
    187256                plt.plot(x, y, marker='o')
    188                 plt.xlabel(f"Longitude ({getattr(dataset.variables[lon_varname], 'units', 'degrees')})")
    189                 plt.ylabel(varname)
    190                 plt.title(f"{varname} (lat={lat_vals[0]})")
    191             # N x 1 -> plot vs lat
    192             elif nlon == 1 and nlat > 1:
    193                 x = lat_vals
    194                 y = data_slice[:, 0]
     257                # Invert Y-axis if it's a depth coordinate
     258                if dname.lower() == "subsurface_layers":
     259                    plt.gca().invert_yaxis()
     260
     261                xlabel = varname
     262                if hasattr(var, "units"):
     263                    xlabel += f" ({var.units})"
     264                plt.xlabel(xlabel)
     265
     266                ylabel = coord_var
     267                if hasattr(dataset.variables[coord_var], "units"):
     268                    ylabel += f" ({dataset.variables[coord_var].units})"
     269                plt.ylabel(ylabel)
     270
     271                plt.title(f"{varname} vs {coord_var}")
     272
     273                if output_path:
     274                    try:
     275                        plt.savefig(output_path, bbox_inches="tight")
     276                        print(f"Figure saved to '{output_path}'")
     277                    except Exception as e:
     278                        print(f"Error saving figure: {e}")
     279                else:
     280                    plt.show()
     281                plt.close()
     282                return
     283            else:
     284                # No known coordinate found → simple plot vs index
    195285                plt.figure()
    196                 plt.plot(x, y, marker='o')
    197                 plt.xlabel(f"Latitude ({getattr(dataset.variables[lat_varname], 'units', 'degrees')})")
    198                 plt.ylabel(varname)
    199                 plt.title(f"{varname} (lon={lon_vals[0]})")
     286                plt.plot(data_slice, marker='o')
     287                plt.xlabel("Index")
     288                ylabel = varname
     289                if hasattr(var, "units"):
     290                    ylabel += f" ({var.units})"
     291                plt.ylabel(ylabel)
     292                plt.title(f"{varname} (1D)")
     293
     294                if output_path:
     295                    try:
     296                        plt.savefig(output_path, bbox_inches="tight")
     297                        print(f"Figure saved to '{output_path}'")
     298                    except Exception as e:
     299                        print(f"Error saving figure: {e}")
     300                else:
     301                    plt.show()
     302                plt.close()
     303                return
     304
     305        else:
     306            # Unable to identify the remaining dimension → error
     307            print(f"Error: After slicing, data for '{varname}' is 1D but remaining dimension is unknown.")
     308            return
     309
     310    # CASE: After slicing, if data_slice.ndim == 2
     311    if data_slice.ndim == 2:
     312        # If lat and lon exist in the original dims, re-find their indices
     313        lat_idx2 = find_dim_index(dims, LAT_DIMS)
     314        lon_idx2 = find_dim_index(dims, LON_DIMS)
     315
     316        if lat_idx2 is not None and lon_idx2 is not None:
     317            # We have a 2D variable on a lat×lon grid
     318            lat_varname = find_coord_var(dataset, LAT_DIMS)
     319            lon_varname = find_coord_var(dataset, LON_DIMS)
     320            if lat_varname is None or lon_varname is None:
     321                print("Error: Could not locate latitude/longitude variables in the dataset.")
     322                return
     323
     324            lat_var = dataset.variables[lat_varname][:]
     325            lon_var = dataset.variables[lon_varname][:]
     326            if hasattr(lat_var, "mask"):
     327                lat_var = np.where(lat_var.mask, np.nan, lat_var.data)
     328            if hasattr(lon_var, "mask"):
     329                lon_var = np.where(lon_var.mask, np.nan, lon_var.data)
     330
     331            # Build 2D coordinate arrays
     332            if lat_var.ndim == 1 and lon_var.ndim == 1:
     333                lon2d, lat2d = np.meshgrid(lon_var, lat_var)
     334            elif lat_var.ndim == 2 and lon_var.ndim == 2:
     335                lat2d, lon2d = lat_var, lon_var
    200336            else:
    201                 print("Unexpected slice shape.")
     337                print("Error: Latitude and longitude must both be either 1D or 2D.")
    202338                return
     339
     340            plt.figure(figsize=(10, 6))
     341            cf = plt.contourf(lon2d, lat2d, data_slice, cmap=colormap)
     342            cbar = plt.colorbar(cf)
     343            if hasattr(var, "units"):
     344                cbar.set_label(f"{varname} ({var.units})")
     345            else:
     346                cbar.set_label(varname)
     347
     348            lon_label = f"Longitude ({getattr(dataset.variables[lon_varname], 'units', 'degrees')})"
     349            lat_label = f"Latitude ({getattr(dataset.variables[lat_varname], 'units', 'degrees')})"
     350            plt.xlabel(lon_label)
     351            plt.ylabel(lat_label)
     352            plt.title(f"{varname} (lat × lon)")
     353
    203354            if output_path:
    204                 plt.savefig(output_path, bbox_inches="tight")
    205                 print(f"Figure saved to '{output_path}'")
     355                try:
     356                    plt.savefig(output_path, bbox_inches="tight")
     357                    print(f"Figure saved to '{output_path}'")
     358                except Exception as e:
     359                    print(f"Error saving figure: {e}")
    206360            else:
    207361                plt.show()
     
    209363            return
    210364
    211     # Locate coordinate variables
    212     lat_varname = find_coord_var(dataset, LAT_DIMS)
    213     lon_varname = find_coord_var(dataset, LON_DIMS)
    214     if lat_varname is None or lon_varname is None:
    215         print("Error: Could not locate latitude/longitude variables in the dataset.")
    216         return
    217 
    218     lat_var = dataset.variables[lat_varname][:]
    219     lon_var = dataset.variables[lon_varname][:]
    220     if hasattr(lat_var, "mask"):
    221         lat_var = np.where(lat_var.mask, np.nan, lat_var.data)
    222     if hasattr(lon_var, "mask"):
    223         lon_var = np.where(lon_var.mask, np.nan, lon_var.data)
    224 
    225     # Build 2D coordinate arrays
    226     if lat_var.ndim == 1 and lon_var.ndim == 1:
    227         lon2d, lat2d = np.meshgrid(lon_var, lat_var)
    228     elif lat_var.ndim == 2 and lon_var.ndim == 2:
    229         lat2d, lon2d = lat_var, lon_var
    230     else:
    231         print("Error: Latitude and longitude variables must both be either 1D or 2D.")
    232         return
    233 
    234     # Retrieve units if available
    235     var_units = getattr(var, "units", "")
    236     lat_units = getattr(dataset.variables[lat_varname], "units", "degrees")
    237     lon_units = getattr(dataset.variables[lon_varname], "units", "degrees")
    238 
    239     # Plot
    240     plt.figure(figsize=(10, 6))
    241     cf = plt.contourf(lon2d, lat2d, data_slice, cmap=colormap)
    242     cbar = plt.colorbar(cf)
    243     if var_units:
    244         cbar.set_label(f"{varname} ({var_units})")
    245     else:
    246         cbar.set_label(varname)
    247 
    248     plt.xlabel(f"Longitude ({lon_units})")
    249     plt.ylabel(f"Latitude ({lat_units})")
    250     plt.title(f"{varname} Visualization")
    251 
    252     if output_path:
    253         try:
    254             plt.savefig(output_path, bbox_inches="tight")
    255             print(f"Figure saved to '{output_path}'")
    256         except Exception as e:
    257             print(f"Error saving figure: {e}")
    258     else:
    259         plt.show()
    260     plt.close()
     365        else:
     366            # No lat/lon → two non-geographical dimensions; plot with imshow
     367            plt.figure(figsize=(8, 6))
     368            plt.imshow(data_slice, aspect='auto')
     369            cb_label = varname
     370            if hasattr(var, "units"):
     371                cb_label += f" ({var.units})"
     372            plt.colorbar(label=cb_label)
     373            plt.xlabel("Dimension 2 Index")
     374            plt.ylabel("Dimension 1 Index")
     375            plt.title(f"{varname} (2D without lat/lon)")
     376
     377            if output_path:
     378                try:
     379                    plt.savefig(output_path, bbox_inches="tight")
     380                    print(f"Figure saved to '{output_path}'")
     381                except Exception as e:
     382                    print(f"Error saving figure: {e}")
     383            else:
     384                plt.show()
     385            plt.close()
     386            return
     387
     388    # CASE: data_slice.ndim is neither 0, 1, nor 2
     389    print(f"Error: After slicing, data for '{varname}' has ndim={data_slice.ndim}, which is not supported.")
     390    return
    261391
    262392
    263393def visualize_variable_interactive(nc_path=None):
    264394    """
    265     Interactive mode: if nc_path is provided, skip prompting for filename.
    266     Otherwise, prompt for filename. Then select variable (automatically if only one), prompt for indices.
     395    Interactive mode: prompts for the NetCDF file if not provided, then for the variable,
     396    then for Time/altitude indices (skipped entirely if variable is purely 1D over Time),
     397    and for each other dimension offers to fix an index or to plot along that dimension (by typing 'f').
     398
     399    If a dimension has length 1, the index 0 is chosen automatically.
    267400    """
    268401    # Determine file path
     
    309442            return
    310443
    311     dims = ds.variables[var_input].dimensions
     444    dims = ds.variables[var_input].dimensions  # tuple of dimension names
     445
     446    # If the variable is purely 1D over Time, plot immediately without asking for time index
     447    if len(dims) == 1 and find_dim_index(dims, TIME_DIMS) is not None:
     448        plot_variable(
     449            dataset=ds,
     450            varname=var_input,
     451            time_index=None,
     452            alt_index=None,
     453            colormap="jet",
     454            output_path=None,
     455            extra_indices=None
     456        )
     457        ds.close()
     458        return
     459
     460    # Otherwise, proceed to prompt for time/altitude and other dimensions
    312461    time_idx = None
    313462    alt_idx = None
    314463
     464    # Prompt for time index if applicable
    315465    t_idx = find_dim_index(dims, TIME_DIMS)
    316466    if t_idx is not None:
    317         length = ds.variables[var_input].shape[t_idx]
    318         if length > 1:
     467        time_len = ds.variables[var_input].shape[t_idx]
     468        if time_len > 1:
    319469            while True:
    320470                try:
    321                     user_t = input(f"Enter time index [0..{length - 1}]: ").strip()
     471                    user_t = input(f"Enter time index [0..{time_len - 1}]: ").strip()
    322472                    if user_t == "":
    323473                        print("No time index entered. Exiting.")
    324474                        ds.close()
    325475                        return
    326                     time_idx = int(user_t)
    327                     if 0 <= time_idx < length:
     476                    ti = int(user_t)
     477                    if 0 <= ti < time_len:
     478                        time_idx = ti
    328479                        break
    329480                except ValueError:
    330481                    pass
    331                 print(f"Invalid index. Enter an integer between 0 and {length - 1}.")
     482                print(f"Invalid index. Enter an integer between 0 and {time_len - 1}.")
    332483        else:
    333484            time_idx = 0
    334485            print("Only one time step available; using index 0.")
    335486
     487    # Prompt for altitude index if applicable
    336488    a_idx = find_dim_index(dims, ALT_DIMS)
    337489    if a_idx is not None:
    338         length = ds.variables[var_input].shape[a_idx]
    339         if length > 1:
     490        alt_len = ds.variables[var_input].shape[a_idx]
     491        if alt_len > 1:
    340492            while True:
    341493                try:
    342                     user_a = input(f"Enter altitude index [0..{length - 1}]: ").strip()
     494                    user_a = input(f"Enter altitude index [0..{alt_len - 1}]: ").strip()
    343495                    if user_a == "":
    344496                        print("No altitude index entered. Exiting.")
    345497                        ds.close()
    346498                        return
    347                     alt_idx = int(user_a)
    348                     if 0 <= alt_idx < length:
     499                    ai = int(user_a)
     500                    if 0 <= ai < alt_len:
     501                        alt_idx = ai
    349502                        break
    350503                except ValueError:
    351504                    pass
    352                 print(f"Invalid index. Enter an integer between 0 and {length - 1}.")
     505                print(f"Invalid index. Enter an integer between 0 and {alt_len - 1}.")
    353506        else:
    354507            alt_idx = 0
    355508            print("Only one altitude level available; using index 0.")
    356509
     510    # Identify other dimensions (excluding Time/lat/lon/alt)
     511    other_dims = []
     512    for idx_dim, dim_name in enumerate(dims):
     513        if idx_dim == t_idx or idx_dim == a_idx:
     514            continue
     515        if dim_name.lower() in (d.lower() for d in LAT_DIMS + LON_DIMS):
     516            continue
     517        other_dims.append((idx_dim, dim_name))
     518
     519    # For each other dimension, ask user to fix an index or type 'f' to plot along that dimension
     520    extra_indices = {}
     521    for idx_dim, dim_name in other_dims:
     522        dim_len = ds.variables[var_input].shape[idx_dim]
     523        if dim_len == 1:
     524            extra_indices[dim_name] = 0
     525            print(f"Dimension '{dim_name}' has length 1; using index 0.")
     526        else:
     527            while True:
     528                prompt = (
     529                    f"Enter index for '{dim_name}' [0..{dim_len - 1}] "
     530                    f"or 'f' to plot along '{dim_name}': "
     531                )
     532                user_i = input(prompt).strip().lower()
     533                if user_i == 'f':
     534                    # Leave this dimension unfixed → no key in extra_indices
     535                    break
     536                if user_i == "":
     537                    print("No index entered. Exiting.")
     538                    ds.close()
     539                    return
     540                try:
     541                    idx_val = int(user_i)
     542                    if 0 <= idx_val < dim_len:
     543                        extra_indices[dim_name] = idx_val
     544                        break
     545                except ValueError:
     546                    pass
     547                print(f"Invalid index. Enter an integer between 0 and {dim_len - 1}, or 'f'.")
     548
     549    # Finally, call plot_variable with collected indices
    357550    plot_variable(
    358551        dataset=ds,
     
    361554        alt_index=alt_idx,
    362555        colormap="jet",
    363         output_path=None
     556        output_path=None,
     557        extra_indices=extra_indices
    364558    )
    365559    ds.close()
    366560
    367561
    368 def visualize_variable_cli(nc_path, varname, time_index, alt_index, colormap, output_path):
    369     """
    370     Command-line mode: directly visualize based on provided arguments.
     562def visualize_variable_cli(nc_path, varname, time_index, alt_index, colormap, output_path, extra_json):
     563    """
     564    Command-line mode: visualize directly, parsing the --extra-indices argument (JSON string).
    371565    """
    372566    if not os.path.isfile(nc_path):
    373567        print(f"Error: '{nc_path}' not found.")
    374568        return
     569
    375570    try:
    376571        ds = Dataset(nc_path, mode="r")
     
    383578        ds.close()
    384579        return
     580
     581    # Parse extra_indices if provided
     582    extra_indices = {}
     583    if extra_json:
     584        try:
     585            parsed = json.loads(extra_json)
     586            if isinstance(parsed, dict):
     587                for k, v in parsed.items():
     588                    if isinstance(k, str) and isinstance(v, int):
     589                        extra_indices[k] = v
     590            else:
     591                print("Warning: --extra-indices is not a JSON object. Ignored.")
     592        except json.JSONDecodeError:
     593            print("Warning: --extra-indices is not valid JSON. Ignored.")
    385594
    386595    plot_variable(
     
    390599        alt_index=alt_index,
    391600        colormap=colormap,
    392         output_path=output_path
     601        output_path=output_path,
     602        extra_indices=extra_indices
    393603    )
    394604    ds.close()
     
    397607def main():
    398608    parser = argparse.ArgumentParser(
    399         description="Visualize a 2D slice of a NetCDF variable on a latitude-longitude map."
     609        description="Visualize a 1D/2D slice of a NetCDF variable on various dimension types."
    400610    )
    401611    parser.add_argument(
     
    411621        "--time-index", "-t",
    412622        type=int,
    413         help="Index along the time dimension (if applicable)."
     623        help="Index on the Time dimension, if applicable (ignored for pure 1D time series)."
    414624    )
    415625    parser.add_argument(
    416626        "--alt-index", "-a",
    417627        type=int,
    418         help="Index along the altitude dimension (if applicable)."
     628        help="Index on the altitude dimension, if applicable."
    419629    )
    420630    parser.add_argument(
     
    425635    parser.add_argument(
    426636        "--output", "-o",
    427         help="If provided, save the plot to this file instead of displaying it."
     637        help="If provided, save the figure to this file instead of displaying it."
     638    )
     639    parser.add_argument(
     640        "--extra-indices", "-e",
     641        help="JSON string to fix indices of dimensions outside Time/lat/lon/alt. "
     642             "Example: '{\"nslope\":0, \"physical_points\":2}'."
    428643    )
    429644
    430645    args = parser.parse_args()
    431646
    432     # If nc_file is provided but variable is missing: ask only for variable
    433     if args.nc_file and not args.variable:
    434         visualize_variable_interactive(nc_path=args.nc_file)
    435     # If either nc_file or variable is missing, run fully interactive
    436     elif not args.nc_file or not args.variable:
    437         visualize_variable_interactive()
    438     else:
     647    # If both nc_file and variable are provided → CLI mode
     648    if args.nc_file and args.variable:
    439649        visualize_variable_cli(
    440650            nc_path=args.nc_file,
     
    443653            alt_index=args.alt_index,
    444654            colormap=args.cmap,
    445             output_path=args.output
     655            output_path=args.output,
     656            extra_json=args.extra_indices
    446657        )
     658    else:
     659        # Otherwise → fully interactive mode
     660        visualize_variable_interactive(nc_path=args.nc_file)
    447661
    448662
Note: See TracChangeset for help on using the changeset viewer.