#!/usr/bin/env python3
##############################################################
### Python script to visualize a variable in a NetCDF file ###
##############################################################

"""
This script can display any numeric variable from a NetCDF file.
It supports the following cases:
  - 1D time series (Time)
  - 1D vertical profiles (e.g., subsurface_layers)
  - 2D latitude/longitude map
  - 2D (Time × another dimension)
  - Variables with dimension “physical_points” displayed on a 2D map if lat/lon are present
  - Optionally average over latitude and plot longitude vs. time heatmap
  - Scalar output (ndim == 0 after slicing)

Usage:
  1) Command-line mode:
       python display_netcdf.py /path/to/your_file.nc --variable VAR_NAME \
           [--time-index 0] [--alt-index 0] [--cmap viridis] [--avg-lat] \
           [--output out.png] [--extra-indices '{"nslope": 1}']

    --variable     : Name of the variable to visualize.
    --time-index   : Index along the Time dimension (0-based, ignored for purely 1D time series).
    --alt-index    : Index along the altitude dimension (0-based), if present.
    --cmap         : Matplotlib colormap (default: "jet").
    --avg-lat      : Average over latitude and plot longitude vs. time heatmap.
    --output       : If provided, save the figure to this filename instead of displaying.
    --extra-indices: JSON string to fix indices for any other dimensions.
                     For any dimension whose name contains "slope", use 1-based numbering here.
                     Example: '{"nslope": 1, "physical_points": 3}'

  2) Interactive mode:
       python display_netcdf.py
       (The script will prompt for everything, including averaging option.)
"""

import os
import sys
import glob
import readline
import argparse
import json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.tri as mtri
from netCDF4 import Dataset

# Constants to recognize dimension names
TIME_DIMS = ("Time", "time", "time_counter")
ALT_DIMS  = ("altitude",)
LAT_DIMS  = ("latitude", "lat")
LON_DIMS  = ("longitude", "lon")


def complete_filename(text, state):
    """
    Readline tab-completion function for filesystem paths.
    """
    if "*" not in text:
        pattern = text + "*"
    else:
        pattern = text
    matches = glob.glob(os.path.expanduser(pattern))
    matches = [m + "/" if os.path.isdir(m) else m for m in matches]
    try:
        return matches[state]
    except IndexError:
        return None


def make_varname_completer(varnames):
    """
    Returns a readline completer function for the given list of variable names.
    """
    def completer(text, state):
        options = [name for name in varnames if name.startswith(text)]
        try:
            return options[state]
        except IndexError:
            return None
    return completer


def find_dim_index(dims, candidates):
    """
    Search through dims tuple for any name in candidates.
    Returns the index if found, else returns None.
    """
    for idx, dim in enumerate(dims):
        for cand in candidates:
            if cand.lower() == dim.lower():
                return idx
    return None


def find_coord_var(dataset, candidates):
    """
    Among dataset variables, return the first variable whose name matches any candidate.
    Returns None if none found.
    """
    for name in dataset.variables:
        for cand in candidates:
            if cand.lower() == name.lower():
                return name
    return None


