Changeset 3808 for trunk/LMDZ.MARS/util


Ignore:
Timestamp:
Jun 16, 2025, 4:04:32 PM (25 hours ago)
Author:
jbclement
Message:

Mars PCM:

  • Bug corrections for the Python script displaying variables in a NetCDF file regarding the dimensions + addition of options (for example to average over longitude).
  • Improvement for the Python script analyzing variables in a NetCDF file.

JBC

Location:
trunk/LMDZ.MARS/util
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • trunk/LMDZ.MARS/util/analyse_netcdf.py

    r3798 r3808  
    6262        data = variable[:]
    6363    except Exception as e:
    64         print(f"\nUnable to read variable '{name}': {e}")
     64        print(f"\nError: Unable to read variable '{name}': {e}")
    6565        return
    6666
     
    7676        print(f"  Dimensions: {dims}")
    7777        print(f"  Shape     : {shape}")
    78         print("  Entire variable is NaN or masked.")
     78        print("  \033[91mAnomaly: entire variable is NaN or masked!\033[0m")
    7979        return
    8080
     
    9696    print(f"  Mean value: {data_mean:>12.6e}")
    9797    if has_nan:
    98         print(f"  \033[91mContains NaN values!\033[0m")
     98        print(f"  \033[91mAnomaly: contains NaN values!\033[0m")
    9999    if has_negative:
    100         print(f"  \033[93mWarning: contains negative values!\033[0m")
     100        print(f"  \033[93mCaution: contains negative values!\033[0m")
    101101
    102102def analyze_netcdf_file(nc_path):
     
    123123        except Exception:
    124124            # If reading dtype fails, skip it
    125             print(f"\nSkipping variable with unknown type: {var_name}")
     125            print(f"\nWarning: Skipping variable with unknown type: {var_name}")
    126126            continue
    127127
     
    129129            analyze_variable(variable)
    130130        else:
    131             print(f"\nSkipping non-numeric variable: {var_name}")
     131            print(f"\nWarning: Skipping non-numeric variable: {var_name}")
    132132
    133133    ds.close()
     
    168168if __name__ == "__main__":
    169169    main()
    170 
  • trunk/LMDZ.MARS/util/display_netcdf.py

    r3798 r3808  
    1111  - 2D latitude/longitude map
    1212  - 2D (Time × another dimension)
    13   - Variables with dimension “physical_points” as 2D map if lat/lon present,
    14     or generic 2D plot if the remaining axes are spatial
     13  - Variables with dimension “physical_points” displayed on a 2D map if lat/lon are present
     14  - Optionally average over latitude and plot longitude vs. time heatmap
    1515  - Scalar output (ndim == 0 after slicing)
    1616
     
    1818  1) Command-line mode:
    1919       python display_netcdf.py /path/to/your_file.nc --variable VAR_NAME \
    20            [--time-index 0] [--cmap viridis] [--output out.png] \
    21            [--extra-indices '{"dim1": idx1, "dim2": idx2}']
     20           [--time-index 0] [--alt-index 0] [--cmap viridis] [--avg-lat] \
     21           [--output out.png] [--extra-indices '{"nslope": 1}']
    2222
    2323    --variable     : Name of the variable to visualize.
    24     --time-index   : Index along the Time dimension (ignored for purely 1D time series).
    25     --alt-index    : Index along the altitude dimension, if present.
    26     --cmap         : Matplotlib colormap for contourf (default: "jet").
     24    --time-index   : Index along the Time dimension (0-based, ignored for purely 1D time series).
     25    --alt-index    : Index along the altitude dimension (0-based), if present.
     26    --cmap         : Matplotlib colormap (default: "jet").
     27    --avg-lat      : Average over latitude and plot longitude vs. time heatmap.
    2728    --output       : If provided, save the figure to this filename instead of displaying.
    28     --extra-indices: JSON string to fix indices for any dimensions other than Time, lat, lon, or altitude.
    29                      Example: '{"nslope": 0, "physical_points": 2}'
    30                      Omitting a dimension means it remains unfixed (useful to plot a 1D profile).
     29    --extra-indices: JSON string to fix indices for any other dimensions.
     30                     For any dimension whose name contains "slope", use 1-based numbering here.
     31                     Example: '{"nslope": 1, "physical_points": 3}'
    3132
    3233  2) Interactive mode:
    3334       python display_netcdf.py
    34        (The script will prompt for the NetCDF file, the variable, etc.)
     35       (The script will prompt for everything, including averaging option.)
    3536"""
    3637
     
    4344import numpy as np
    4445import matplotlib.pyplot as plt
     46import matplotlib.tri as mtri
    4547from netCDF4 import Dataset
    4648
     
    105107
    106108
    107 def plot_variable(dataset, varname, time_index=None, alt_index=None, colormap="jet",
    108                   output_path=None, extra_indices=None):
    109     """
    110     Extracts the requested slice from the variable and plots it according to the data shape:
    111 
    112     - Pure 1D time series → time-series plot
    113     - After slicing:
    114         • If data_slice.ndim == 0 → print the scalar value
    115         • If data_slice.ndim == 1:
    116             • If the remaining dimension is “subsurface_layers” (or another known coordinate) → vertical profile
    117             • Else → simple plot vs. index
    118         • If data_slice.ndim == 2:
    119             • If lat/lon exist → contourf map
    120             • Else → imshow generic 2D plot
    121     - If data_slice.ndim is neither 0, 1, nor 2 → error message
    122 
    123     Parameters
    124     ----------
    125     dataset       : netCDF4.Dataset object (already open)
    126     varname       : name of the variable to plot
    127     time_index    : int or None (if variable has a time dimension, ignored for pure time series)
    128     alt_index     : int or None (if variable has an altitude dimension)
    129     colormap      : string colormap name (passed to plt.contourf)
    130     output_path   : string filepath to save figure, or None to display interactively
    131     extra_indices : dict { dimension_name (str) : index (int) } for slicing all
    132                      dimensions except Time/lat/lon/alt. If a dimension is not
    133                      included, it remains “slice(None)” (useful for 1D plots).
     109def plot_variable(dataset, varname, time_index=None, alt_index=None,
     110                  colormap="jet", output_path=None, extra_indices=None,
     111                  avg_lat=False):
     112    """
     113    Core plotting logic: reads the variable, handles masks,
     114    determines dimensionality, and creates the appropriate plot:
     115      - 1D time series
     116      - 1D profiles or physical_points maps
     117      - 2D lat×lon or generic 2D
     118      - Time×lon heatmap if avg_lat=True
     119      - Scalar printing
    134120    """
    135121    var = dataset.variables[varname]
    136     dims = var.dimensions  # tuple of dimension names
    137 
    138     # Read the full data (could be a masked array)
     122    dims = var.dimensions
     123
     124    # Read full data
    139125    try:
    140126        data_full = var[:]
    141127    except Exception as e:
    142         print(f"Error: Cannot read data for variable '{varname}': {e}")
    143         return
    144 
    145     # Convert masked array to NaN
     128        print(f"Error: Cannot read data for '{varname}': {e}")
     129        return
    146130    if hasattr(data_full, "mask"):
    147131        data_full = np.where(data_full.mask, np.nan, data_full.data)
    148132
    149     # ------------------------------------------------------------------------
    150     # 1) Pure 1D time series (dims == ('Time',) or equivalent)
     133    # Pure 1D time series
    151134    if len(dims) == 1 and find_dim_index(dims, TIME_DIMS) is not None:
    152         # Plot the time series directly
    153         time_varname = find_coord_var(dataset, TIME_DIMS)
    154         if time_varname:
    155             time_vals = dataset.variables[time_varname][:]
    156             if hasattr(time_vals, "mask"):
    157                 time_vals = np.where(time_vals.mask, np.nan, time_vals.data)
    158         else:
    159             time_vals = np.arange(data_full.shape[0])
    160 
     135        time_var = find_coord_var(dataset, TIME_DIMS)
     136        tvals = (dataset.variables[time_var][:] if time_var
     137                 else np.arange(data_full.shape[0]))
     138        if hasattr(tvals, "mask"):
     139            tvals = np.where(tvals.mask, np.nan, tvals.data)
    161140        plt.figure()
    162         plt.plot(time_vals, data_full, marker='o')
    163         xlabel = time_varname if time_varname else "Time Index"
    164         plt.xlabel(xlabel)
    165         ylabel = varname
    166         if hasattr(var, "units"):
    167             ylabel += f" ({var.units})"
    168         plt.ylabel(ylabel)
    169         plt.title(f"{varname} vs {xlabel}")
    170 
     141        plt.plot(tvals, data_full, marker="o")
     142        plt.xlabel(time_var or "Time Index")
     143        plt.ylabel(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
     144        plt.title(f"{varname} vs {time_var or 'Index'}")
    171145        if output_path:
    172             try:
    173                 plt.savefig(output_path, bbox_inches="tight")
    174                 print(f"Figure saved to '{output_path}'")
    175             except Exception as e:
    176                 print(f"Error saving figure: {e}")
     146            plt.savefig(output_path, bbox_inches="tight")
     147            print(f"Saved to {output_path}")
    177148        else:
    178149            plt.show()
    179         plt.close()
    180         return
    181     # ------------------------------------------------------------------------
    182 
    183     # Identify special dimension indices
     150        return
     151
     152    # Identify dims
    184153    t_idx = find_dim_index(dims, TIME_DIMS)
    185     a_idx = find_dim_index(dims, ALT_DIMS)
    186154    lat_idx = find_dim_index(dims, LAT_DIMS)
    187155    lon_idx = find_dim_index(dims, LON_DIMS)
    188 
    189     # Build the slicer list
     156    a_idx = find_dim_index(dims, ALT_DIMS)
     157
     158    # Average over latitude & plot time × lon heatmap
     159    if avg_lat and t_idx is not None and lat_idx is not None and lon_idx is not None:
     160        # mean over lat axis
     161        data_avg = np.nanmean(data_full, axis=lat_idx)
     162        # data_avg shape: (time, lon, ...)
     163        # we assume no other unfixed dims
     164        # get coordinates
     165        time_var = find_coord_var(dataset, TIME_DIMS)
     166        lon_var = find_coord_var(dataset, LON_DIMS)
     167        tvals = dataset.variables[time_var][:]
     168        lons = dataset.variables[lon_var][:]
     169        if hasattr(tvals, "mask"):
     170            tvals = np.where(tvals.mask, np.nan, tvals.data)
     171        if hasattr(lons, "mask"):
     172            lons = np.where(lons.mask, np.nan, lons.data)
     173        plt.figure(figsize=(10, 6))
     174        plt.pcolormesh(lons, tvals, data_avg, shading="auto", cmap=colormap)
     175        plt.xlabel(f"Longitude ({getattr(dataset.variables[lon_var], 'units', 'deg')})")
     176        plt.ylabel(time_var)
     177        cbar = plt.colorbar()
     178        cbar.set_label(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
     179        plt.title(f"{varname} averaged over latitude")
     180        if output_path:
     181            plt.savefig(output_path, bbox_inches="tight")
     182            print(f"Saved to {output_path}")
     183        else:
     184            plt.show()
     185        return
     186
     187    # Build slicer for other cases
    190188    slicer = [slice(None)] * len(dims)
    191 
    192     # Apply slicing on Time and altitude if specified
    193189    if t_idx is not None:
    194190        if time_index is None:
    195             print("Error: Variable has a time dimension; please supply a time index.")
     191            print("Error: please supply a time index.")
    196192            return
    197193        slicer[t_idx] = time_index
    198194    if a_idx is not None:
    199195        if alt_index is None:
    200             print("Error: Variable has an altitude dimension; please supply an altitude index.")
     196            print("Error: please supply an altitude index.")
    201197            return
    202198        slicer[a_idx] = alt_index
    203199
    204     # Apply slicing on all “extra” dimensions (except Time/lat/lon/alt)
    205200    if extra_indices is None:
    206201        extra_indices = {}
    207     for dim_name, idx_val in extra_indices.items():
    208         if dim_name in dims:
    209             dim_index = dims.index(dim_name)
    210             slicer[dim_index] = idx_val
    211 
    212     # Extract the sliced data
     202    for dn, idx_val in extra_indices.items():
     203        if dn in dims:
     204            slicer[dims.index(dn)] = idx_val
     205
     206    # Extract slice
    213207    try:
    214         data_slice = data_full[tuple(slicer)]
     208        dslice = data_full[tuple(slicer)]
    215209    except Exception as e:
    216         print(f"Error: Could not slice variable '{varname}': {e}")
    217         return
    218 
    219     # CASE: After slicing, if data_slice.ndim == 0 → scalar
    220     if np.ndim(data_slice) == 0:
    221         try:
    222             scalar_val = float(data_slice)
    223         except Exception:
    224             scalar_val = data_slice
    225         print(f"Scalar result for '{varname}': {scalar_val}")
    226         return
    227 
    228     # CASE: After slicing, if data_slice.ndim == 1 (vertical profile or simple vector)
    229     if data_slice.ndim == 1:
    230         # Identify the remaining dimension
    231         rem_dim = None
    232         for di, dname in enumerate(dims):
    233             if slicer[di] == slice(None):
    234                 rem_dim = (di, dname)
    235                 break
    236 
    237         if rem_dim is not None:
    238             di, dname = rem_dim
    239             coord_var = None
    240 
    241             # If it's "subsurface_layers", look for coordinate "soildepth"
     210        print(f"Error slicing '{varname}': {e}")
     211        return
     212
     213    # Scalar
     214    if np.ndim(dslice) == 0:
     215        print(f"Scalar '{varname}': {float(dslice)}")
     216        return
     217
     218    # 1D: vector, profile, or physical_points
     219    if dslice.ndim == 1:
     220        rem = [(i, name) for i, name in enumerate(dims) if slicer[i] == slice(None)]
     221        if rem:
     222            di, dname = rem[0]
     223            # physical_points → interpolated map
     224            if dname.lower() == "physical_points":
     225                latv = find_coord_var(dataset, LAT_DIMS)
     226                lonv = find_coord_var(dataset, LON_DIMS)
     227                if latv and lonv:
     228                    lats = dataset.variables[latv][:]
     229                    lons = dataset.variables[lonv][:]
     230                    if hasattr(lats, "mask"):
     231                        lats = np.where(lats.mask, np.nan, lats.data)
     232                    if hasattr(lons, "mask"):
     233                        lons = np.where(lons.mask, np.nan, lons.data)
     234                    triang = mtri.Triangulation(lons, lats)
     235                    plt.figure(figsize=(8, 6))
     236                    cf = plt.tricontourf(triang, dslice, cmap=colormap)
     237                    cbar = plt.colorbar(cf)
     238                    cbar.set_label(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
     239                    plt.xlabel(f"Longitude ({getattr(dataset.variables[lonv], 'units', 'deg')})")
     240                    plt.ylabel(f"Latitude ({getattr(dataset.variables[latv], 'units', 'deg')})")
     241                    plt.title(f"{varname} (interpolated map over physical_points)")
     242                    if output_path:
     243                        plt.savefig(output_path, bbox_inches="tight")
     244                        print(f"Saved to {output_path}")
     245                    else:
     246                        plt.show()
     247                    return
     248            # vertical profile?
     249            coord = None
    242250            if dname.lower() == "subsurface_layers" and "soildepth" in dataset.variables:
    243                 coord_var = "soildepth"
    244             # If there is a variable with the same name, use it
     251                coord = "soildepth"
    245252            elif dname in dataset.variables:
    246                 coord_var = dname
    247 
    248             if coord_var:
    249                 coord_vals = dataset.variables[coord_var][:]
    250                 if hasattr(coord_vals, "mask"):
    251                     coord_vals = np.where(coord_vals.mask, np.nan, coord_vals.data)
    252                 x = data_slice
    253                 y = coord_vals
    254 
     253                coord = dname
     254            if coord:
     255                coords = dataset.variables[coord][:]
     256                if hasattr(coords, "mask"):
     257                    coords = np.where(coords.mask, np.nan, coords.data)
    255258                plt.figure()
    256                 plt.plot(x, y, marker='o')
    257                 # Invert Y-axis if it's a depth coordinate
     259                plt.plot(dslice, coords, marker="o")
    258260                if dname.lower() == "subsurface_layers":
    259261                    plt.gca().invert_yaxis()
    260 
    261                 xlabel = varname
    262                 if hasattr(var, "units"):
    263                     xlabel += f" ({var.units})"
    264                 plt.xlabel(xlabel)
    265 
    266                 ylabel = coord_var
    267                 if hasattr(dataset.variables[coord_var], "units"):
    268                     ylabel += f" ({dataset.variables[coord_var].units})"
    269                 plt.ylabel(ylabel)
    270 
    271                 plt.title(f"{varname} vs {coord_var}")
    272 
     262                plt.xlabel(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
     263                plt.ylabel(coord + (f" ({dataset.variables[coord].units})" if hasattr(dataset.variables[coord], "units") else ""))
     264                plt.title(f"{varname} vs {coord}")
    273265                if output_path:
    274                     try:
    275                         plt.savefig(output_path, bbox_inches="tight")
    276                         print(f"Figure saved to '{output_path}'")
    277                     except Exception as e:
    278                         print(f"Error saving figure: {e}")
     266                    plt.savefig(output_path, bbox_inches="tight")
     267                    print(f"Saved to {output_path}")
    279268                else:
    280269                    plt.show()
    281                 plt.close()
    282270                return
    283             else:
    284                 # No known coordinate found → simple plot vs index
    285                 plt.figure()
    286                 plt.plot(data_slice, marker='o')
    287                 plt.xlabel("Index")
    288                 ylabel = varname
    289                 if hasattr(var, "units"):
    290                     ylabel += f" ({var.units})"
    291                 plt.ylabel(ylabel)
    292                 plt.title(f"{varname} (1D)")
    293 
    294                 if output_path:
    295                     try:
    296                         plt.savefig(output_path, bbox_inches="tight")
    297                         print(f"Figure saved to '{output_path}'")
    298                     except Exception as e:
    299                         print(f"Error saving figure: {e}")
    300                 else:
    301                     plt.show()
    302                 plt.close()
    303                 return
    304 
    305         else:
    306             # Unable to identify the remaining dimension → error
    307             print(f"Error: After slicing, data for '{varname}' is 1D but remaining dimension is unknown.")
    308             return
    309 
    310     # CASE: After slicing, if data_slice.ndim == 2
    311     if data_slice.ndim == 2:
    312         # If lat and lon exist in the original dims, re-find their indices
     271        # generic 1D
     272        plt.figure()
     273        plt.plot(dslice, marker="o")
     274        plt.xlabel("Index")
     275        plt.ylabel(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
     276        plt.title(f"{varname} (1D)")
     277        if output_path:
     278            plt.savefig(output_path, bbox_inches="tight")
     279            print(f"Saved to {output_path}")
     280        else:
     281            plt.show()
     282        return
     283
     284    # 2D: map or generic
     285    if dslice.ndim == 2:
    313286        lat_idx2 = find_dim_index(dims, LAT_DIMS)
    314287        lon_idx2 = find_dim_index(dims, LON_DIMS)
    315 
    316288        if lat_idx2 is not None and lon_idx2 is not None:
    317             # We have a 2D variable on a lat×lon grid
    318             lat_varname = find_coord_var(dataset, LAT_DIMS)
    319             lon_varname = find_coord_var(dataset, LON_DIMS)
    320             if lat_varname is None or lon_varname is None:
    321                 print("Error: Could not locate latitude/longitude variables in the dataset.")
    322                 return
    323 
    324             lat_var = dataset.variables[lat_varname][:]
    325             lon_var = dataset.variables[lon_varname][:]
    326             if hasattr(lat_var, "mask"):
    327                 lat_var = np.where(lat_var.mask, np.nan, lat_var.data)
    328             if hasattr(lon_var, "mask"):
    329                 lon_var = np.where(lon_var.mask, np.nan, lon_var.data)
    330 
    331             # Build 2D coordinate arrays
    332             if lat_var.ndim == 1 and lon_var.ndim == 1:
    333                 lon2d, lat2d = np.meshgrid(lon_var, lat_var)
    334             elif lat_var.ndim == 2 and lon_var.ndim == 2:
    335                 lat2d, lon2d = lat_var, lon_var
     289            latv = find_coord_var(dataset, LAT_DIMS)
     290            lonv = find_coord_var(dataset, LON_DIMS)
     291            lats = dataset.variables[latv][:]
     292            lons = dataset.variables[lonv][:]
     293            if hasattr(lats, "mask"):
     294                lats = np.where(lats.mask, np.nan, lats.data)
     295            if hasattr(lons, "mask"):
     296                lons = np.where(lons.mask, np.nan, lons.data)
     297            if lats.ndim == 1 and lons.ndim == 1:
     298                lon2d, lat2d = np.meshgrid(lons, lats)
    336299            else:
    337                 print("Error: Latitude and longitude must both be either 1D or 2D.")
    338                 return
    339 
     300                lat2d, lon2d = lats, lons
    340301            plt.figure(figsize=(10, 6))
    341             cf = plt.contourf(lon2d, lat2d, data_slice, cmap=colormap)
     302            cf = plt.contourf(lon2d, lat2d, dslice, cmap=colormap)
    342303            cbar = plt.colorbar(cf)
    343             if hasattr(var, "units"):
    344                 cbar.set_label(f"{varname} ({var.units})")
    345             else:
    346                 cbar.set_label(varname)
    347 
    348             lon_label = f"Longitude ({getattr(dataset.variables[lon_varname], 'units', 'degrees')})"
    349             lat_label = f"Latitude ({getattr(dataset.variables[lat_varname], 'units', 'degrees')})"
    350             plt.xlabel(lon_label)
    351             plt.ylabel(lat_label)
     304            cbar.set_label(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
     305            plt.xlabel(f"Longitude ({getattr(dataset.variables[lonv], 'units', 'deg')})")
     306            plt.ylabel(f"Latitude ({getattr(dataset.variables[latv], 'units', 'deg')})")
    352307            plt.title(f"{varname} (lat × lon)")
    353 
    354308            if output_path:
    355                 try:
    356                     plt.savefig(output_path, bbox_inches="tight")
    357                     print(f"Figure saved to '{output_path}'")
    358                 except Exception as e:
    359                     print(f"Error saving figure: {e}")
     309                plt.savefig(output_path, bbox_inches="tight")
     310                print(f"Saved to {output_path}")
    360311            else:
    361312                plt.show()
    362             plt.close()
    363313            return
    364 
    365         else:
    366             # No lat/lon → two non-geographical dimensions; plot with imshow
    367             plt.figure(figsize=(8, 6))
    368             plt.imshow(data_slice, aspect='auto')
    369             cb_label = varname
    370             if hasattr(var, "units"):
    371                 cb_label += f" ({var.units})"
    372             plt.colorbar(label=cb_label)
    373             plt.xlabel("Dimension 2 Index")
    374             plt.ylabel("Dimension 1 Index")
    375             plt.title(f"{varname} (2D without lat/lon)")
    376 
    377             if output_path:
    378                 try:
    379                     plt.savefig(output_path, bbox_inches="tight")
    380                     print(f"Figure saved to '{output_path}'")
    381                 except Exception as e:
    382                     print(f"Error saving figure: {e}")
    383             else:
    384                 plt.show()
    385             plt.close()
    386             return
    387 
    388     # CASE: data_slice.ndim is neither 0, 1, nor 2
    389     print(f"Error: After slicing, data for '{varname}' has ndim={data_slice.ndim}, which is not supported.")
    390     return
     314        # generic 2D
     315        plt.figure(figsize=(8, 6))
     316        plt.imshow(dslice, aspect="auto")
     317        plt.colorbar(label=varname + (f" ({var.units})" if hasattr(var, "units") else ""))
     318        plt.xlabel("Dim 2 index")
     319        plt.ylabel("Dim 1 index")
     320        plt.title(f"{varname} (2D)")
     321        if output_path:
     322            plt.savefig(output_path, bbox_inches="tight")
     323            print(f"Saved to {output_path}")
     324        else:
     325            plt.show()
     326        return
     327
     328    print(f"Error: ndim={dslice.ndim} not supported.")
    391329
    392330
    393331def visualize_variable_interactive(nc_path=None):
    394332    """
    395     Interactive mode: prompts for the NetCDF file if not provided, then for the variable,
    396     then for Time/altitude indices (skipped entirely if variable is purely 1D over Time),
    397     and for each other dimension offers to fix an index or to plot along that dimension (by typing 'f').
    398 
    399     If a dimension has length 1, the index 0 is chosen automatically.
    400     """
    401     # Determine file path
     333    Interactive mode: prompts for file, variable, displays dims,
     334    handles special case of pure time series, then guides user
     335    through any needed index selections.
     336    """
     337    # File selection
    402338    if nc_path:
    403         file_input = nc_path
     339        path = nc_path
    404340    else:
    405341        readline.set_completer(complete_filename)
    406342        readline.parse_and_bind("tab: complete")
    407         file_input = input("Enter the path to the NetCDF file: ").strip()
    408 
    409     if not file_input:
    410         print("No file specified. Exiting.")
    411         return
    412     if not os.path.isfile(file_input):
    413         print(f"Error: '{file_input}' not found.")
    414         return
    415 
    416     try:
    417         ds = Dataset(file_input, mode="r")
    418     except Exception as e:
    419         print(f"Error: Unable to open '{file_input}': {e}")
    420         return
    421 
    422     varnames = list(ds.variables.keys())
    423     if not varnames:
    424         print("Error: No variables found in the dataset.")
     343        path = input("Enter path to NetCDF file: ").strip()
     344    if not os.path.isfile(path):
     345        print(f"Error: '{path}' not found."); return
     346    ds = Dataset(path, "r")
     347
     348    # Variable selection with autocomplete
     349    vars_ = list(ds.variables.keys())
     350    if not vars_:
     351        print("No variables found."); ds.close(); return
     352    if len(vars_) == 1:
     353        var = vars_[0]; print(f"Selected '{var}'")
     354    else:
     355        print("Available variables:")
     356        for v in vars_:
     357            print(f"  - {v}")
     358        readline.set_completer(make_varname_completer(vars_))
     359        readline.parse_and_bind("tab: complete")
     360        var = input("Variable name: ").strip()
     361        if var not in ds.variables:
     362            print("Unknown variable."); ds.close(); return
     363
     364    # DISPLAY DIMENSIONS AND SIZES
     365    dims  = ds.variables[var].dimensions
     366    shape = ds.variables[var].shape
     367    print(f"\nVariable '{var}' has {len(dims)} dimensions:")
     368    for name, size in zip(dims, shape):
     369        print(f"  - {name}: size {size}")
     370    print()
     371
     372    # Identify dimension indices
     373    t_idx   = find_dim_index(dims, TIME_DIMS)
     374    lat_idx = find_dim_index(dims, LAT_DIMS)
     375    lon_idx = find_dim_index(dims, LON_DIMS)
     376    a_idx   = find_dim_index(dims, ALT_DIMS)
     377
     378    # SPECIAL CASE: time-only series (all others singleton) → plot directly
     379    if (
     380        t_idx is not None and shape[t_idx] > 1 and
     381        all(shape[i] == 1 for i in range(len(dims)) if i != t_idx)
     382    ):
     383        print("Detected single-point spatial dims; plotting time series…")
     384        # récupérer les valeurs
     385        var_obj = ds.variables[var]
     386        data = var_obj[:].squeeze()   # shape (time,)
     387        # temps
     388        time_var = find_coord_var(ds, TIME_DIMS)
     389        if time_var:
     390            tvals = ds.variables[time_var][:]
     391        else:
     392            tvals = np.arange(data.shape[0])
     393        # masque éventuel
     394        if hasattr(data, "mask"):
     395            data = np.where(data.mask, np.nan, data.data)
     396        if hasattr(tvals, "mask"):
     397            tvals = np.where(tvals.mask, np.nan, tvals.data)
     398        # tracé
     399        plt.figure()
     400        plt.plot(tvals, data, marker="o")
     401        plt.xlabel(time_var or "Time Index")
     402        plt.ylabel(var + (f" ({var_obj.units})" if hasattr(var_obj, "units") else ""))
     403        plt.title(f"{var} vs {time_var or 'Index'}")
     404        plt.show()
    425405        ds.close()
    426406        return
    427407
    428     # Auto-select if only one variable
    429     if len(varnames) == 1:
    430         var_input = varnames[0]
    431         print(f"Automatically selected the only variable: '{var_input}'")
    432     else:
    433         print("\nAvailable variables:")
    434         for name in varnames:
    435             print(f"  - {name}")
    436         print()
    437         readline.set_completer(make_varname_completer(varnames))
    438         var_input = input("Enter the name of the variable to visualize: ").strip()
    439         if var_input not in ds.variables:
    440             print(f"Error: Variable '{var_input}' not found. Exiting.")
    441             ds.close()
    442             return
    443 
    444     dims = ds.variables[var_input].dimensions  # tuple of dimension names
    445 
    446     # If the variable is purely 1D over Time, plot immediately without asking for time index
    447     if len(dims) == 1 and find_dim_index(dims, TIME_DIMS) is not None:
    448         plot_variable(
    449             dataset=ds,
    450             varname=var_input,
    451             time_index=None,
    452             alt_index=None,
    453             colormap="jet",
    454             output_path=None,
    455             extra_indices=None
    456         )
     408    # Ask average over latitude only if Time, lat AND lon each >1
     409    avg_lat = False
     410    if (
     411        t_idx   is not None and shape[t_idx]  > 1 and
     412        lat_idx is not None and shape[lat_idx] > 1 and
     413        lon_idx is not None and shape[lon_idx] > 1
     414    ):
     415        u = input("Average over latitude & plot lon vs time? [y/n]: ").strip().lower()
     416        avg_lat = (u == "y")
     417
     418    # Time index prompt
     419    ti = None
     420    if t_idx is not None:
     421        L = shape[t_idx]
     422        if L > 1:
     423            while True:
     424                u = input(f"Enter time index [0..{L-1}]: ").strip()
     425                try:
     426                    ti = int(u)
     427                    if 0 <= ti < L:
     428                        break
     429                except:
     430                    pass
     431                print("Invalid.")
     432        else:
     433            ti = 0; print("Only one time; using 0.")
     434
     435    # Altitude index prompt
     436    ai = None
     437    if a_idx is not None:
     438        L = shape[a_idx]
     439        if L > 1:
     440            while True:
     441                u = input(f"Enter altitude index [0..{L-1}]: ").strip()
     442                try:
     443                    ai = int(u)
     444                    if 0 <= ai < L:
     445                        break
     446                except:
     447                    pass
     448                print("Invalid.")
     449        else:
     450            ai = 0; print("Only one altitude; using 0.")
     451
     452    # Other dims
     453    extra = {}
     454    for idx, dname in enumerate(dims):
     455        if idx in (t_idx, a_idx):
     456            continue
     457        if dname.lower() in LAT_DIMS + LON_DIMS and shape[idx] == 1:
     458            extra[dname] = 0
     459            continue
     460        L = shape[idx]
     461        if L == 1:
     462            extra[dname] = 0
     463            continue
     464        if "slope" in dname.lower():
     465            prompt = f"Enter slope number [1..{L}] for '{dname}': "
     466        else:
     467            prompt = f"Enter index [0..{L-1}] or 'f' to plot '{dname}': "
     468        while True:
     469            u = input(prompt).strip().lower()
     470            if u == "f" and "slope" not in dname.lower():
     471                break
     472            try:
     473                iv = int(u)
     474                if "slope" in dname.lower():
     475                    if 1 <= iv <= L:
     476                        extra[dname] = iv - 1
     477                        break
     478                else:
     479                    if 0 <= iv < L:
     480                        extra[dname] = iv
     481                        break
     482            except:
     483                pass
     484            print("Invalid.")
     485
     486    plot_variable(ds, var, time_index=ti, alt_index=ai,
     487                  colormap="jet", output_path=None,
     488                  extra_indices=extra, avg_lat=avg_lat)
     489    ds.close()
     490
     491
     492def visualize_variable_cli(nc_file, varname, time_index, alt_index,
     493                           colormap, output_path, extra_json, avg_lat):
     494    """
     495    Command-line mode: visualize directly, parsing the --extra-indices argument (JSON string).
     496    """
     497    if not os.path.isfile(nc_file):
     498        print(f"Error: '{nc_file}' not found."); return
     499    ds = Dataset(nc_file, "r")
     500    if varname not in ds.variables:
     501        print(f"Variable '{varname}' not in file."); ds.close(); return
     502
     503    # DISPLAY DIMENSIONS AND SIZES
     504    dims  = ds.variables[varname].dimensions
     505    shape = ds.variables[varname].shape
     506    print(f"\nVariable '{varname}' has {len(dims)} dimensions:")
     507    for name, size in zip(dims, shape):
     508        print(f"  - {name}: size {size}")
     509    print()
     510
     511    # SPECIAL CASE: time-only → plot directly
     512    t_idx = find_dim_index(dims, TIME_DIMS)
     513    if (
     514        t_idx is not None and shape[t_idx] > 1 and
     515        all(shape[i] == 1 for i in range(len(dims)) if i != t_idx)
     516    ):
     517        print("Detected single-point spatial dims; plotting time series…")
     518        # même logique que ci‑dessus
     519        var_obj = ds.variables[varname]
     520        data = var_obj[:].squeeze()
     521        time_var = find_coord_var(ds, TIME_DIMS)
     522        if time_var:
     523            tvals = ds.variables[time_var][:]
     524        else:
     525            tvals = np.arange(data.shape[0])
     526        if hasattr(data, "mask"):
     527            data = np.where(data.mask, np.nan, data.data)
     528        if hasattr(tvals, "mask"):
     529            tvals = np.where(tvals.mask, np.nan, tvals.data)
     530        plt.figure()
     531        plt.plot(tvals, data, marker="o")
     532        plt.xlabel(time_var or "Time Index")
     533        plt.ylabel(varname + (f" ({var_obj.units})" if hasattr(var_obj, "units") else ""))
     534        plt.title(f"{varname} vs {time_var or 'Index'}")
     535        if output_path:
     536            plt.savefig(output_path, bbox_inches="tight")
     537            print(f"Saved to {output_path}")
     538        else:
     539            plt.show()
    457540        ds.close()
    458541        return
    459542
    460     # Otherwise, proceed to prompt for time/altitude and other dimensions
    461     time_idx = None
    462     alt_idx = None
    463 
    464     # Prompt for time index if applicable
    465     t_idx = find_dim_index(dims, TIME_DIMS)
    466     if t_idx is not None:
    467         time_len = ds.variables[var_input].shape[t_idx]
    468         if time_len > 1:
    469             while True:
    470                 try:
    471                     user_t = input(f"Enter time index [0..{time_len - 1}]: ").strip()
    472                     if user_t == "":
    473                         print("No time index entered. Exiting.")
    474                         ds.close()
    475                         return
    476                     ti = int(user_t)
    477                     if 0 <= ti < time_len:
    478                         time_idx = ti
    479                         break
    480                 except ValueError:
    481                     pass
    482                 print(f"Invalid index. Enter an integer between 0 and {time_len - 1}.")
    483         else:
    484             time_idx = 0
    485             print("Only one time step available; using index 0.")
    486 
    487     # Prompt for altitude index if applicable
    488     a_idx = find_dim_index(dims, ALT_DIMS)
    489     if a_idx is not None:
    490         alt_len = ds.variables[var_input].shape[a_idx]
    491         if alt_len > 1:
    492             while True:
    493                 try:
    494                     user_a = input(f"Enter altitude index [0..{alt_len - 1}]: ").strip()
    495                     if user_a == "":
    496                         print("No altitude index entered. Exiting.")
    497                         ds.close()
    498                         return
    499                     ai = int(user_a)
    500                     if 0 <= ai < alt_len:
    501                         alt_idx = ai
    502                         break
    503                 except ValueError:
    504                     pass
    505                 print(f"Invalid index. Enter an integer between 0 and {alt_len - 1}.")
    506         else:
    507             alt_idx = 0
    508             print("Only one altitude level available; using index 0.")
    509 
    510     # Identify other dimensions (excluding Time/lat/lon/alt)
    511     other_dims = []
    512     for idx_dim, dim_name in enumerate(dims):
    513         if idx_dim == t_idx or idx_dim == a_idx:
    514             continue
    515         if dim_name.lower() in (d.lower() for d in LAT_DIMS + LON_DIMS):
    516             continue
    517         other_dims.append((idx_dim, dim_name))
    518 
    519     # For each other dimension, ask user to fix an index or type 'f' to plot along that dimension
    520     extra_indices = {}
    521     for idx_dim, dim_name in other_dims:
    522         dim_len = ds.variables[var_input].shape[idx_dim]
    523         if dim_len == 1:
    524             extra_indices[dim_name] = 0
    525             print(f"Dimension '{dim_name}' has length 1; using index 0.")
    526         else:
    527             while True:
    528                 prompt = (
    529                     f"Enter index for '{dim_name}' [0..{dim_len - 1}] "
    530                     f"or 'f' to plot along '{dim_name}': "
    531                 )
    532                 user_i = input(prompt).strip().lower()
    533                 if user_i == 'f':
    534                     # Leave this dimension unfixed → no key in extra_indices
    535                     break
    536                 if user_i == "":
    537                     print("No index entered. Exiting.")
    538                     ds.close()
    539                     return
    540                 try:
    541                     idx_val = int(user_i)
    542                     if 0 <= idx_val < dim_len:
    543                         extra_indices[dim_name] = idx_val
    544                         break
    545                 except ValueError:
    546                     pass
    547                 print(f"Invalid index. Enter an integer between 0 and {dim_len - 1}, or 'f'.")
    548 
    549     # Finally, call plot_variable with collected indices
    550     plot_variable(
    551         dataset=ds,
    552         varname=var_input,
    553         time_index=time_idx,
    554         alt_index=alt_idx,
    555         colormap="jet",
    556         output_path=None,
    557         extra_indices=extra_indices
    558     )
    559     ds.close()
    560 
    561 
    562 def visualize_variable_cli(nc_path, varname, time_index, alt_index, colormap, output_path, extra_json):
    563     """
    564     Command-line mode: visualize directly, parsing the --extra-indices argument (JSON string).
    565     """
    566     if not os.path.isfile(nc_path):
    567         print(f"Error: '{nc_path}' not found.")
    568         return
    569 
    570     try:
    571         ds = Dataset(nc_path, mode="r")
    572     except Exception as e:
    573         print(f"Error: Unable to open '{nc_path}': {e}")
    574         return
    575 
    576     if varname not in ds.variables:
    577         print(f"Error: Variable '{varname}' not found in '{nc_path}'.")
    578         ds.close()
    579         return
    580 
    581     # Parse extra_indices if provided
    582     extra_indices = {}
     543    # Si --avg-lat mais lat/lon/Time non compatibles → désactive
     544    lat_idx = find_dim_index(dims, LAT_DIMS)
     545    lon_idx = find_dim_index(dims, LON_DIMS)
     546    if avg_lat and not (
     547        t_idx   is not None and shape[t_idx]  > 1 and
     548        lat_idx is not None and shape[lat_idx] > 1 and
     549        lon_idx is not None and shape[lon_idx] > 1
     550    ):
     551        print("Note: disabling --avg-lat (requires Time, lat & lon each >1).")
     552        avg_lat = False
     553
     554    # Parse extra indices JSON
     555    extra = {}
    583556    if extra_json:
    584557        try:
    585558            parsed = json.loads(extra_json)
    586             if isinstance(parsed, dict):
    587                 for k, v in parsed.items():
    588                     if isinstance(k, str) and isinstance(v, int):
    589                         extra_indices[k] = v
    590             else:
    591                 print("Warning: --extra-indices is not a JSON object. Ignored.")
    592         except json.JSONDecodeError:
    593             print("Warning: --extra-indices is not valid JSON. Ignored.")
    594 
    595     plot_variable(
    596         dataset=ds,
    597         varname=varname,
    598         time_index=time_index,
    599         alt_index=alt_index,
    600         colormap=colormap,
    601         output_path=output_path,
    602         extra_indices=extra_indices
    603     )
     559            for k, v in parsed.items():
     560                if isinstance(v, int):
     561                    if "slope" in k.lower():
     562                        extra[k] = v - 1
     563                    else:
     564                        extra[k] = v
     565        except:
     566            print("Warning: bad extra-indices.")
     567
     568    plot_variable(ds, varname, time_index, alt_index,
     569                  colormap, output_path, extra, avg_lat)
    604570    ds.close()
    605571
    606572
    607573def main():
    608     parser = argparse.ArgumentParser(
    609         description="Visualize a 1D/2D slice of a NetCDF variable on various dimension types."
    610     )
    611     parser.add_argument(
    612         "nc_file",
    613         nargs="?",
    614         help="Path to the NetCDF file (interactive if omitted)."
    615     )
    616     parser.add_argument(
    617         "--variable", "-v",
    618         help="Name of the variable to visualize."
    619     )
    620     parser.add_argument(
    621         "--time-index", "-t",
    622         type=int,
    623         help="Index on the Time dimension, if applicable (ignored for pure 1D time series)."
    624     )
    625     parser.add_argument(
    626         "--alt-index", "-a",
    627         type=int,
    628         help="Index on the altitude dimension, if applicable."
    629     )
    630     parser.add_argument(
    631         "--cmap", "-c",
    632         default="jet",
    633         help="Matplotlib colormap (default: 'jet')."
    634     )
    635     parser.add_argument(
    636         "--output", "-o",
    637         help="If provided, save the figure to this file instead of displaying it."
    638     )
    639     parser.add_argument(
    640         "--extra-indices", "-e",
    641         help="JSON string to fix indices of dimensions outside Time/lat/lon/alt. "
    642              "Example: '{\"nslope\":0, \"physical_points\":2}'."
    643     )
    644 
     574    parser = argparse.ArgumentParser()
     575    parser.add_argument("nc_file", nargs="?", help="NetCDF file (omit for interactive)")
     576    parser.add_argument("-v", "--variable", help="Variable name")
     577    parser.add_argument("-t", "--time-index", type=int, help="Time index (0-based)")
     578    parser.add_argument("-a", "--alt-index", type=int, help="Altitude index (0-based)")
     579    parser.add_argument("-c", "--cmap", default="jet", help="Colormap")
     580    parser.add_argument("--avg-lat", action="store_true",
     581                        help="Average over latitude (time × lon heatmap)")
     582    parser.add_argument("-o", "--output", help="Save figure path")
     583    parser.add_argument("-e", "--extra-indices", help="JSON for other dims")
    645584    args = parser.parse_args()
    646585
    647     # If both nc_file and variable are provided → CLI mode
    648586    if args.nc_file and args.variable:
    649587        visualize_variable_cli(
    650             nc_path=args.nc_file,
    651             varname=args.variable,
    652             time_index=args.time_index,
    653             alt_index=args.alt_index,
    654             colormap=args.cmap,
    655             output_path=args.output,
    656             extra_json=args.extra_indices
     588            args.nc_file, args.variable,
     589            args.time_index, args.alt_index,
     590            args.cmap, args.output,
     591            args.extra_indices, args.avg_lat
    657592        )
    658593    else:
    659         # Otherwise → fully interactive mode
    660         visualize_variable_interactive(nc_path=args.nc_file)
     594        visualize_variable_interactive(args.nc_file)
    661595
    662596
    663597if __name__ == "__main__":
    664598    main()
    665 
Note: See TracChangeset for help on using the changeset viewer.