Ignore:
Timestamp:
Jun 19, 2025, 7:47:13 PM (2 days ago)
Author:
jbclement
Message:

Mars PCM:
Big improvements of script "display_netcdf.py" in util folder. In particular:

  • Addition of a MOLA map which, if present, overlays on 2D map plots. A file "MOLA_1px_per_deg.npy" is provided for Mars in util folder;
  • Addition of the possibility to plot polar‐stereographic views;
  • More interactive (keep asking the user to plot variables until he/she quits for ex);
  • Simplification and generalization of the script to get user answers.

JBC

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/LMDZ.MARS/util/display_netcdf.py

    r3808 r3810  
    1414  - Optionally average over latitude and plot longitude vs. time heatmap
    1515  - Scalar output (ndim == 0 after slicing)
     16  - 2D cross-sections (altitude × latitude or altitude × longitude)
    1617
    1718Usage:
     
    1920       python display_netcdf.py /path/to/your_file.nc --variable VAR_NAME \
    2021           [--time-index 0] [--alt-index 0] [--cmap viridis] [--avg-lat] \
     22           [--slice-lon-index 10] [--slice-lat-index 20] [--show-topo] \
    2123           [--output out.png] [--extra-indices '{"nslope": 1}']
    2224
    23     --variable     : Name of the variable to visualize.
    24     --time-index   : Index along the Time dimension (0-based, ignored for purely 1D time series).
    25     --alt-index    : Index along the altitude dimension (0-based), if present.
    26     --cmap         : Matplotlib colormap (default: "jet").
    27     --avg-lat      : Average over latitude and plot longitude vs. time heatmap.
    28     --output       : If provided, save the figure to this filename instead of displaying.
    29     --extra-indices: JSON string to fix indices for any other dimensions.
    30                      For any dimension whose name contains "slope", use 1-based numbering here.
    31                      Example: '{"nslope": 1, "physical_points": 3}'
     25    --variable         : Name of the variable to visualize.
     26    --time-index       : Index along the Time dimension (0-based, ignored for purely 1D time series).
     27    --alt-index        : Index along the altitude dimension (0-based), if present.
     28    --cmap             : Matplotlib colormap (default: "jet").
     29    --avg-lat          : Average over latitude and plot longitude vs. time heatmap.
     30    --slice-lon-index  : Fixed longitude index for altitude×longitude cross-section.
     31    --slice-lat-index  : Fixed latitude index for altitude×latitude cross-section.
     32    --show-topo        : Overlay MOLA topography on lat/lon maps.
     33    --output           : If provided, save the figure to this filename instead of displaying.
     34    --extra-indices    : JSON string to fix indices for any other dimensions.
     35                         For dimensions with "slope", use 1-based numbering here.
     36                         Example: '{"nslope": 1, "physical_points": 3}'
    3237
    3338  2) Interactive mode:
    3439       python display_netcdf.py
    35        (The script will prompt for everything, including averaging option.)
     40       (The script will prompt for everything, including averaging or slicing options.)
    3641"""
    3742
     
    4550import matplotlib.pyplot as plt
    4651import matplotlib.tri as mtri
     52import matplotlib.path as mpath
     53import cartopy.crs as ccrs
    4754from netCDF4 import Dataset
    4855
    49 # Constants to recognize dimension names
     56# Constants for recognized dimension names
    5057TIME_DIMS = ("Time", "time", "time_counter")
    5158ALT_DIMS  = ("altitude",)
     
    5360LON_DIMS  = ("longitude", "lon")
    5461
     62# Attempt to load MOLA topography
     63try:
     64    MOLA = np.load('MOLA_1px_per_deg.npy')  # shape (nlat, nlon) at 1° per pixel: lat from -90 to 90, lon from 0 to 360
     65    nlat, nlon = MOLA.shape
     66    topo_lats = np.linspace(90 - 0.5, -90 + 0.5, nlat)
     67    topo_lons = np.linspace(-180 + 0.5, 180 - 0.5, nlon)
     68    topo_lon2d, topo_lat2d = np.meshgrid(topo_lons, topo_lats)
     69    topo_loaded = True
     70    print("MOLA topography loaded successfully from 'MOLA_1px_per_deg.npy'.")
     71except Exception as e:
     72    topo_loaded = False
     73    print(f"Warning: failed to load MOLA topography ('MOLA_1px_per_deg.npy'): {e}")
     74
    5575
    5676def complete_filename(text, state):
    5777    """
    58     Readline tab-completion function for filesystem paths.
     78    Tab-completion for filesystem paths.
    5979    """
    6080    if "*" not in text:
     
    7292def make_varname_completer(varnames):
    7393    """
    74     Returns a readline completer function for the given list of variable names.
     94    Returns a readline completer for variable names.
    7595    """
    7696    def completer(text, state):
     
    106126    return None
    107127
     128
     129def overlay_topography(ax, transform, levels=10):
     130    """
     131    Overlay MOLA topography contours onto a given GeoAxes.
     132    """
     133    if not topo_loaded:
     134        return
     135    ax.contour(
     136        topo_lon2d, topo_lat2d, MOLA,
     137        levels=levels,
     138        linewidths=0.5,
     139        colors='black',
     140        transform=transform
     141    )
     142
     143
     144def plot_polar_views(lon2d, lat2d, data2d, colormap, varname, units=None, topo_overlay=True):
     145    """
     146    Plot two polar‐stereographic views (north & south) of the same data.
     147    """
     148    figs = []  # collect figure handles
     149
     150    for pole in ("north", "south"):
     151        # Choose projection and extent for each pole
     152        if pole == "north":
     153            proj = ccrs.NorthPolarStereo(central_longitude=180)
     154            extent = [-180, 180, 60, 90]
     155        else:
     156            proj = ccrs.SouthPolarStereo(central_longitude=180)
     157            extent = [-180, 180, -90, -60]
     158
     159        # Create figure and GeoAxes
     160        fig = plt.figure(figsize=(8, 6))
     161        ax = fig.add_subplot(1, 1, 1, projection=proj, aspect=True)
     162        ax.set_global()
     163        ax.set_extent(extent, ccrs.PlateCarree())
     164
     165        # Draw circular boundary
     166        theta = np.linspace(0, 2 * np.pi, 100)
     167        center, radius = [0.5, 0.5], 0.5
     168        verts = np.vstack([np.sin(theta), np.cos(theta)]).T
     169        circle = mpath.Path(verts * radius + center)
     170        ax.set_boundary(circle, transform=ax.transAxes)
     171
     172        # Add meridians/parallels
     173        gl = ax.gridlines(
     174            draw_labels=True,
     175            color='k',
     176            xlocs=range(-180, 181, 30),
     177            ylocs=range(-90, 91, 10),
     178            linestyle='--',
     179            linewidth=0.5
     180        )
     181        #gl.top_labels = False
     182        #gl.right_labels = False
     183
     184        # Plot data in PlateCarree projection
     185        cf = ax.contourf(
     186            lon2d, lat2d, data2d,
     187            levels=100,
     188            cmap=colormap,
     189            transform=ccrs.PlateCarree()
     190        )
     191
     192        # Optionally overlay MOLA topography
     193        if topo_overlay:
     194            overlay_topography(ax, transform=ccrs.PlateCarree(), levels=20)
     195
     196        # Colorbar and title
     197        cbar = fig.colorbar(cf, ax=ax, pad=0.1)
     198        label = varname + (f" ({units})" if units else "")
     199        cbar.set_label(label)
     200        ax.set_title(f"{varname} — {pole.capitalize()} Pole", pad=50)
     201
     202        figs.append(fig)
     203
     204    # Show both figures
     205    plt.show()
    108206
    109207def plot_variable(dataset, varname, time_index=None, alt_index=None,
     
    158256    # Average over latitude & plot time × lon heatmap
    159257    if avg_lat and t_idx is not None and lat_idx is not None and lon_idx is not None:
    160         # mean over lat axis
     258        # compute mean over lat axis
    161259        data_avg = np.nanmean(data_full, axis=lat_idx)
    162         # data_avg shape: (time, lon, ...)
    163         # we assume no other unfixed dims
    164         # get coordinates
     260        # prepare coordinates
    165261        time_var = find_coord_var(dataset, TIME_DIMS)
    166262        lon_var = find_coord_var(dataset, LON_DIMS)
     
    228324                    lats = dataset.variables[latv][:]
    229325                    lons = dataset.variables[lonv][:]
     326
     327                    # Unmask
    230328                    if hasattr(lats, "mask"):
    231329                        lats = np.where(lats.mask, np.nan, lats.data)
    232330                    if hasattr(lons, "mask"):
    233331                        lons = np.where(lons.mask, np.nan, lons.data)
    234                     triang = mtri.Triangulation(lons, lats)
    235                     plt.figure(figsize=(8, 6))
    236                     cf = plt.tricontourf(triang, dslice, cmap=colormap)
    237                     cbar = plt.colorbar(cf)
     332
     333                    # Convert radians to degrees if needed
     334                    lats_deg = np.round(np.degrees(lats), 6)
     335                    lons_deg = np.round(np.degrees(lons), 6)
     336
     337                    # Build regular grid
     338                    uniq_lats = np.unique(lats_deg)
     339                    uniq_lons = np.unique(lons_deg)
     340                    nlon = len(uniq_lons)
     341
     342                    data2d = []
     343                    for lat_val in uniq_lats:
     344                        mask = lats_deg == lat_val
     345                        slice_vals = dslice[mask]
     346                        lons_at_lat = lons_deg[mask]
     347                        if len(slice_vals) == 1:
     348                            row = np.full(nlon, slice_vals[0])
     349                        else:
     350                            order = np.argsort(lons_at_lat)
     351                            row = np.full(nlon, np.nan)
     352                            row[: len(slice_vals)] = slice_vals[order]
     353                        data2d.append(row)
     354                    data2d = np.array(data2d)
     355
     356                    # Wrap longitude if needed
     357                    if -180.0 in uniq_lons:
     358                        idx = np.where(np.isclose(uniq_lons, -180.0))[0][0]
     359                        data2d = np.hstack([data2d, data2d[:, [idx]]])
     360                        uniq_lons = np.append(uniq_lons, 180.0)
     361
     362                    # Plot interpolated map
     363                    proj = ccrs.PlateCarree()
     364                    fig, ax = plt.subplots(subplot_kw=dict(projection=proj), figsize=(8, 6))
     365                    lon2d, lat2d = np.meshgrid(uniq_lons, uniq_lats)
     366                    lon_ticks = np.arange(-180, 181, 30)
     367                    lat_ticks = np.arange(-90, 91, 30)
     368                    ax.set_xticks(lon_ticks, crs=ccrs.PlateCarree())
     369                    ax.set_yticks(lat_ticks, crs=ccrs.PlateCarree())
     370                    ax.tick_params(
     371                        axis='x', which='major',
     372                        length=4,
     373                        direction='out',
     374                        pad=2,
     375                        labelsize=8
     376                    )
     377                    ax.tick_params(
     378                       axis='y', which='major',
     379                       length=4,
     380                       direction='out',
     381                       pad=2,
     382                       labelsize=8
     383                    )
     384                    cf = ax.contourf(
     385                        lon2d, lat2d, data2d,
     386                        levels=100,
     387                        cmap=colormap,
     388                        transform=proj
     389                    )
     390
     391                    # Overlay MOLA topography
     392                    overlay_topography(ax, transform=proj, levels=10)
     393
     394                    # Colorbar & labels
     395                    cbar = fig.colorbar(cf, ax=ax, pad=0.02)
    238396                    cbar.set_label(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
    239                     plt.xlabel(f"Longitude ({getattr(dataset.variables[lonv], 'units', 'deg')})")
    240                     plt.ylabel(f"Latitude ({getattr(dataset.variables[latv], 'units', 'deg')})")
    241                     plt.title(f"{varname} (interpolated map over physical_points)")
     397                    ax.set_title(f"{varname} (interpolated map over physical_points)")
     398                    ax.set_xlabel(f"Longitude ({getattr(dataset.variables[lonv], 'units', 'deg')})")
     399                    ax.set_ylabel(f"Latitude ({getattr(dataset.variables[latv], 'units', 'deg')})")
     400
     401                    # Prompt for polar-stereo views if interactive
     402                    if sys.stdin.isatty() and input("Display polar-stereo views? [y/n]: ").strip().lower() == "y":
     403                        units = getattr(dataset.variables[varname], "units", None)
     404                        plot_polar_views(lon2d, lat2d, data2d, colormap, varname, units)
     405
    242406                    if output_path:
    243407                        plt.savefig(output_path, bbox_inches="tight")
     
    282446        return
    283447
    284     # 2D: map or generic
    285     if dslice.ndim == 2:
     448    # if dslice.ndim == 2:
    286449        lat_idx2 = find_dim_index(dims, LAT_DIMS)
    287450        lon_idx2 = find_dim_index(dims, LON_DIMS)
     451
     452        # Geographic lat×lon slice
    288453        if lat_idx2 is not None and lon_idx2 is not None:
    289454            latv = find_coord_var(dataset, LAT_DIMS)
     
    291456            lats = dataset.variables[latv][:]
    292457            lons = dataset.variables[lonv][:]
     458
     459            # Handle masked arrays
    293460            if hasattr(lats, "mask"):
    294461                lats = np.where(lats.mask, np.nan, lats.data)
    295462            if hasattr(lons, "mask"):
    296463                lons = np.where(lons.mask, np.nan, lons.data)
    297             if lats.ndim == 1 and lons.ndim == 1:
    298                 lon2d, lat2d = np.meshgrid(lons, lats)
    299             else:
    300                 lat2d, lon2d = lats, lons
    301             plt.figure(figsize=(10, 6))
    302             cf = plt.contourf(lon2d, lat2d, dslice, cmap=colormap)
    303             cbar = plt.colorbar(cf)
    304             cbar.set_label(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
    305             plt.xlabel(f"Longitude ({getattr(dataset.variables[lonv], 'units', 'deg')})")
    306             plt.ylabel(f"Latitude ({getattr(dataset.variables[latv], 'units', 'deg')})")
    307             plt.title(f"{varname} (lat × lon)")
     464
     465            # Create map projection
     466            proj = ccrs.PlateCarree()
     467            fig, ax = plt.subplots(figsize=(10, 6), subplot_kw=dict(projection=proj))
     468
     469            # Make meshgrid and plot
     470            lon2d, lat2d = np.meshgrid(lons, lats)
     471            cf = ax.contourf(
     472                lon2d, lat2d, dslice,
     473                levels=100,
     474                cmap=colormap,
     475                transform=proj
     476            )
     477
     478            # Overlay topography
     479            overlay_topography(ax, transform=proj, levels=10)
     480
     481            # Colorbar and labels
     482            lon_ticks = np.arange(-180, 181, 30)
     483            lat_ticks = np.arange(-90, 91, 30)
     484            ax.set_xticks(lon_ticks, crs=ccrs.PlateCarree())
     485            ax.set_yticks(lat_ticks, crs=ccrs.PlateCarree())
     486            ax.tick_params(
     487                axis='x', which='major',
     488                length=4,
     489                direction='out',
     490                pad=2,
     491                labelsize=8
     492            )
     493            ax.tick_params(
     494                axis='y', which='major',
     495                length=4,
     496                direction='out',
     497                pad=2,
     498                labelsize=8
     499            )
     500            cbar = fig.colorbar(cf, ax=ax, orientation="vertical", pad=0.02)
     501            cbar.set_label(varname + (f" ({dataset.variables[varname].units})"
     502                                      if hasattr(dataset.variables[varname], "units") else ""))
     503            ax.set_title(f"{varname} (lat × lon)")
     504            ax.set_xlabel(f"Longitude ({getattr(dataset.variables[lonv], 'units', 'deg')})")
     505            ax.set_ylabel(f"Latitude ({getattr(dataset.variables[latv], 'units', 'deg')})")
     506
     507            # Prompt for polar-stereo views if interactive
     508            if sys.stdin.isatty() and input("Display polar-stereo views? [y/n]: ").strip().lower() == "y":
     509                units = getattr(dataset.variables[varname], "units", None)
     510                plot_polar_views(lon2d, lat2d, dslice, colormap, varname, units)
     511
    308512            if output_path:
    309513                plt.savefig(output_path, bbox_inches="tight")
     
    312516                plt.show()
    313517            return
    314         # generic 2D
     518
     519        # Generic 2D
    315520        plt.figure(figsize=(8, 6))
    316521        plt.imshow(dslice, aspect="auto")
     
    331536def visualize_variable_interactive(nc_path=None):
    332537    """
    333     Interactive mode: prompts for file, variable, displays dims,
    334     handles special case of pure time series, then guides user
    335     through any needed index selections.
    336     """
    337     # File selection
     538    Interactive loop: keep prompting for variables to plot until user quits.
     539    """
     540    # Open dataset
    338541    if nc_path:
    339542        path = nc_path
     
    342545        readline.parse_and_bind("tab: complete")
    343546        path = input("Enter path to NetCDF file: ").strip()
     547
    344548    if not os.path.isfile(path):
    345         print(f"Error: '{path}' not found."); return
     549        print(f"Error: '{path}' not found.")
     550        return
     551
    346552    ds = Dataset(path, "r")
    347 
    348     # Variable selection with autocomplete
    349     vars_ = list(ds.variables.keys())
    350     if not vars_:
    351         print("No variables found."); ds.close(); return
    352     if len(vars_) == 1:
    353         var = vars_[0]; print(f"Selected '{var}'")
    354     else:
    355         print("Available variables:")
    356         for v in vars_:
    357             print(f"  - {v}")
    358         readline.set_completer(make_varname_completer(vars_))
     553    var_list = list(ds.variables.keys())
     554    if not var_list:
     555        print("No variables found in file.")
     556        ds.close()
     557        return
     558
     559    # Enable interactive mode
     560    plt.ion()
     561
     562    while True:
     563        # Enable tab-completion for variable names
     564        readline.set_completer(make_varname_completer(var_list))
    359565        readline.parse_and_bind("tab: complete")
    360         var = input("Variable name: ").strip()
    361         if var not in ds.variables:
    362             print("Unknown variable."); ds.close(); return
    363 
    364     # DISPLAY DIMENSIONS AND SIZES
    365     dims  = ds.variables[var].dimensions
    366     shape = ds.variables[var].shape
    367     print(f"\nVariable '{var}' has {len(dims)} dimensions:")
    368     for name, size in zip(dims, shape):
    369         print(f"  - {name}: size {size}")
    370     print()
    371 
    372     # Identify dimension indices
    373     t_idx   = find_dim_index(dims, TIME_DIMS)
    374     lat_idx = find_dim_index(dims, LAT_DIMS)
    375     lon_idx = find_dim_index(dims, LON_DIMS)
    376     a_idx   = find_dim_index(dims, ALT_DIMS)
    377 
    378     # SPECIAL CASE: time-only series (all others singleton) → plot directly
    379     if (
    380         t_idx is not None and shape[t_idx] > 1 and
    381         all(shape[i] == 1 for i in range(len(dims)) if i != t_idx)
    382     ):
    383         print("Detected single-point spatial dims; plotting time series…")
    384         # récupérer les valeurs
    385         var_obj = ds.variables[var]
    386         data = var_obj[:].squeeze()   # shape (time,)
    387         # temps
    388         time_var = find_coord_var(ds, TIME_DIMS)
    389         if time_var:
    390             tvals = ds.variables[time_var][:]
    391         else:
    392             tvals = np.arange(data.shape[0])
    393         # masque éventuel
    394         if hasattr(data, "mask"):
    395             data = np.where(data.mask, np.nan, data.data)
    396         if hasattr(tvals, "mask"):
    397             tvals = np.where(tvals.mask, np.nan, tvals.data)
    398         # tracé
    399         plt.figure()
    400         plt.plot(tvals, data, marker="o")
    401         plt.xlabel(time_var or "Time Index")
    402         plt.ylabel(var + (f" ({var_obj.units})" if hasattr(var_obj, "units") else ""))
    403         plt.title(f"{var} vs {time_var or 'Index'}")
    404         plt.show()
    405         ds.close()
    406         return
    407 
    408     # Ask average over latitude only if Time, lat AND lon each >1
    409     avg_lat = False
    410     if (
    411         t_idx   is not None and shape[t_idx]  > 1 and
    412         lat_idx is not None and shape[lat_idx] > 1 and
    413         lon_idx is not None and shape[lon_idx] > 1
    414     ):
    415         u = input("Average over latitude & plot lon vs time? [y/n]: ").strip().lower()
    416         avg_lat = (u == "y")
    417 
    418     # Time index prompt
    419     ti = None
    420     if t_idx is not None:
    421         L = shape[t_idx]
    422         if L > 1:
     566
     567        print("\nAvailable variables:")
     568        for name in var_list:
     569            print(f"  - {name}")
     570        varname = input("\nEnter variable name to plot (or 'q' to quit): ").strip()
     571        if varname.lower() in ("q", "quit", "exit"):
     572            print("Exiting.")
     573            break
     574        if varname not in ds.variables:
     575            print(f"Variable '{varname}' not found. Try again.")
     576            continue
     577
     578        # Display dimensions and size
     579        var = ds.variables[varname]
     580        dims, shape = var.dimensions, var.shape
     581        print(f"\nVariable '{varname}' has dimensions:")
     582        for dim, size in zip(dims, shape):
     583            print(f"  - {dim}: size {size}")
     584        print()
     585
     586        # Prepare slicing parameters
     587        time_index = None
     588        alt_index = None
     589        avg = False
     590        extra_indices = {}
     591
     592        # Time index
     593        t_idx = find_dim_index(dims, TIME_DIMS)
     594        if t_idx is not None:
     595            if shape[t_idx] > 1:
     596                while True:
     597                    idx = input(f"Enter time index [1–{shape[t_idx]}] (press Enter for all): ").strip()
     598                    if idx == '':
     599                        time_index = None
     600                        break
     601                    if idx.isdigit():
     602                        i = int(idx)
     603                        if 1 <= i <= shape[t_idx]:
     604                            time_index = i - 1
     605                            break
     606                    print("Invalid entry. Please enter a valid number or press Enter.")
     607            else:
     608                time_index = 0
     609
     610        # Altitude index
     611        a_idx = find_dim_index(dims, ALT_DIMS)
     612        if a_idx is not None:
     613            if shape[a_idx] > 1:
     614                while True:
     615                    idx = input(f"Enter altitude index [1–{shape[a_idx]}] (press Enter for all): ").strip()
     616                    if idx == '':
     617                        alt_index = None
     618                        break
     619                    if idx.isdigit():
     620                        i = int(idx)
     621                        if 1 <= i <= shape[a_idx]:
     622                            alt_index = i - 1
     623                            break
     624                    print("Invalid entry. Please enter a valid number or press Enter.")
     625            else:
     626                alt_index = 0
     627
     628        # Average over latitude?
     629        lat_idx = find_dim_index(dims, LAT_DIMS)
     630        lon_idx = find_dim_index(dims, LON_DIMS)
     631        if (t_idx is not None and lat_idx is not None and lon_idx is not None and
     632            shape[t_idx] > 1 and shape[lat_idx] > 1 and shape[lon_idx] > 1):
     633            resp = input("Average over latitude and plot lon vs time? [y/n]: ").strip().lower()
     634            avg = (resp == 'y')
     635
     636        # Other dimensions
     637        for i, dname in enumerate(dims):
     638            if i in (t_idx, a_idx):
     639                continue
     640            size = shape[i]
     641            if size == 1:
     642                extra_indices[dname] = 0
     643                continue
    423644            while True:
    424                 u = input(f"Enter time index [0..{L-1}]: ").strip()
    425                 try:
    426                     ti = int(u)
    427                     if 0 <= ti < L:
     645                idx = input(f"Enter index [1–{size}] for '{dname}' (press Enter for all): ").strip()
     646                if idx == '':
     647                    # keep all values
     648                    break
     649                if idx.isdigit():
     650                    j = int(idx)
     651                    if 1 <= j <= size:
     652                        extra_indices[dname] = j - 1
    428653                        break
    429                 except:
    430                     pass
    431                 print("Invalid.")
    432         else:
    433             ti = 0; print("Only one time; using 0.")
    434 
    435     # Altitude index prompt
    436     ai = None
    437     if a_idx is not None:
    438         L = shape[a_idx]
    439         if L > 1:
    440             while True:
    441                 u = input(f"Enter altitude index [0..{L-1}]: ").strip()
    442                 try:
    443                     ai = int(u)
    444                     if 0 <= ai < L:
    445                         break
    446                 except:
    447                     pass
    448                 print("Invalid.")
    449         else:
    450             ai = 0; print("Only one altitude; using 0.")
    451 
    452     # Other dims
    453     extra = {}
    454     for idx, dname in enumerate(dims):
    455         if idx in (t_idx, a_idx):
    456             continue
    457         if dname.lower() in LAT_DIMS + LON_DIMS and shape[idx] == 1:
    458             extra[dname] = 0
    459             continue
    460         L = shape[idx]
    461         if L == 1:
    462             extra[dname] = 0
    463             continue
    464         if "slope" in dname.lower():
    465             prompt = f"Enter slope number [1..{L}] for '{dname}': "
    466         else:
    467             prompt = f"Enter index [0..{L-1}] or 'f' to plot '{dname}': "
    468         while True:
    469             u = input(prompt).strip().lower()
    470             if u == "f" and "slope" not in dname.lower():
    471                 break
    472             try:
    473                 iv = int(u)
    474                 if "slope" in dname.lower():
    475                     if 1 <= iv <= L:
    476                         extra[dname] = iv - 1
    477                         break
    478                 else:
    479                     if 0 <= iv < L:
    480                         extra[dname] = iv
    481                         break
    482             except:
    483                 pass
    484             print("Invalid.")
    485 
    486     plot_variable(ds, var, time_index=ti, alt_index=ai,
    487                   colormap="jet", output_path=None,
    488                   extra_indices=extra, avg_lat=avg_lat)
     654                print("Invalid entry. Please enter a valid number or press Enter.")
     655
     656        # Plot the variable
     657        plot_variable(
     658            ds, varname,
     659            time_index    = time_index,
     660            alt_index     = alt_index,
     661            colormap      = 'jet',
     662            output_path   = None,
     663            extra_indices = extra_indices,
     664            avg_lat       = avg
     665        )
     666
    489667    ds.close()
    490668
     
    496674    """
    497675    if not os.path.isfile(nc_file):
    498         print(f"Error: '{nc_file}' not found."); return
     676        print(f"Error: '{nc_file}' not found.")
     677        return
    499678    ds = Dataset(nc_file, "r")
    500679    if varname not in ds.variables:
    501         print(f"Variable '{varname}' not in file."); ds.close(); return
    502 
    503     # DISPLAY DIMENSIONS AND SIZES
     680        print(f"Variable '{varname}' not in file.")
     681        ds.close()
     682        return
     683
     684    # Display dimensions and size
    504685    dims  = ds.variables[varname].dimensions
    505686    shape = ds.variables[varname].shape
     
    509690    print()
    510691
    511     # SPECIAL CASE: time-only → plot directly
     692    # Special case: time-only → plot directly
    512693    t_idx = find_dim_index(dims, TIME_DIMS)
    513694    if (
     
    516697    ):
    517698        print("Detected single-point spatial dims; plotting time series…")
    518         # même logique que ci‑dessus
    519699        var_obj = ds.variables[varname]
    520700        data = var_obj[:].squeeze()
     
    541721        return
    542722
    543     # Si --avg-lat mais lat/lon/Time non compatibles → désactive
     723    # if --avg-lat but lat/lon/Time not compatible → disable
    544724    lat_idx = find_dim_index(dims, LAT_DIMS)
    545725    lon_idx = find_dim_index(dims, LON_DIMS)
     
    597777if __name__ == "__main__":
    598778    main()
     779
Note: See TracChangeset for help on using the changeset viewer.