def plot_variable(dataset, varname, time_index=None, alt_index=None,
                  colormap="jet", output_path=None, extra_indices=None,
                  avg_lat=False):
    """
    Core plotting logic: reads the variable, handles masks,
    determines dimensionality, and creates the appropriate plot:
      - 1D time series
      - 1D profiles or physical_points maps
      - 2D lat×lon or generic 2D
      - Time×lon heatmap if avg_lat=True
      - Scalar printing
    """
    var = dataset.variables[varname]
    dims = var.dimensions

    # Read full data
    try:
        data_full = var[:]
    except Exception as e:
        print(f"Error: Cannot read data for '{varname}': {e}")
        return
    if hasattr(data_full, "mask"):
        data_full = np.where(data_full.mask, np.nan, data_full.data)

    # Pure 1D time series
    if len(dims) == 1 and find_dim_index(dims, TIME_DIMS) is not None:
        time_var = find_coord_var(dataset, TIME_DIMS)
        tvals = (dataset.variables[time_var][:] if time_var
                 else np.arange(data_full.shape[0]))
        if hasattr(tvals, "mask"):
            tvals = np.where(tvals.mask, np.nan, tvals.data)
        plt.figure()
        plt.plot(tvals, data_full, marker="o")
        plt.xlabel(time_var or "Time Index")
        plt.ylabel(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
        plt.title(f"{varname} vs {time_var or 'Index'}")
        if output_path:
            plt.savefig(output_path, bbox_inches="tight")
            print(f"Saved to {output_path}")
        else:
            plt.show()
        return

    # Identify dims
    t_idx = find_dim_index(dims, TIME_DIMS)
    lat_idx = find_dim_index(dims, LAT_DIMS)
    lon_idx = find_dim_index(dims, LON_DIMS)
    a_idx = find_dim_index(dims, ALT_DIMS)

    # Average over latitude & plot time × lon heatmap
    if avg_lat and t_idx is not None and lat_idx is not None and lon_idx is not None:
        # mean over lat axis
        data_avg = np.nanmean(data_full, axis=lat_idx)
        # data_avg shape: (time, lon, ...)
        # we assume no other unfixed dims
        # get coordinates
        time_var = find_coord_var(dataset, TIME_DIMS)
        lon_var = find_coord_var(dataset, LON_DIMS)
        tvals = dataset.variables[time_var][:]
        lons = dataset.variables[lon_var][:]
        if hasattr(tvals, "mask"):
            tvals = np.where(tvals.mask, np.nan, tvals.data)
        if hasattr(lons, "mask"):
            lons = np.where(lons.mask, np.nan, lons.data)
        plt.figure(figsize=(10, 6))
        plt.pcolormesh(lons, tvals, data_avg, shading="auto", cmap=colormap)
        plt.xlabel(f"Longitude ({getattr(dataset.variables[lon_var], 'units', 'deg')})")
        plt.ylabel(time_var)
        cbar = plt.colorbar()
        cbar.set_label(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
        plt.title(f"{varname} averaged over latitude")
        if output_path:
            plt.savefig(output_path, bbox_inches="tight")
            print(f"Saved to {output_path}")
        else:
            plt.show()
        return

    # Build slicer for other cases
    slicer = [slice(None)] * len(dims)
    if t_idx is not None:
        if time_index is None:
            print("Error: please supply a time index.")
            return
        slicer[t_idx] = time_index
    if a_idx is not None:
        if alt_index is None:
            print("Error: please supply an altitude index.")
            return
        slicer[a_idx] = alt_index

    if extra_indices is None:
        extra_indices = {}
    for dn, idx_val in extra_indices.items():
        if dn in dims:
            slicer[dims.index(dn)] = idx_val

    # Extract slice
    try:
        dslice = data_full[tuple(slicer)]
    except Exception as e:
        print(f"Error slicing '{varname}': {e}")
        return

    # Scalar
    if np.ndim(dslice) == 0:
        print(f"Scalar '{varname}': {float(dslice)}")
        return

    # 1D: vector, profile, or physical_points
    if dslice.ndim == 1:
        rem = [(i, name) for i, name in enumerate(dims) if slicer[i] == slice(None)]
        if rem:
            di, dname = rem[0]
            # physical_points → interpolated map
            if dname.lower() == "physical_points":
                latv = find_coord_var(dataset, LAT_DIMS)
                lonv = find_coord_var(dataset, LON_DIMS)
                if latv and lonv:
                    lats = dataset.variables[latv][:]
                    lons = dataset.variables[lonv][:]
                    if hasattr(lats, "mask"):
                        lats = np.where(lats.mask, np.nan, lats.data)
                    if hasattr(lons, "mask"):
                        lons = np.where(lons.mask, np.nan, lons.data)
                    triang = mtri.Triangulation(lons, lats)
                    plt.figure(figsize=(8, 6))
                    cf = plt.tricontourf(triang, dslice, cmap=colormap)
                    cbar = plt.colorbar(cf)
                    cbar.set_label(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
                    plt.xlabel(f"Longitude ({getattr(dataset.variables[lonv], 'units', 'deg')})")
                    plt.ylabel(f"Latitude ({getattr(dataset.variables[latv], 'units', 'deg')})")
                    plt.title(f"{varname} (interpolated map over physical_points)")
                    if output_path:
                        plt.savefig(output_path, bbox_inches="tight")
                        print(f"Saved to {output_path}")
                    else:
                        plt.show()
                    return
            # vertical profile?
            coord = None
            if dname.lower() == "subsurface_layers" and "soildepth" in dataset.variables:
                coord = "soildepth"
            elif dname in dataset.variables:
                coord = dname
            if coord:
                coords = dataset.variables[coord][:]
                if hasattr(coords, "mask"):
                    coords = np.where(coords.mask, np.nan, coords.data)
                plt.figure()
                plt.plot(dslice, coords, marker="o")
                if dname.lower() == "subsurface_layers":
                    plt.gca().invert_yaxis()
                plt.xlabel(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
                plt.ylabel(coord + (f" ({dataset.variables[coord].units})" if hasattr(dataset.variables[coord], "units") else ""))
                plt.title(f"{varname} vs {coord}")
                if output_path:
                    plt.savefig(output_path, bbox_inches="tight")
                    print(f"Saved to {output_path}")
                else:
                    plt.show()
                return
        # generic 1D
        plt.figure()
        plt.plot(dslice, marker="o")
        plt.xlabel("Index")
        plt.ylabel(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
        plt.title(f"{varname} (1D)")
        if output_path:
            plt.savefig(output_path, bbox_inches="tight")
            print(f"Saved to {output_path}")
        else:
            plt.show()
        return

    # 2D: map or generic
    if dslice.ndim == 2:
        lat_idx2 = find_dim_index(dims, LAT_DIMS)
        lon_idx2 = find_dim_index(dims, LON_DIMS)
        if lat_idx2 is not None and lon_idx2 is not None:
            latv = find_coord_var(dataset, LAT_DIMS)
            lonv = find_coord_var(dataset, LON_DIMS)
            lats = dataset.variables[latv][:]
            lons = dataset.variables[lonv][:]
            if hasattr(lats, "mask"):
                lats = np.where(lats.mask, np.nan, lats.data)
            if hasattr(lons, "mask"):
                lons = np.where(lons.mask, np.nan, lons.data)
            if lats.ndim == 1 and lons.ndim == 1:
                lon2d, lat2d = np.meshgrid(lons, lats)
            else:
                lat2d, lon2d = lats, lons
            plt.figure(figsize=(10, 6))
            cf = plt.contourf(lon2d, lat2d, dslice, cmap=colormap)
            cbar = plt.colorbar(cf)
            cbar.set_label(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
            plt.xlabel(f"Longitude ({getattr(dataset.variables[lonv], 'units', 'deg')})")
            plt.ylabel(f"Latitude ({getattr(dataset.variables[latv], 'units', 'deg')})")
            plt.title(f"{varname} (lat × lon)")
            if output_path:
                plt.savefig(output_path, bbox_inches="tight")
                print(f"Saved to {output_path}")
            else:
                plt.show()
            return
        # generic 2D
        plt.figure(figsize=(8, 6))
        plt.imshow(dslice, aspect="auto")
        plt.colorbar(label=varname + (f" ({var.units})" if hasattr(var, "units") else ""))
        plt.xlabel("Dim 2 index")
        plt.ylabel("Dim 1 index")
        plt.title(f"{varname} (2D)")
        if output_path:
            plt.savefig(output_path, bbox_inches="tight")
            print(f"Saved to {output_path}")
        else:
            plt.show()
        return

    print(f"Error: ndim={dslice.ndim} not supported.")


def visualize_variable_interactive(nc_path=None):
    """
    Interactive mode: prompts for file, variable, displays dims,
    handles special case of pure time series, then guides user
    through any needed index selections.
    """
    # File selection
    if nc_path:
        path = nc_path
    else:
        readline.set_completer(complete_filename)
        readline.parse_and_bind("tab: complete")
        path = input("Enter path to NetCDF file: ").strip()
    if not os.path.isfile(path):
        print(f"Error: '{path}' not found."); return
    ds = Dataset(path, "r")

    # Variable selection with autocomplete
    vars_ = list(ds.variables.keys())
    if not vars_:
        print("No variables found."); ds.close(); return
    if len(vars_) == 1:
        var = vars_[0]; print(f"Selected '{var}'")
    else:
        print("Available variables:")
        for v in vars_:
            print(f"  - {v}")
        readline.set_completer(make_varname_completer(vars_))
        readline.parse_and_bind("tab: complete")
        var = input("Variable name: ").strip()
        if var not in ds.variables:
            print("Unknown variable."); ds.close(); return

    # DISPLAY DIMENSIONS AND SIZES
    dims  = ds.variables[var].dimensions
    shape = ds.variables[var].shape
    print(f"\nVariable '{var}' has {len(dims)} dimensions:")
    for name, size in zip(dims, shape):
        print(f"  - {name}: size {size}")
    print()

    # Identify dimension indices
    t_idx   = find_dim_index(dims, TIME_DIMS)
    lat_idx = find_dim_index(dims, LAT_DIMS)
    lon_idx = find_dim_index(dims, LON_DIMS)
    a_idx   = find_dim_index(dims, ALT_DIMS)

    # SPECIAL CASE: time-only series (all others singleton) → plot directly
    if (
        t_idx is not None and shape[t_idx] > 1 and
        all(shape[i] == 1 for i in range(len(dims)) if i != t_idx)
    ):
        print("Detected single-point spatial dims; plotting time series…")
        # récupérer les valeurs
        var_obj = ds.variables[var]
        data = var_obj[:].squeeze()   # shape (time,)
        # temps
        time_var = find_coord_var(ds, TIME_DIMS)
        if time_var:
            tvals = ds.variables[time_var][:]
        else:
            tvals = np.arange(data.shape[0])
        # masque éventuel
        if hasattr(data, "mask"):
            data = np.where(data.mask, np.nan, data.data)
        if hasattr(tvals, "mask"):
            tvals = np.where(tvals.mask, np.nan, tvals.data)
        # tracé
        plt.figure()
        plt.plot(tvals, data, marker="o")
        plt.xlabel(time_var or "Time Index")
        plt.ylabel(var + (f" ({var_obj.units})" if hasattr(var_obj, "units") else ""))
        plt.title(f"{var} vs {time_var or 'Index'}")
        plt.show()
        ds.close()
        return

    # Ask average over latitude only if Time, lat AND lon each >1
    avg_lat = False
    if (
        t_idx   is not None and shape[t_idx]  > 1 and
        lat_idx is not None and shape[lat_idx] > 1 and
        lon_idx is not None and shape[lon_idx] > 1
    ):
        u = input("Average over latitude & plot lon vs time? [y/n]: ").strip().lower()
        avg_lat = (u == "y")

    # Time index prompt
    ti = None
    if t_idx is not None:
        L = shape[t_idx]
        if L > 1:
            while True:
                u = input(f"Enter time index [0..{L-1}]: ").strip()
                try:
                    ti = int(u)
                    if 0 <= ti < L:
                        break
                except:
                    pass
                print("Invalid.")
        else:
            ti = 0; print("Only one time; using 0.")

    # Altitude index prompt
    ai = None
    if a_idx is not None:
        L = shape[a_idx]
        if L > 1:
            while True:
                u = input(f"Enter altitude index [0..{L-1}]: ").strip()
                try:
                    ai = int(u)
                    if 0 <= ai < L:
                        break
                except:
                    pass
                print("Invalid.")
        else:
            ai = 0; print("Only one altitude; using 0.")

    # Other dims
    extra = {}
    for idx, dname in enumerate(dims):
        if idx in (t_idx, a_idx):
            continue
        if dname.lower() in LAT_DIMS + LON_DIMS and shape[idx] == 1:
            extra[dname] = 0
            continue
        L = shape[idx]
        if L == 1:
            extra[dname] = 0
            continue
        if "slope" in dname.lower():
            prompt = f"Enter slope number [1..{L}] for '{dname}': "
        else:
            prompt = f"Enter index [0..{L-1}] or 'f' to plot '{dname}': "
        while True:
            u = input(prompt).strip().lower()
            if u == "f" and "slope" not in dname.lower():
                break
            try:
                iv = int(u)
                if "slope" in dname.lower():
                    if 1 <= iv <= L:
                        extra[dname] = iv - 1
                        break
                else:
                    if 0 <= iv < L:
                        extra[dname] = iv
                        break
            except:
                pass
            print("Invalid.")

    plot_variable(ds, var, time_index=ti, alt_index=ai,
                  colormap="jet", output_path=None,
                  extra_indices=extra, avg_lat=avg_lat)
    ds.close()


def visualize_variable_cli(nc_file, varname, time_index, alt_index,
                           colormap, output_path, extra_json, avg_lat):
    """
    Command-line mode: visualize directly, parsing the --extra-indices argument (JSON string).
    """
    if not os.path.isfile(nc_file):
        print(f"Error: '{nc_file}' not found."); return
    ds = Dataset(nc_file, "r")
    if varname not in ds.variables:
        print(f"Variable '{varname}' not in file."); ds.close(); return

    # DISPLAY DIMENSIONS AND SIZES
    dims  = ds.variables[varname].dimensions
    shape = ds.variables[varname].shape
    print(f"\nVariable '{varname}' has {len(dims)} dimensions:")
    for name, size in zip(dims, shape):
        print(f"  - {name}: size {size}")
    print()

    # SPECIAL CASE: time-only → plot directly
    t_idx = find_dim_index(dims, TIME_DIMS)
    if (
        t_idx is not None and shape[t_idx] > 1 and
        all(shape[i] == 1 for i in range(len(dims)) if i != t_idx)
    ):
        print("Detected single-point spatial dims; plotting time series…")
        # même logique que ci‑dessus
        var_obj = ds.variables[varname]
        data = var_obj[:].squeeze()
        time_var = find_coord_var(ds, TIME_DIMS)
        if time_var:
            tvals = ds.variables[time_var][:]
        else:
            tvals = np.arange(data.shape[0])
        if hasattr(data, "mask"):
            data = np.where(data.mask, np.nan, data.data)
        if hasattr(tvals, "mask"):
            tvals = np.where(tvals.mask, np.nan, tvals.data)
        plt.figure()
        plt.plot(tvals, data, marker="o")
        plt.xlabel(time_var or "Time Index")
        plt.ylabel(varname + (f" ({var_obj.units})" if hasattr(var_obj, "units") else ""))
        plt.title(f"{varname} vs {time_var or 'Index'}")
        if output_path:
            plt.savefig(output_path, bbox_inches="tight")
            print(f"Saved to {output_path}")
        else:
            plt.show()
        ds.close()
        return

    # Si --avg-lat mais lat/lon/Time non compatibles → désactive
    lat_idx = find_dim_index(dims, LAT_DIMS)
    lon_idx = find_dim_index(dims, LON_DIMS)
    if avg_lat and not (
        t_idx   is not None and shape[t_idx]  > 1 and
        lat_idx is not None and shape[lat_idx] > 1 and
        lon_idx is not None and shape[lon_idx] > 1
    ):
        print("Note: disabling --avg-lat (requires Time, lat & lon each >1).")
        avg_lat = False

    # Parse extra indices JSON
    extra = {}
    if extra_json:
        try:
            parsed = json.loads(extra_json)
            for k, v in parsed.items():
                if isinstance(v, int):
                    if "slope" in k.lower():
                        extra[k] = v - 1
                    else:
                        extra[k] = v
        except:
            print("Warning: bad extra-indices.")

    plot_variable(ds, varname, time_index, alt_index,
                  colormap, output_path, extra, avg_lat)
    ds.close()


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("nc_file", nargs="?", help="NetCDF file (omit for interactive)")
    parser.add_argument("-v", "--variable", help="Variable name")
    parser.add_argument("-t", "--time-index", type=int, help="Time index (0-based)")
    parser.add_argument("-a", "--alt-index", type=int, help="Altitude index (0-based)")
    parser.add_argument("-c", "--cmap", default="jet", help="Colormap")
    parser.add_argument("--avg-lat", action="store_true",
                        help="Average over latitude (time × lon heatmap)")
    parser.add_argument("-o", "--output", help="Save figure path")
    parser.add_argument("-e", "--extra-indices", help="JSON for other dims")
    args = parser.parse_args()

    if args.nc_file and args.variable:
        visualize_variable_cli(
            args.nc_file, args.variable,
            args.time_index, args.alt_index,
            args.cmap, args.output,
            args.extra_indices, args.avg_lat
        )
    else:
        visualize_variable_interactive(args.nc_file)


if __name__ == "__main__":
    main()
