source: trunk/LMDZ.MARS/util/display_netcdf.py @ 3808

Last change on this file since 3808 was 3808, checked in by jbclement, 28 hours ago

Mars PCM:

  • Bug corrections for the Python script displaying variables in a NetCDF file regarding the dimensions + addition of options (for example to average over longitude).
  • Improvement for the Python script analyzing variables in a NetCDF file.

JBC

  • Property svn:executable set to *
File size: 21.6 KB
Line 
1#!/usr/bin/env python3
2##############################################################
3### Python script to visualize a variable in a NetCDF file ###
4##############################################################
5
6"""
7This script can display any numeric variable from a NetCDF file.
8It supports the following cases:
9  - 1D time series (Time)
10  - 1D vertical profiles (e.g., subsurface_layers)
11  - 2D latitude/longitude map
12  - 2D (Time × another dimension)
13  - Variables with dimension “physical_points” displayed on a 2D map if lat/lon are present
14  - Optionally average over latitude and plot longitude vs. time heatmap
15  - Scalar output (ndim == 0 after slicing)
16
17Usage:
18  1) Command-line mode:
19       python display_netcdf.py /path/to/your_file.nc --variable VAR_NAME \
20           [--time-index 0] [--alt-index 0] [--cmap viridis] [--avg-lat] \
21           [--output out.png] [--extra-indices '{"nslope": 1}']
22
23    --variable     : Name of the variable to visualize.
24    --time-index   : Index along the Time dimension (0-based, ignored for purely 1D time series).
25    --alt-index    : Index along the altitude dimension (0-based), if present.
26    --cmap         : Matplotlib colormap (default: "jet").
27    --avg-lat      : Average over latitude and plot longitude vs. time heatmap.
28    --output       : If provided, save the figure to this filename instead of displaying.
29    --extra-indices: JSON string to fix indices for any other dimensions.
30                     For any dimension whose name contains "slope", use 1-based numbering here.
31                     Example: '{"nslope": 1, "physical_points": 3}'
32
33  2) Interactive mode:
34       python display_netcdf.py
35       (The script will prompt for everything, including averaging option.)
36"""
37
38import os
39import sys
40import glob
41import readline
42import argparse
43import json
44import numpy as np
45import matplotlib.pyplot as plt
46import matplotlib.tri as mtri
47from netCDF4 import Dataset
48
49# Constants to recognize dimension names
50TIME_DIMS = ("Time", "time", "time_counter")
51ALT_DIMS  = ("altitude",)
52LAT_DIMS  = ("latitude", "lat")
53LON_DIMS  = ("longitude", "lon")
54
55
56def complete_filename(text, state):
57    """
58    Readline tab-completion function for filesystem paths.
59    """
60    if "*" not in text:
61        pattern = text + "*"
62    else:
63        pattern = text
64    matches = glob.glob(os.path.expanduser(pattern))
65    matches = [m + "/" if os.path.isdir(m) else m for m in matches]
66    try:
67        return matches[state]
68    except IndexError:
69        return None
70
71
72def make_varname_completer(varnames):
73    """
74    Returns a readline completer function for the given list of variable names.
75    """
76    def completer(text, state):
77        options = [name for name in varnames if name.startswith(text)]
78        try:
79            return options[state]
80        except IndexError:
81            return None
82    return completer
83
84
85def find_dim_index(dims, candidates):
86    """
87    Search through dims tuple for any name in candidates.
88    Returns the index if found, else returns None.
89    """
90    for idx, dim in enumerate(dims):
91        for cand in candidates:
92            if cand.lower() == dim.lower():
93                return idx
94    return None
95
96
97def find_coord_var(dataset, candidates):
98    """
99    Among dataset variables, return the first variable whose name matches any candidate.
100    Returns None if none found.
101    """
102    for name in dataset.variables:
103        for cand in candidates:
104            if cand.lower() == name.lower():
105                return name
106    return None
107
108
109def plot_variable(dataset, varname, time_index=None, alt_index=None,
110                  colormap="jet", output_path=None, extra_indices=None,
111                  avg_lat=False):
112    """
113    Core plotting logic: reads the variable, handles masks,
114    determines dimensionality, and creates the appropriate plot:
115      - 1D time series
116      - 1D profiles or physical_points maps
117      - 2D lat×lon or generic 2D
118      - Time×lon heatmap if avg_lat=True
119      - Scalar printing
120    """
121    var = dataset.variables[varname]
122    dims = var.dimensions
123
124    # Read full data
125    try:
126        data_full = var[:]
127    except Exception as e:
128        print(f"Error: Cannot read data for '{varname}': {e}")
129        return
130    if hasattr(data_full, "mask"):
131        data_full = np.where(data_full.mask, np.nan, data_full.data)
132
133    # Pure 1D time series
134    if len(dims) == 1 and find_dim_index(dims, TIME_DIMS) is not None:
135        time_var = find_coord_var(dataset, TIME_DIMS)
136        tvals = (dataset.variables[time_var][:] if time_var
137                 else np.arange(data_full.shape[0]))
138        if hasattr(tvals, "mask"):
139            tvals = np.where(tvals.mask, np.nan, tvals.data)
140        plt.figure()
141        plt.plot(tvals, data_full, marker="o")
142        plt.xlabel(time_var or "Time Index")
143        plt.ylabel(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
144        plt.title(f"{varname} vs {time_var or 'Index'}")
145        if output_path:
146            plt.savefig(output_path, bbox_inches="tight")
147            print(f"Saved to {output_path}")
148        else:
149            plt.show()
150        return
151
152    # Identify dims
153    t_idx = find_dim_index(dims, TIME_DIMS)
154    lat_idx = find_dim_index(dims, LAT_DIMS)
155    lon_idx = find_dim_index(dims, LON_DIMS)
156    a_idx = find_dim_index(dims, ALT_DIMS)
157
158    # Average over latitude & plot time × lon heatmap
159    if avg_lat and t_idx is not None and lat_idx is not None and lon_idx is not None:
160        # mean over lat axis
161        data_avg = np.nanmean(data_full, axis=lat_idx)
162        # data_avg shape: (time, lon, ...)
163        # we assume no other unfixed dims
164        # get coordinates
165        time_var = find_coord_var(dataset, TIME_DIMS)
166        lon_var = find_coord_var(dataset, LON_DIMS)
167        tvals = dataset.variables[time_var][:]
168        lons = dataset.variables[lon_var][:]
169        if hasattr(tvals, "mask"):
170            tvals = np.where(tvals.mask, np.nan, tvals.data)
171        if hasattr(lons, "mask"):
172            lons = np.where(lons.mask, np.nan, lons.data)
173        plt.figure(figsize=(10, 6))
174        plt.pcolormesh(lons, tvals, data_avg, shading="auto", cmap=colormap)
175        plt.xlabel(f"Longitude ({getattr(dataset.variables[lon_var], 'units', 'deg')})")
176        plt.ylabel(time_var)
177        cbar = plt.colorbar()
178        cbar.set_label(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
179        plt.title(f"{varname} averaged over latitude")
180        if output_path:
181            plt.savefig(output_path, bbox_inches="tight")
182            print(f"Saved to {output_path}")
183        else:
184            plt.show()
185        return
186
187    # Build slicer for other cases
188    slicer = [slice(None)] * len(dims)
189    if t_idx is not None:
190        if time_index is None:
191            print("Error: please supply a time index.")
192            return
193        slicer[t_idx] = time_index
194    if a_idx is not None:
195        if alt_index is None:
196            print("Error: please supply an altitude index.")
197            return
198        slicer[a_idx] = alt_index
199
200    if extra_indices is None:
201        extra_indices = {}
202    for dn, idx_val in extra_indices.items():
203        if dn in dims:
204            slicer[dims.index(dn)] = idx_val
205
206    # Extract slice
207    try:
208        dslice = data_full[tuple(slicer)]
209    except Exception as e:
210        print(f"Error slicing '{varname}': {e}")
211        return
212
213    # Scalar
214    if np.ndim(dslice) == 0:
215        print(f"Scalar '{varname}': {float(dslice)}")
216        return
217
218    # 1D: vector, profile, or physical_points
219    if dslice.ndim == 1:
220        rem = [(i, name) for i, name in enumerate(dims) if slicer[i] == slice(None)]
221        if rem:
222            di, dname = rem[0]
223            # physical_points → interpolated map
224            if dname.lower() == "physical_points":
225                latv = find_coord_var(dataset, LAT_DIMS)
226                lonv = find_coord_var(dataset, LON_DIMS)
227                if latv and lonv:
228                    lats = dataset.variables[latv][:]
229                    lons = dataset.variables[lonv][:]
230                    if hasattr(lats, "mask"):
231                        lats = np.where(lats.mask, np.nan, lats.data)
232                    if hasattr(lons, "mask"):
233                        lons = np.where(lons.mask, np.nan, lons.data)
234                    triang = mtri.Triangulation(lons, lats)
235                    plt.figure(figsize=(8, 6))
236                    cf = plt.tricontourf(triang, dslice, cmap=colormap)
237                    cbar = plt.colorbar(cf)
238                    cbar.set_label(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
239                    plt.xlabel(f"Longitude ({getattr(dataset.variables[lonv], 'units', 'deg')})")
240                    plt.ylabel(f"Latitude ({getattr(dataset.variables[latv], 'units', 'deg')})")
241                    plt.title(f"{varname} (interpolated map over physical_points)")
242                    if output_path:
243                        plt.savefig(output_path, bbox_inches="tight")
244                        print(f"Saved to {output_path}")
245                    else:
246                        plt.show()
247                    return
248            # vertical profile?
249            coord = None
250            if dname.lower() == "subsurface_layers" and "soildepth" in dataset.variables:
251                coord = "soildepth"
252            elif dname in dataset.variables:
253                coord = dname
254            if coord:
255                coords = dataset.variables[coord][:]
256                if hasattr(coords, "mask"):
257                    coords = np.where(coords.mask, np.nan, coords.data)
258                plt.figure()
259                plt.plot(dslice, coords, marker="o")
260                if dname.lower() == "subsurface_layers":
261                    plt.gca().invert_yaxis()
262                plt.xlabel(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
263                plt.ylabel(coord + (f" ({dataset.variables[coord].units})" if hasattr(dataset.variables[coord], "units") else ""))
264                plt.title(f"{varname} vs {coord}")
265                if output_path:
266                    plt.savefig(output_path, bbox_inches="tight")
267                    print(f"Saved to {output_path}")
268                else:
269                    plt.show()
270                return
271        # generic 1D
272        plt.figure()
273        plt.plot(dslice, marker="o")
274        plt.xlabel("Index")
275        plt.ylabel(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
276        plt.title(f"{varname} (1D)")
277        if output_path:
278            plt.savefig(output_path, bbox_inches="tight")
279            print(f"Saved to {output_path}")
280        else:
281            plt.show()
282        return
283
284    # 2D: map or generic
285    if dslice.ndim == 2:
286        lat_idx2 = find_dim_index(dims, LAT_DIMS)
287        lon_idx2 = find_dim_index(dims, LON_DIMS)
288        if lat_idx2 is not None and lon_idx2 is not None:
289            latv = find_coord_var(dataset, LAT_DIMS)
290            lonv = find_coord_var(dataset, LON_DIMS)
291            lats = dataset.variables[latv][:]
292            lons = dataset.variables[lonv][:]
293            if hasattr(lats, "mask"):
294                lats = np.where(lats.mask, np.nan, lats.data)
295            if hasattr(lons, "mask"):
296                lons = np.where(lons.mask, np.nan, lons.data)
297            if lats.ndim == 1 and lons.ndim == 1:
298                lon2d, lat2d = np.meshgrid(lons, lats)
299            else:
300                lat2d, lon2d = lats, lons
301            plt.figure(figsize=(10, 6))
302            cf = plt.contourf(lon2d, lat2d, dslice, cmap=colormap)
303            cbar = plt.colorbar(cf)
304            cbar.set_label(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
305            plt.xlabel(f"Longitude ({getattr(dataset.variables[lonv], 'units', 'deg')})")
306            plt.ylabel(f"Latitude ({getattr(dataset.variables[latv], 'units', 'deg')})")
307            plt.title(f"{varname} (lat × lon)")
308            if output_path:
309                plt.savefig(output_path, bbox_inches="tight")
310                print(f"Saved to {output_path}")
311            else:
312                plt.show()
313            return
314        # generic 2D
315        plt.figure(figsize=(8, 6))
316        plt.imshow(dslice, aspect="auto")
317        plt.colorbar(label=varname + (f" ({var.units})" if hasattr(var, "units") else ""))
318        plt.xlabel("Dim 2 index")
319        plt.ylabel("Dim 1 index")
320        plt.title(f"{varname} (2D)")
321        if output_path:
322            plt.savefig(output_path, bbox_inches="tight")
323            print(f"Saved to {output_path}")
324        else:
325            plt.show()
326        return
327
328    print(f"Error: ndim={dslice.ndim} not supported.")
329
330
331def visualize_variable_interactive(nc_path=None):
332    """
333    Interactive mode: prompts for file, variable, displays dims,
334    handles special case of pure time series, then guides user
335    through any needed index selections.
336    """
337    # File selection
338    if nc_path:
339        path = nc_path
340    else:
341        readline.set_completer(complete_filename)
342        readline.parse_and_bind("tab: complete")
343        path = input("Enter path to NetCDF file: ").strip()
344    if not os.path.isfile(path):
345        print(f"Error: '{path}' not found."); return
346    ds = Dataset(path, "r")
347
348    # Variable selection with autocomplete
349    vars_ = list(ds.variables.keys())
350    if not vars_:
351        print("No variables found."); ds.close(); return
352    if len(vars_) == 1:
353        var = vars_[0]; print(f"Selected '{var}'")
354    else:
355        print("Available variables:")
356        for v in vars_:
357            print(f"  - {v}")
358        readline.set_completer(make_varname_completer(vars_))
359        readline.parse_and_bind("tab: complete")
360        var = input("Variable name: ").strip()
361        if var not in ds.variables:
362            print("Unknown variable."); ds.close(); return
363
364    # DISPLAY DIMENSIONS AND SIZES
365    dims  = ds.variables[var].dimensions
366    shape = ds.variables[var].shape
367    print(f"\nVariable '{var}' has {len(dims)} dimensions:")
368    for name, size in zip(dims, shape):
369        print(f"  - {name}: size {size}")
370    print()
371
372    # Identify dimension indices
373    t_idx   = find_dim_index(dims, TIME_DIMS)
374    lat_idx = find_dim_index(dims, LAT_DIMS)
375    lon_idx = find_dim_index(dims, LON_DIMS)
376    a_idx   = find_dim_index(dims, ALT_DIMS)
377
378    # SPECIAL CASE: time-only series (all others singleton) → plot directly
379    if (
380        t_idx is not None and shape[t_idx] > 1 and
381        all(shape[i] == 1 for i in range(len(dims)) if i != t_idx)
382    ):
383        print("Detected single-point spatial dims; plotting time series…")
384        # récupérer les valeurs
385        var_obj = ds.variables[var]
386        data = var_obj[:].squeeze()   # shape (time,)
387        # temps
388        time_var = find_coord_var(ds, TIME_DIMS)
389        if time_var:
390            tvals = ds.variables[time_var][:]
391        else:
392            tvals = np.arange(data.shape[0])
393        # masque éventuel
394        if hasattr(data, "mask"):
395            data = np.where(data.mask, np.nan, data.data)
396        if hasattr(tvals, "mask"):
397            tvals = np.where(tvals.mask, np.nan, tvals.data)
398        # tracé
399        plt.figure()
400        plt.plot(tvals, data, marker="o")
401        plt.xlabel(time_var or "Time Index")
402        plt.ylabel(var + (f" ({var_obj.units})" if hasattr(var_obj, "units") else ""))
403        plt.title(f"{var} vs {time_var or 'Index'}")
404        plt.show()
405        ds.close()
406        return
407
408    # Ask average over latitude only if Time, lat AND lon each >1
409    avg_lat = False
410    if (
411        t_idx   is not None and shape[t_idx]  > 1 and
412        lat_idx is not None and shape[lat_idx] > 1 and
413        lon_idx is not None and shape[lon_idx] > 1
414    ):
415        u = input("Average over latitude & plot lon vs time? [y/n]: ").strip().lower()
416        avg_lat = (u == "y")
417
418    # Time index prompt
419    ti = None
420    if t_idx is not None:
421        L = shape[t_idx]
422        if L > 1:
423            while True:
424                u = input(f"Enter time index [0..{L-1}]: ").strip()
425                try:
426                    ti = int(u)
427                    if 0 <= ti < L:
428                        break
429                except:
430                    pass
431                print("Invalid.")
432        else:
433            ti = 0; print("Only one time; using 0.")
434
435    # Altitude index prompt
436    ai = None
437    if a_idx is not None:
438        L = shape[a_idx]
439        if L > 1:
440            while True:
441                u = input(f"Enter altitude index [0..{L-1}]: ").strip()
442                try:
443                    ai = int(u)
444                    if 0 <= ai < L:
445                        break
446                except:
447                    pass
448                print("Invalid.")
449        else:
450            ai = 0; print("Only one altitude; using 0.")
451
452    # Other dims
453    extra = {}
454    for idx, dname in enumerate(dims):
455        if idx in (t_idx, a_idx):
456            continue
457        if dname.lower() in LAT_DIMS + LON_DIMS and shape[idx] == 1:
458            extra[dname] = 0
459            continue
460        L = shape[idx]
461        if L == 1:
462            extra[dname] = 0
463            continue
464        if "slope" in dname.lower():
465            prompt = f"Enter slope number [1..{L}] for '{dname}': "
466        else:
467            prompt = f"Enter index [0..{L-1}] or 'f' to plot '{dname}': "
468        while True:
469            u = input(prompt).strip().lower()
470            if u == "f" and "slope" not in dname.lower():
471                break
472            try:
473                iv = int(u)
474                if "slope" in dname.lower():
475                    if 1 <= iv <= L:
476                        extra[dname] = iv - 1
477                        break
478                else:
479                    if 0 <= iv < L:
480                        extra[dname] = iv
481                        break
482            except:
483                pass
484            print("Invalid.")
485
486    plot_variable(ds, var, time_index=ti, alt_index=ai,
487                  colormap="jet", output_path=None,
488                  extra_indices=extra, avg_lat=avg_lat)
489    ds.close()
490
491
492def visualize_variable_cli(nc_file, varname, time_index, alt_index,
493                           colormap, output_path, extra_json, avg_lat):
494    """
495    Command-line mode: visualize directly, parsing the --extra-indices argument (JSON string).
496    """
497    if not os.path.isfile(nc_file):
498        print(f"Error: '{nc_file}' not found."); return
499    ds = Dataset(nc_file, "r")
500    if varname not in ds.variables:
501        print(f"Variable '{varname}' not in file."); ds.close(); return
502
503    # DISPLAY DIMENSIONS AND SIZES
504    dims  = ds.variables[varname].dimensions
505    shape = ds.variables[varname].shape
506    print(f"\nVariable '{varname}' has {len(dims)} dimensions:")
507    for name, size in zip(dims, shape):
508        print(f"  - {name}: size {size}")
509    print()
510
511    # SPECIAL CASE: time-only → plot directly
512    t_idx = find_dim_index(dims, TIME_DIMS)
513    if (
514        t_idx is not None and shape[t_idx] > 1 and
515        all(shape[i] == 1 for i in range(len(dims)) if i != t_idx)
516    ):
517        print("Detected single-point spatial dims; plotting time series…")
518        # même logique que ci‑dessus
519        var_obj = ds.variables[varname]
520        data = var_obj[:].squeeze()
521        time_var = find_coord_var(ds, TIME_DIMS)
522        if time_var:
523            tvals = ds.variables[time_var][:]
524        else:
525            tvals = np.arange(data.shape[0])
526        if hasattr(data, "mask"):
527            data = np.where(data.mask, np.nan, data.data)
528        if hasattr(tvals, "mask"):
529            tvals = np.where(tvals.mask, np.nan, tvals.data)
530        plt.figure()
531        plt.plot(tvals, data, marker="o")
532        plt.xlabel(time_var or "Time Index")
533        plt.ylabel(varname + (f" ({var_obj.units})" if hasattr(var_obj, "units") else ""))
534        plt.title(f"{varname} vs {time_var or 'Index'}")
535        if output_path:
536            plt.savefig(output_path, bbox_inches="tight")
537            print(f"Saved to {output_path}")
538        else:
539            plt.show()
540        ds.close()
541        return
542
543    # Si --avg-lat mais lat/lon/Time non compatibles → désactive
544    lat_idx = find_dim_index(dims, LAT_DIMS)
545    lon_idx = find_dim_index(dims, LON_DIMS)
546    if avg_lat and not (
547        t_idx   is not None and shape[t_idx]  > 1 and
548        lat_idx is not None and shape[lat_idx] > 1 and
549        lon_idx is not None and shape[lon_idx] > 1
550    ):
551        print("Note: disabling --avg-lat (requires Time, lat & lon each >1).")
552        avg_lat = False
553
554    # Parse extra indices JSON
555    extra = {}
556    if extra_json:
557        try:
558            parsed = json.loads(extra_json)
559            for k, v in parsed.items():
560                if isinstance(v, int):
561                    if "slope" in k.lower():
562                        extra[k] = v - 1
563                    else:
564                        extra[k] = v
565        except:
566            print("Warning: bad extra-indices.")
567
568    plot_variable(ds, varname, time_index, alt_index,
569                  colormap, output_path, extra, avg_lat)
570    ds.close()
571
572
573def main():
574    parser = argparse.ArgumentParser()
575    parser.add_argument("nc_file", nargs="?", help="NetCDF file (omit for interactive)")
576    parser.add_argument("-v", "--variable", help="Variable name")
577    parser.add_argument("-t", "--time-index", type=int, help="Time index (0-based)")
578    parser.add_argument("-a", "--alt-index", type=int, help="Altitude index (0-based)")
579    parser.add_argument("-c", "--cmap", default="jet", help="Colormap")
580    parser.add_argument("--avg-lat", action="store_true",
581                        help="Average over latitude (time × lon heatmap)")
582    parser.add_argument("-o", "--output", help="Save figure path")
583    parser.add_argument("-e", "--extra-indices", help="JSON for other dims")
584    args = parser.parse_args()
585
586    if args.nc_file and args.variable:
587        visualize_variable_cli(
588            args.nc_file, args.variable,
589            args.time_index, args.alt_index,
590            args.cmap, args.output,
591            args.extra_indices, args.avg_lat
592        )
593    else:
594        visualize_variable_interactive(args.nc_file)
595
596
597if __name__ == "__main__":
598    main()
Note: See TracBrowser for help on using the repository browser.