source: trunk/LMDZ.MARS/util/display_netcdf.py @ 3807

Last change on this file since 3807 was 3798, checked in by jbclement, 2 weeks ago

Mars PCM:
Handle correctly more variables with different types/shapes/dimensions.
JBC

  • Property svn:executable set to *
File size: 23.4 KB
Line 
1#!/usr/bin/env python3
2##############################################################
3### Python script to visualize a variable in a NetCDF file ###
4##############################################################
5
6"""
7This script can display any numeric variable from a NetCDF file.
8It supports the following cases:
9  - 1D time series (Time)
10  - 1D vertical profiles (e.g., subsurface_layers)
11  - 2D latitude/longitude map
12  - 2D (Time × another dimension)
13  - Variables with dimension “physical_points” as 2D map if lat/lon present,
14    or generic 2D plot if the remaining axes are spatial
15  - Scalar output (ndim == 0 after slicing)
16
17Usage:
18  1) Command-line mode:
19       python display_netcdf.py /path/to/your_file.nc --variable VAR_NAME \
20           [--time-index 0] [--cmap viridis] [--output out.png] \
21           [--extra-indices '{"dim1": idx1, "dim2": idx2}']
22
23    --variable     : Name of the variable to visualize.
24    --time-index   : Index along the Time dimension (ignored for purely 1D time series).
25    --alt-index    : Index along the altitude dimension, if present.
26    --cmap         : Matplotlib colormap for contourf (default: "jet").
27    --output       : If provided, save the figure to this filename instead of displaying.
28    --extra-indices: JSON string to fix indices for any dimensions other than Time, lat, lon, or altitude.
29                     Example: '{"nslope": 0, "physical_points": 2}'
30                     Omitting a dimension means it remains unfixed (useful to plot a 1D profile).
31
32  2) Interactive mode:
33       python display_netcdf.py
34       (The script will prompt for the NetCDF file, the variable, etc.)
35"""
36
37import os
38import sys
39import glob
40import readline
41import argparse
42import json
43import numpy as np
44import matplotlib.pyplot as plt
45from netCDF4 import Dataset
46
47# Constants to recognize dimension names
48TIME_DIMS = ("Time", "time", "time_counter")
49ALT_DIMS  = ("altitude",)
50LAT_DIMS  = ("latitude", "lat")
51LON_DIMS  = ("longitude", "lon")
52
53
54def complete_filename(text, state):
55    """
56    Readline tab-completion function for filesystem paths.
57    """
58    if "*" not in text:
59        pattern = text + "*"
60    else:
61        pattern = text
62    matches = glob.glob(os.path.expanduser(pattern))
63    matches = [m + "/" if os.path.isdir(m) else m for m in matches]
64    try:
65        return matches[state]
66    except IndexError:
67        return None
68
69
70def make_varname_completer(varnames):
71    """
72    Returns a readline completer function for the given list of variable names.
73    """
74    def completer(text, state):
75        options = [name for name in varnames if name.startswith(text)]
76        try:
77            return options[state]
78        except IndexError:
79            return None
80    return completer
81
82
83def find_dim_index(dims, candidates):
84    """
85    Search through dims tuple for any name in candidates.
86    Returns the index if found, else returns None.
87    """
88    for idx, dim in enumerate(dims):
89        for cand in candidates:
90            if cand.lower() == dim.lower():
91                return idx
92    return None
93
94
95def find_coord_var(dataset, candidates):
96    """
97    Among dataset variables, return the first variable whose name matches any candidate.
98    Returns None if none found.
99    """
100    for name in dataset.variables:
101        for cand in candidates:
102            if cand.lower() == name.lower():
103                return name
104    return None
105
106
107def plot_variable(dataset, varname, time_index=None, alt_index=None, colormap="jet",
108                  output_path=None, extra_indices=None):
109    """
110    Extracts the requested slice from the variable and plots it according to the data shape:
111
112    - Pure 1D time series → time-series plot
113    - After slicing:
114        • If data_slice.ndim == 0 → print the scalar value
115        • If data_slice.ndim == 1:
116            • If the remaining dimension is “subsurface_layers” (or another known coordinate) → vertical profile
117            • Else → simple plot vs. index
118        • If data_slice.ndim == 2:
119            • If lat/lon exist → contourf map
120            • Else → imshow generic 2D plot
121    - If data_slice.ndim is neither 0, 1, nor 2 → error message
122
123    Parameters
124    ----------
125    dataset       : netCDF4.Dataset object (already open)
126    varname       : name of the variable to plot
127    time_index    : int or None (if variable has a time dimension, ignored for pure time series)
128    alt_index     : int or None (if variable has an altitude dimension)
129    colormap      : string colormap name (passed to plt.contourf)
130    output_path   : string filepath to save figure, or None to display interactively
131    extra_indices : dict { dimension_name (str) : index (int) } for slicing all
132                     dimensions except Time/lat/lon/alt. If a dimension is not
133                     included, it remains “slice(None)” (useful for 1D plots).
134    """
135    var = dataset.variables[varname]
136    dims = var.dimensions  # tuple of dimension names
137
138    # Read the full data (could be a masked array)
139    try:
140        data_full = var[:]
141    except Exception as e:
142        print(f"Error: Cannot read data for variable '{varname}': {e}")
143        return
144
145    # Convert masked array to NaN
146    if hasattr(data_full, "mask"):
147        data_full = np.where(data_full.mask, np.nan, data_full.data)
148
149    # ------------------------------------------------------------------------
150    # 1) Pure 1D time series (dims == ('Time',) or equivalent)
151    if len(dims) == 1 and find_dim_index(dims, TIME_DIMS) is not None:
152        # Plot the time series directly
153        time_varname = find_coord_var(dataset, TIME_DIMS)
154        if time_varname:
155            time_vals = dataset.variables[time_varname][:]
156            if hasattr(time_vals, "mask"):
157                time_vals = np.where(time_vals.mask, np.nan, time_vals.data)
158        else:
159            time_vals = np.arange(data_full.shape[0])
160
161        plt.figure()
162        plt.plot(time_vals, data_full, marker='o')
163        xlabel = time_varname if time_varname else "Time Index"
164        plt.xlabel(xlabel)
165        ylabel = varname
166        if hasattr(var, "units"):
167            ylabel += f" ({var.units})"
168        plt.ylabel(ylabel)
169        plt.title(f"{varname} vs {xlabel}")
170
171        if output_path:
172            try:
173                plt.savefig(output_path, bbox_inches="tight")
174                print(f"Figure saved to '{output_path}'")
175            except Exception as e:
176                print(f"Error saving figure: {e}")
177        else:
178            plt.show()
179        plt.close()
180        return
181    # ------------------------------------------------------------------------
182
183    # Identify special dimension indices
184    t_idx = find_dim_index(dims, TIME_DIMS)
185    a_idx = find_dim_index(dims, ALT_DIMS)
186    lat_idx = find_dim_index(dims, LAT_DIMS)
187    lon_idx = find_dim_index(dims, LON_DIMS)
188
189    # Build the slicer list
190    slicer = [slice(None)] * len(dims)
191
192    # Apply slicing on Time and altitude if specified
193    if t_idx is not None:
194        if time_index is None:
195            print("Error: Variable has a time dimension; please supply a time index.")
196            return
197        slicer[t_idx] = time_index
198    if a_idx is not None:
199        if alt_index is None:
200            print("Error: Variable has an altitude dimension; please supply an altitude index.")
201            return
202        slicer[a_idx] = alt_index
203
204    # Apply slicing on all “extra” dimensions (except Time/lat/lon/alt)
205    if extra_indices is None:
206        extra_indices = {}
207    for dim_name, idx_val in extra_indices.items():
208        if dim_name in dims:
209            dim_index = dims.index(dim_name)
210            slicer[dim_index] = idx_val
211
212    # Extract the sliced data
213    try:
214        data_slice = data_full[tuple(slicer)]
215    except Exception as e:
216        print(f"Error: Could not slice variable '{varname}': {e}")
217        return
218
219    # CASE: After slicing, if data_slice.ndim == 0 → scalar
220    if np.ndim(data_slice) == 0:
221        try:
222            scalar_val = float(data_slice)
223        except Exception:
224            scalar_val = data_slice
225        print(f"Scalar result for '{varname}': {scalar_val}")
226        return
227
228    # CASE: After slicing, if data_slice.ndim == 1 (vertical profile or simple vector)
229    if data_slice.ndim == 1:
230        # Identify the remaining dimension
231        rem_dim = None
232        for di, dname in enumerate(dims):
233            if slicer[di] == slice(None):
234                rem_dim = (di, dname)
235                break
236
237        if rem_dim is not None:
238            di, dname = rem_dim
239            coord_var = None
240
241            # If it's "subsurface_layers", look for coordinate "soildepth"
242            if dname.lower() == "subsurface_layers" and "soildepth" in dataset.variables:
243                coord_var = "soildepth"
244            # If there is a variable with the same name, use it
245            elif dname in dataset.variables:
246                coord_var = dname
247
248            if coord_var:
249                coord_vals = dataset.variables[coord_var][:]
250                if hasattr(coord_vals, "mask"):
251                    coord_vals = np.where(coord_vals.mask, np.nan, coord_vals.data)
252                x = data_slice
253                y = coord_vals
254
255                plt.figure()
256                plt.plot(x, y, marker='o')
257                # Invert Y-axis if it's a depth coordinate
258                if dname.lower() == "subsurface_layers":
259                    plt.gca().invert_yaxis()
260
261                xlabel = varname
262                if hasattr(var, "units"):
263                    xlabel += f" ({var.units})"
264                plt.xlabel(xlabel)
265
266                ylabel = coord_var
267                if hasattr(dataset.variables[coord_var], "units"):
268                    ylabel += f" ({dataset.variables[coord_var].units})"
269                plt.ylabel(ylabel)
270
271                plt.title(f"{varname} vs {coord_var}")
272
273                if output_path:
274                    try:
275                        plt.savefig(output_path, bbox_inches="tight")
276                        print(f"Figure saved to '{output_path}'")
277                    except Exception as e:
278                        print(f"Error saving figure: {e}")
279                else:
280                    plt.show()
281                plt.close()
282                return
283            else:
284                # No known coordinate found → simple plot vs index
285                plt.figure()
286                plt.plot(data_slice, marker='o')
287                plt.xlabel("Index")
288                ylabel = varname
289                if hasattr(var, "units"):
290                    ylabel += f" ({var.units})"
291                plt.ylabel(ylabel)
292                plt.title(f"{varname} (1D)")
293
294                if output_path:
295                    try:
296                        plt.savefig(output_path, bbox_inches="tight")
297                        print(f"Figure saved to '{output_path}'")
298                    except Exception as e:
299                        print(f"Error saving figure: {e}")
300                else:
301                    plt.show()
302                plt.close()
303                return
304
305        else:
306            # Unable to identify the remaining dimension → error
307            print(f"Error: After slicing, data for '{varname}' is 1D but remaining dimension is unknown.")
308            return
309
310    # CASE: After slicing, if data_slice.ndim == 2
311    if data_slice.ndim == 2:
312        # If lat and lon exist in the original dims, re-find their indices
313        lat_idx2 = find_dim_index(dims, LAT_DIMS)
314        lon_idx2 = find_dim_index(dims, LON_DIMS)
315
316        if lat_idx2 is not None and lon_idx2 is not None:
317            # We have a 2D variable on a lat×lon grid
318            lat_varname = find_coord_var(dataset, LAT_DIMS)
319            lon_varname = find_coord_var(dataset, LON_DIMS)
320            if lat_varname is None or lon_varname is None:
321                print("Error: Could not locate latitude/longitude variables in the dataset.")
322                return
323
324            lat_var = dataset.variables[lat_varname][:]
325            lon_var = dataset.variables[lon_varname][:]
326            if hasattr(lat_var, "mask"):
327                lat_var = np.where(lat_var.mask, np.nan, lat_var.data)
328            if hasattr(lon_var, "mask"):
329                lon_var = np.where(lon_var.mask, np.nan, lon_var.data)
330
331            # Build 2D coordinate arrays
332            if lat_var.ndim == 1 and lon_var.ndim == 1:
333                lon2d, lat2d = np.meshgrid(lon_var, lat_var)
334            elif lat_var.ndim == 2 and lon_var.ndim == 2:
335                lat2d, lon2d = lat_var, lon_var
336            else:
337                print("Error: Latitude and longitude must both be either 1D or 2D.")
338                return
339
340            plt.figure(figsize=(10, 6))
341            cf = plt.contourf(lon2d, lat2d, data_slice, cmap=colormap)
342            cbar = plt.colorbar(cf)
343            if hasattr(var, "units"):
344                cbar.set_label(f"{varname} ({var.units})")
345            else:
346                cbar.set_label(varname)
347
348            lon_label = f"Longitude ({getattr(dataset.variables[lon_varname], 'units', 'degrees')})"
349            lat_label = f"Latitude ({getattr(dataset.variables[lat_varname], 'units', 'degrees')})"
350            plt.xlabel(lon_label)
351            plt.ylabel(lat_label)
352            plt.title(f"{varname} (lat × lon)")
353
354            if output_path:
355                try:
356                    plt.savefig(output_path, bbox_inches="tight")
357                    print(f"Figure saved to '{output_path}'")
358                except Exception as e:
359                    print(f"Error saving figure: {e}")
360            else:
361                plt.show()
362            plt.close()
363            return
364
365        else:
366            # No lat/lon → two non-geographical dimensions; plot with imshow
367            plt.figure(figsize=(8, 6))
368            plt.imshow(data_slice, aspect='auto')
369            cb_label = varname
370            if hasattr(var, "units"):
371                cb_label += f" ({var.units})"
372            plt.colorbar(label=cb_label)
373            plt.xlabel("Dimension 2 Index")
374            plt.ylabel("Dimension 1 Index")
375            plt.title(f"{varname} (2D without lat/lon)")
376
377            if output_path:
378                try:
379                    plt.savefig(output_path, bbox_inches="tight")
380                    print(f"Figure saved to '{output_path}'")
381                except Exception as e:
382                    print(f"Error saving figure: {e}")
383            else:
384                plt.show()
385            plt.close()
386            return
387
388    # CASE: data_slice.ndim is neither 0, 1, nor 2
389    print(f"Error: After slicing, data for '{varname}' has ndim={data_slice.ndim}, which is not supported.")
390    return
391
392
393def visualize_variable_interactive(nc_path=None):
394    """
395    Interactive mode: prompts for the NetCDF file if not provided, then for the variable,
396    then for Time/altitude indices (skipped entirely if variable is purely 1D over Time),
397    and for each other dimension offers to fix an index or to plot along that dimension (by typing 'f').
398
399    If a dimension has length 1, the index 0 is chosen automatically.
400    """
401    # Determine file path
402    if nc_path:
403        file_input = nc_path
404    else:
405        readline.set_completer(complete_filename)
406        readline.parse_and_bind("tab: complete")
407        file_input = input("Enter the path to the NetCDF file: ").strip()
408
409    if not file_input:
410        print("No file specified. Exiting.")
411        return
412    if not os.path.isfile(file_input):
413        print(f"Error: '{file_input}' not found.")
414        return
415
416    try:
417        ds = Dataset(file_input, mode="r")
418    except Exception as e:
419        print(f"Error: Unable to open '{file_input}': {e}")
420        return
421
422    varnames = list(ds.variables.keys())
423    if not varnames:
424        print("Error: No variables found in the dataset.")
425        ds.close()
426        return
427
428    # Auto-select if only one variable
429    if len(varnames) == 1:
430        var_input = varnames[0]
431        print(f"Automatically selected the only variable: '{var_input}'")
432    else:
433        print("\nAvailable variables:")
434        for name in varnames:
435            print(f"  - {name}")
436        print()
437        readline.set_completer(make_varname_completer(varnames))
438        var_input = input("Enter the name of the variable to visualize: ").strip()
439        if var_input not in ds.variables:
440            print(f"Error: Variable '{var_input}' not found. Exiting.")
441            ds.close()
442            return
443
444    dims = ds.variables[var_input].dimensions  # tuple of dimension names
445
446    # If the variable is purely 1D over Time, plot immediately without asking for time index
447    if len(dims) == 1 and find_dim_index(dims, TIME_DIMS) is not None:
448        plot_variable(
449            dataset=ds,
450            varname=var_input,
451            time_index=None,
452            alt_index=None,
453            colormap="jet",
454            output_path=None,
455            extra_indices=None
456        )
457        ds.close()
458        return
459
460    # Otherwise, proceed to prompt for time/altitude and other dimensions
461    time_idx = None
462    alt_idx = None
463
464    # Prompt for time index if applicable
465    t_idx = find_dim_index(dims, TIME_DIMS)
466    if t_idx is not None:
467        time_len = ds.variables[var_input].shape[t_idx]
468        if time_len > 1:
469            while True:
470                try:
471                    user_t = input(f"Enter time index [0..{time_len - 1}]: ").strip()
472                    if user_t == "":
473                        print("No time index entered. Exiting.")
474                        ds.close()
475                        return
476                    ti = int(user_t)
477                    if 0 <= ti < time_len:
478                        time_idx = ti
479                        break
480                except ValueError:
481                    pass
482                print(f"Invalid index. Enter an integer between 0 and {time_len - 1}.")
483        else:
484            time_idx = 0
485            print("Only one time step available; using index 0.")
486
487    # Prompt for altitude index if applicable
488    a_idx = find_dim_index(dims, ALT_DIMS)
489    if a_idx is not None:
490        alt_len = ds.variables[var_input].shape[a_idx]
491        if alt_len > 1:
492            while True:
493                try:
494                    user_a = input(f"Enter altitude index [0..{alt_len - 1}]: ").strip()
495                    if user_a == "":
496                        print("No altitude index entered. Exiting.")
497                        ds.close()
498                        return
499                    ai = int(user_a)
500                    if 0 <= ai < alt_len:
501                        alt_idx = ai
502                        break
503                except ValueError:
504                    pass
505                print(f"Invalid index. Enter an integer between 0 and {alt_len - 1}.")
506        else:
507            alt_idx = 0
508            print("Only one altitude level available; using index 0.")
509
510    # Identify other dimensions (excluding Time/lat/lon/alt)
511    other_dims = []
512    for idx_dim, dim_name in enumerate(dims):
513        if idx_dim == t_idx or idx_dim == a_idx:
514            continue
515        if dim_name.lower() in (d.lower() for d in LAT_DIMS + LON_DIMS):
516            continue
517        other_dims.append((idx_dim, dim_name))
518
519    # For each other dimension, ask user to fix an index or type 'f' to plot along that dimension
520    extra_indices = {}
521    for idx_dim, dim_name in other_dims:
522        dim_len = ds.variables[var_input].shape[idx_dim]
523        if dim_len == 1:
524            extra_indices[dim_name] = 0
525            print(f"Dimension '{dim_name}' has length 1; using index 0.")
526        else:
527            while True:
528                prompt = (
529                    f"Enter index for '{dim_name}' [0..{dim_len - 1}] "
530                    f"or 'f' to plot along '{dim_name}': "
531                )
532                user_i = input(prompt).strip().lower()
533                if user_i == 'f':
534                    # Leave this dimension unfixed → no key in extra_indices
535                    break
536                if user_i == "":
537                    print("No index entered. Exiting.")
538                    ds.close()
539                    return
540                try:
541                    idx_val = int(user_i)
542                    if 0 <= idx_val < dim_len:
543                        extra_indices[dim_name] = idx_val
544                        break
545                except ValueError:
546                    pass
547                print(f"Invalid index. Enter an integer between 0 and {dim_len - 1}, or 'f'.")
548
549    # Finally, call plot_variable with collected indices
550    plot_variable(
551        dataset=ds,
552        varname=var_input,
553        time_index=time_idx,
554        alt_index=alt_idx,
555        colormap="jet",
556        output_path=None,
557        extra_indices=extra_indices
558    )
559    ds.close()
560
561
562def visualize_variable_cli(nc_path, varname, time_index, alt_index, colormap, output_path, extra_json):
563    """
564    Command-line mode: visualize directly, parsing the --extra-indices argument (JSON string).
565    """
566    if not os.path.isfile(nc_path):
567        print(f"Error: '{nc_path}' not found.")
568        return
569
570    try:
571        ds = Dataset(nc_path, mode="r")
572    except Exception as e:
573        print(f"Error: Unable to open '{nc_path}': {e}")
574        return
575
576    if varname not in ds.variables:
577        print(f"Error: Variable '{varname}' not found in '{nc_path}'.")
578        ds.close()
579        return
580
581    # Parse extra_indices if provided
582    extra_indices = {}
583    if extra_json:
584        try:
585            parsed = json.loads(extra_json)
586            if isinstance(parsed, dict):
587                for k, v in parsed.items():
588                    if isinstance(k, str) and isinstance(v, int):
589                        extra_indices[k] = v
590            else:
591                print("Warning: --extra-indices is not a JSON object. Ignored.")
592        except json.JSONDecodeError:
593            print("Warning: --extra-indices is not valid JSON. Ignored.")
594
595    plot_variable(
596        dataset=ds,
597        varname=varname,
598        time_index=time_index,
599        alt_index=alt_index,
600        colormap=colormap,
601        output_path=output_path,
602        extra_indices=extra_indices
603    )
604    ds.close()
605
606
607def main():
608    parser = argparse.ArgumentParser(
609        description="Visualize a 1D/2D slice of a NetCDF variable on various dimension types."
610    )
611    parser.add_argument(
612        "nc_file",
613        nargs="?",
614        help="Path to the NetCDF file (interactive if omitted)."
615    )
616    parser.add_argument(
617        "--variable", "-v",
618        help="Name of the variable to visualize."
619    )
620    parser.add_argument(
621        "--time-index", "-t",
622        type=int,
623        help="Index on the Time dimension, if applicable (ignored for pure 1D time series)."
624    )
625    parser.add_argument(
626        "--alt-index", "-a",
627        type=int,
628        help="Index on the altitude dimension, if applicable."
629    )
630    parser.add_argument(
631        "--cmap", "-c",
632        default="jet",
633        help="Matplotlib colormap (default: 'jet')."
634    )
635    parser.add_argument(
636        "--output", "-o",
637        help="If provided, save the figure to this file instead of displaying it."
638    )
639    parser.add_argument(
640        "--extra-indices", "-e",
641        help="JSON string to fix indices of dimensions outside Time/lat/lon/alt. "
642             "Example: '{\"nslope\":0, \"physical_points\":2}'."
643    )
644
645    args = parser.parse_args()
646
647    # If both nc_file and variable are provided → CLI mode
648    if args.nc_file and args.variable:
649        visualize_variable_cli(
650            nc_path=args.nc_file,
651            varname=args.variable,
652            time_index=args.time_index,
653            alt_index=args.alt_index,
654            colormap=args.cmap,
655            output_path=args.output,
656            extra_json=args.extra_indices
657        )
658    else:
659        # Otherwise → fully interactive mode
660        visualize_variable_interactive(nc_path=args.nc_file)
661
662
663if __name__ == "__main__":
664    main()
665
Note: See TracBrowser for help on using the repository browser.