#!/usr/bin/env python3 ############################################################## ### Python script to visualize a variable in a NetCDF file ### ############################################################## """ This script can display any numeric variable from a NetCDF file. It supports the following cases: - 1D time series (Time) - 1D vertical profiles (e.g., subsurface_layers) - 2D latitude/longitude map - 2D (Time × another dimension) - Variables with dimension “physical_points” displayed on a 2D map if lat/lon are present - Optionally average over latitude and plot longitude vs. time heatmap - Scalar output (ndim == 0 after slicing) Usage: 1) Command-line mode: python display_netcdf.py /path/to/your_file.nc --variable VAR_NAME \ [--time-index 0] [--alt-index 0] [--cmap viridis] [--avg-lat] \ [--output out.png] [--extra-indices '{"nslope": 1}'] --variable : Name of the variable to visualize. --time-index : Index along the Time dimension (0-based, ignored for purely 1D time series). --alt-index : Index along the altitude dimension (0-based), if present. --cmap : Matplotlib colormap (default: "jet"). --avg-lat : Average over latitude and plot longitude vs. time heatmap. --output : If provided, save the figure to this filename instead of displaying. --extra-indices: JSON string to fix indices for any other dimensions. For any dimension whose name contains "slope", use 1-based numbering here. Example: '{"nslope": 1, "physical_points": 3}' 2) Interactive mode: python display_netcdf.py (The script will prompt for everything, including averaging option.) """ import os import sys import glob import readline import argparse import json import numpy as np import matplotlib.pyplot as plt import matplotlib.tri as mtri from netCDF4 import Dataset # Constants to recognize dimension names TIME_DIMS = ("Time", "time", "time_counter") ALT_DIMS = ("altitude",) LAT_DIMS = ("latitude", "lat") LON_DIMS = ("longitude", "lon") def complete_filename(text, state): """ Readline tab-completion function for filesystem paths. """ if "*" not in text: pattern = text + "*" else: pattern = text matches = glob.glob(os.path.expanduser(pattern)) matches = [m + "/" if os.path.isdir(m) else m for m in matches] try: return matches[state] except IndexError: return None def make_varname_completer(varnames): """ Returns a readline completer function for the given list of variable names. """ def completer(text, state): options = [name for name in varnames if name.startswith(text)] try: return options[state] except IndexError: return None return completer def find_dim_index(dims, candidates): """ Search through dims tuple for any name in candidates. Returns the index if found, else returns None. """ for idx, dim in enumerate(dims): for cand in candidates: if cand.lower() == dim.lower(): return idx return None def find_coord_var(dataset, candidates): """ Among dataset variables, return the first variable whose name matches any candidate. Returns None if none found. """ for name in dataset.variables: for cand in candidates: if cand.lower() == name.lower(): return name return None def plot_variable(dataset, varname, time_index=None, alt_index=None, colormap="jet", output_path=None, extra_indices=None, avg_lat=False): """ Core plotting logic: reads the variable, handles masks, determines dimensionality, and creates the appropriate plot: - 1D time series - 1D profiles or physical_points maps - 2D lat×lon or generic 2D - Time×lon heatmap if avg_lat=True - Scalar printing """ var = dataset.variables[varname] dims = var.dimensions # Read full data try: data_full = var[:] except Exception as e: print(f"Error: Cannot read data for '{varname}': {e}") return if hasattr(data_full, "mask"): data_full = np.where(data_full.mask, np.nan, data_full.data) # Pure 1D time series if len(dims) == 1 and find_dim_index(dims, TIME_DIMS) is not None: time_var = find_coord_var(dataset, TIME_DIMS) tvals = (dataset.variables[time_var][:] if time_var else np.arange(data_full.shape[0])) if hasattr(tvals, "mask"): tvals = np.where(tvals.mask, np.nan, tvals.data) plt.figure() plt.plot(tvals, data_full, marker="o") plt.xlabel(time_var or "Time Index") plt.ylabel(varname + (f" ({var.units})" if hasattr(var, "units") else "")) plt.title(f"{varname} vs {time_var or 'Index'}") if output_path: plt.savefig(output_path, bbox_inches="tight") print(f"Saved to {output_path}") else: plt.show() return # Identify dims t_idx = find_dim_index(dims, TIME_DIMS) lat_idx = find_dim_index(dims, LAT_DIMS) lon_idx = find_dim_index(dims, LON_DIMS) a_idx = find_dim_index(dims, ALT_DIMS) # Average over latitude & plot time × lon heatmap if avg_lat and t_idx is not None and lat_idx is not None and lon_idx is not None: # mean over lat axis data_avg = np.nanmean(data_full, axis=lat_idx) # data_avg shape: (time, lon, ...) # we assume no other unfixed dims # get coordinates time_var = find_coord_var(dataset, TIME_DIMS) lon_var = find_coord_var(dataset, LON_DIMS) tvals = dataset.variables[time_var][:] lons = dataset.variables[lon_var][:] if hasattr(tvals, "mask"): tvals = np.where(tvals.mask, np.nan, tvals.data) if hasattr(lons, "mask"): lons = np.where(lons.mask, np.nan, lons.data) plt.figure(figsize=(10, 6)) plt.pcolormesh(lons, tvals, data_avg, shading="auto", cmap=colormap) plt.xlabel(f"Longitude ({getattr(dataset.variables[lon_var], 'units', 'deg')})") plt.ylabel(time_var) cbar = plt.colorbar() cbar.set_label(varname + (f" ({var.units})" if hasattr(var, "units") else "")) plt.title(f"{varname} averaged over latitude") if output_path: plt.savefig(output_path, bbox_inches="tight") print(f"Saved to {output_path}") else: plt.show() return # Build slicer for other cases slicer = [slice(None)] * len(dims) if t_idx is not None: if time_index is None: print("Error: please supply a time index.") return slicer[t_idx] = time_index if a_idx is not None: if alt_index is None: print("Error: please supply an altitude index.") return slicer[a_idx] = alt_index if extra_indices is None: extra_indices = {} for dn, idx_val in extra_indices.items(): if dn in dims: slicer[dims.index(dn)] = idx_val # Extract slice try: dslice = data_full[tuple(slicer)] except Exception as e: print(f"Error slicing '{varname}': {e}") return # Scalar if np.ndim(dslice) == 0: print(f"Scalar '{varname}': {float(dslice)}") return # 1D: vector, profile, or physical_points if dslice.ndim == 1: rem = [(i, name) for i, name in enumerate(dims) if slicer[i] == slice(None)] if rem: di, dname = rem[0] # physical_points → interpolated map if dname.lower() == "physical_points": latv = find_coord_var(dataset, LAT_DIMS) lonv = find_coord_var(dataset, LON_DIMS) if latv and lonv: lats = dataset.variables[latv][:] lons = dataset.variables[lonv][:] if hasattr(lats, "mask"): lats = np.where(lats.mask, np.nan, lats.data) if hasattr(lons, "mask"): lons = np.where(lons.mask, np.nan, lons.data) triang = mtri.Triangulation(lons, lats) plt.figure(figsize=(8, 6)) cf = plt.tricontourf(triang, dslice, cmap=colormap) cbar = plt.colorbar(cf) cbar.set_label(varname + (f" ({var.units})" if hasattr(var, "units") else "")) plt.xlabel(f"Longitude ({getattr(dataset.variables[lonv], 'units', 'deg')})") plt.ylabel(f"Latitude ({getattr(dataset.variables[latv], 'units', 'deg')})") plt.title(f"{varname} (interpolated map over physical_points)") if output_path: plt.savefig(output_path, bbox_inches="tight") print(f"Saved to {output_path}") else: plt.show() return # vertical profile? coord = None if dname.lower() == "subsurface_layers" and "soildepth" in dataset.variables: coord = "soildepth" elif dname in dataset.variables: coord = dname if coord: coords = dataset.variables[coord][:] if hasattr(coords, "mask"): coords = np.where(coords.mask, np.nan, coords.data) plt.figure() plt.plot(dslice, coords, marker="o") if dname.lower() == "subsurface_layers": plt.gca().invert_yaxis() plt.xlabel(varname + (f" ({var.units})" if hasattr(var, "units") else "")) plt.ylabel(coord + (f" ({dataset.variables[coord].units})" if hasattr(dataset.variables[coord], "units") else "")) plt.title(f"{varname} vs {coord}") if output_path: plt.savefig(output_path, bbox_inches="tight") print(f"Saved to {output_path}") else: plt.show() return # generic 1D plt.figure() plt.plot(dslice, marker="o") plt.xlabel("Index") plt.ylabel(varname + (f" ({var.units})" if hasattr(var, "units") else "")) plt.title(f"{varname} (1D)") if output_path: plt.savefig(output_path, bbox_inches="tight") print(f"Saved to {output_path}") else: plt.show() return # 2D: map or generic if dslice.ndim == 2: lat_idx2 = find_dim_index(dims, LAT_DIMS) lon_idx2 = find_dim_index(dims, LON_DIMS) if lat_idx2 is not None and lon_idx2 is not None: latv = find_coord_var(dataset, LAT_DIMS) lonv = find_coord_var(dataset, LON_DIMS) lats = dataset.variables[latv][:] lons = dataset.variables[lonv][:] if hasattr(lats, "mask"): lats = np.where(lats.mask, np.nan, lats.data) if hasattr(lons, "mask"): lons = np.where(lons.mask, np.nan, lons.data) if lats.ndim == 1 and lons.ndim == 1: lon2d, lat2d = np.meshgrid(lons, lats) else: lat2d, lon2d = lats, lons plt.figure(figsize=(10, 6)) cf = plt.contourf(lon2d, lat2d, dslice, cmap=colormap) cbar = plt.colorbar(cf) cbar.set_label(varname + (f" ({var.units})" if hasattr(var, "units") else "")) plt.xlabel(f"Longitude ({getattr(dataset.variables[lonv], 'units', 'deg')})") plt.ylabel(f"Latitude ({getattr(dataset.variables[latv], 'units', 'deg')})") plt.title(f"{varname} (lat × lon)") if output_path: plt.savefig(output_path, bbox_inches="tight") print(f"Saved to {output_path}") else: plt.show() return # generic 2D plt.figure(figsize=(8, 6)) plt.imshow(dslice, aspect="auto") plt.colorbar(label=varname + (f" ({var.units})" if hasattr(var, "units") else "")) plt.xlabel("Dim 2 index") plt.ylabel("Dim 1 index") plt.title(f"{varname} (2D)") if output_path: plt.savefig(output_path, bbox_inches="tight") print(f"Saved to {output_path}") else: plt.show() return print(f"Error: ndim={dslice.ndim} not supported.") def visualize_variable_interactive(nc_path=None): """ Interactive mode: prompts for file, variable, displays dims, handles special case of pure time series, then guides user through any needed index selections. """ # File selection if nc_path: path = nc_path else: readline.set_completer(complete_filename) readline.parse_and_bind("tab: complete") path = input("Enter path to NetCDF file: ").strip() if not os.path.isfile(path): print(f"Error: '{path}' not found."); return ds = Dataset(path, "r") # Variable selection with autocomplete vars_ = list(ds.variables.keys()) if not vars_: print("No variables found."); ds.close(); return if len(vars_) == 1: var = vars_[0]; print(f"Selected '{var}'") else: print("Available variables:") for v in vars_: print(f" - {v}") readline.set_completer(make_varname_completer(vars_)) readline.parse_and_bind("tab: complete") var = input("Variable name: ").strip() if var not in ds.variables: print("Unknown variable."); ds.close(); return # DISPLAY DIMENSIONS AND SIZES dims = ds.variables[var].dimensions shape = ds.variables[var].shape print(f"\nVariable '{var}' has {len(dims)} dimensions:") for name, size in zip(dims, shape): print(f" - {name}: size {size}") print() # Identify dimension indices t_idx = find_dim_index(dims, TIME_DIMS) lat_idx = find_dim_index(dims, LAT_DIMS) lon_idx = find_dim_index(dims, LON_DIMS) a_idx = find_dim_index(dims, ALT_DIMS) # SPECIAL CASE: time-only series (all others singleton) → plot directly if ( t_idx is not None and shape[t_idx] > 1 and all(shape[i] == 1 for i in range(len(dims)) if i != t_idx) ): print("Detected single-point spatial dims; plotting time series…") # récupérer les valeurs var_obj = ds.variables[var] data = var_obj[:].squeeze() # shape (time,) # temps time_var = find_coord_var(ds, TIME_DIMS) if time_var: tvals = ds.variables[time_var][:] else: tvals = np.arange(data.shape[0]) # masque éventuel if hasattr(data, "mask"): data = np.where(data.mask, np.nan, data.data) if hasattr(tvals, "mask"): tvals = np.where(tvals.mask, np.nan, tvals.data) # tracé plt.figure() plt.plot(tvals, data, marker="o") plt.xlabel(time_var or "Time Index") plt.ylabel(var + (f" ({var_obj.units})" if hasattr(var_obj, "units") else "")) plt.title(f"{var} vs {time_var or 'Index'}") plt.show() ds.close() return # Ask average over latitude only if Time, lat AND lon each >1 avg_lat = False if ( t_idx is not None and shape[t_idx] > 1 and lat_idx is not None and shape[lat_idx] > 1 and lon_idx is not None and shape[lon_idx] > 1 ): u = input("Average over latitude & plot lon vs time? [y/n]: ").strip().lower() avg_lat = (u == "y") # Time index prompt ti = None if t_idx is not None: L = shape[t_idx] if L > 1: while True: u = input(f"Enter time index [0..{L-1}]: ").strip() try: ti = int(u) if 0 <= ti < L: break except: pass print("Invalid.") else: ti = 0; print("Only one time; using 0.") # Altitude index prompt ai = None if a_idx is not None: L = shape[a_idx] if L > 1: while True: u = input(f"Enter altitude index [0..{L-1}]: ").strip() try: ai = int(u) if 0 <= ai < L: break except: pass print("Invalid.") else: ai = 0; print("Only one altitude; using 0.") # Other dims extra = {} for idx, dname in enumerate(dims): if idx in (t_idx, a_idx): continue if dname.lower() in LAT_DIMS + LON_DIMS and shape[idx] == 1: extra[dname] = 0 continue L = shape[idx] if L == 1: extra[dname] = 0 continue if "slope" in dname.lower(): prompt = f"Enter slope number [1..{L}] for '{dname}': " else: prompt = f"Enter index [0..{L-1}] or 'f' to plot '{dname}': " while True: u = input(prompt).strip().lower() if u == "f" and "slope" not in dname.lower(): break try: iv = int(u) if "slope" in dname.lower(): if 1 <= iv <= L: extra[dname] = iv - 1 break else: if 0 <= iv < L: extra[dname] = iv break except: pass print("Invalid.") plot_variable(ds, var, time_index=ti, alt_index=ai, colormap="jet", output_path=None, extra_indices=extra, avg_lat=avg_lat) ds.close() def visualize_variable_cli(nc_file, varname, time_index, alt_index, colormap, output_path, extra_json, avg_lat): """ Command-line mode: visualize directly, parsing the --extra-indices argument (JSON string). """ if not os.path.isfile(nc_file): print(f"Error: '{nc_file}' not found."); return ds = Dataset(nc_file, "r") if varname not in ds.variables: print(f"Variable '{varname}' not in file."); ds.close(); return # DISPLAY DIMENSIONS AND SIZES dims = ds.variables[varname].dimensions shape = ds.variables[varname].shape print(f"\nVariable '{varname}' has {len(dims)} dimensions:") for name, size in zip(dims, shape): print(f" - {name}: size {size}") print() # SPECIAL CASE: time-only → plot directly t_idx = find_dim_index(dims, TIME_DIMS) if ( t_idx is not None and shape[t_idx] > 1 and all(shape[i] == 1 for i in range(len(dims)) if i != t_idx) ): print("Detected single-point spatial dims; plotting time series…") # même logique que ci‑dessus var_obj = ds.variables[varname] data = var_obj[:].squeeze() time_var = find_coord_var(ds, TIME_DIMS) if time_var: tvals = ds.variables[time_var][:] else: tvals = np.arange(data.shape[0]) if hasattr(data, "mask"): data = np.where(data.mask, np.nan, data.data) if hasattr(tvals, "mask"): tvals = np.where(tvals.mask, np.nan, tvals.data) plt.figure() plt.plot(tvals, data, marker="o") plt.xlabel(time_var or "Time Index") plt.ylabel(varname + (f" ({var_obj.units})" if hasattr(var_obj, "units") else "")) plt.title(f"{varname} vs {time_var or 'Index'}") if output_path: plt.savefig(output_path, bbox_inches="tight") print(f"Saved to {output_path}") else: plt.show() ds.close() return # Si --avg-lat mais lat/lon/Time non compatibles → désactive lat_idx = find_dim_index(dims, LAT_DIMS) lon_idx = find_dim_index(dims, LON_DIMS) if avg_lat and not ( t_idx is not None and shape[t_idx] > 1 and lat_idx is not None and shape[lat_idx] > 1 and lon_idx is not None and shape[lon_idx] > 1 ): print("Note: disabling --avg-lat (requires Time, lat & lon each >1).") avg_lat = False # Parse extra indices JSON extra = {} if extra_json: try: parsed = json.loads(extra_json) for k, v in parsed.items(): if isinstance(v, int): if "slope" in k.lower(): extra[k] = v - 1 else: extra[k] = v except: print("Warning: bad extra-indices.") plot_variable(ds, varname, time_index, alt_index, colormap, output_path, extra, avg_lat) ds.close() def main(): parser = argparse.ArgumentParser() parser.add_argument("nc_file", nargs="?", help="NetCDF file (omit for interactive)") parser.add_argument("-v", "--variable", help="Variable name") parser.add_argument("-t", "--time-index", type=int, help="Time index (0-based)") parser.add_argument("-a", "--alt-index", type=int, help="Altitude index (0-based)") parser.add_argument("-c", "--cmap", default="jet", help="Colormap") parser.add_argument("--avg-lat", action="store_true", help="Average over latitude (time × lon heatmap)") parser.add_argument("-o", "--output", help="Save figure path") parser.add_argument("-e", "--extra-indices", help="JSON for other dims") args = parser.parse_args() if args.nc_file and args.variable: visualize_variable_cli( args.nc_file, args.variable, args.time_index, args.alt_index, args.cmap, args.output, args.extra_indices, args.avg_lat ) else: visualize_variable_interactive(args.nc_file) if __name__ == "__main__": main()