#!/usr/bin/env python3
##############################################################
### Python script to visualize a variable in a NetCDF file ###
##############################################################

"""
This script can display any numeric variable from a NetCDF file.
It supports the following cases:
  - 1D time series (Time)
  - 1D vertical profiles (e.g., subsurface_layers)
  - 2D latitude/longitude map
  - 2D (Time × another dimension)
  - Variables with dimension “physical_points” as 2D map if lat/lon present,
    or generic 2D plot if the remaining axes are spatial
  - Scalar output (ndim == 0 after slicing)

Usage:
  1) Command-line mode:
       python display_netcdf.py /path/to/your_file.nc --variable VAR_NAME \
           [--time-index 0] [--cmap viridis] [--output out.png] \
           [--extra-indices '{"dim1": idx1, "dim2": idx2}']

    --variable     : Name of the variable to visualize.
    --time-index   : Index along the Time dimension (ignored for purely 1D time series).
    --alt-index    : Index along the altitude dimension, if present.
    --cmap         : Matplotlib colormap for contourf (default: "jet").
    --output       : If provided, save the figure to this filename instead of displaying.
    --extra-indices: JSON string to fix indices for any dimensions other than Time, lat, lon, or altitude.
                     Example: '{"nslope": 0, "physical_points": 2}'
                     Omitting a dimension means it remains unfixed (useful to plot a 1D profile).

  2) Interactive mode:
       python display_netcdf.py
       (The script will prompt for the NetCDF file, the variable, etc.)
"""

import os
import sys
import glob
import readline
import argparse
import json
import numpy as np
import matplotlib.pyplot as plt
from netCDF4 import Dataset

# Constants to recognize dimension names
TIME_DIMS = ("Time", "time", "time_counter")
ALT_DIMS  = ("altitude",)
LAT_DIMS  = ("latitude", "lat")
LON_DIMS  = ("longitude", "lon")


def complete_filename(text, state):
    """
    Readline tab-completion function for filesystem paths.
    """
    if "*" not in text:
        pattern = text + "*"
    else:
        pattern = text
    matches = glob.glob(os.path.expanduser(pattern))
    matches = [m + "/" if os.path.isdir(m) else m for m in matches]
    try:
        return matches[state]
    except IndexError:
        return None


def make_varname_completer(varnames):
    """
    Returns a readline completer function for the given list of variable names.
    """
    def completer(text, state):
        options = [name for name in varnames if name.startswith(text)]
        try:
            return options[state]
        except IndexError:
            return None
    return completer


def find_dim_index(dims, candidates):
    """
    Search through dims tuple for any name in candidates.
    Returns the index if found, else returns None.
    """
    for idx, dim in enumerate(dims):
        for cand in candidates:
            if cand.lower() == dim.lower():
                return idx
    return None


def find_coord_var(dataset, candidates):
    """
    Among dataset variables, return the first variable whose name matches any candidate.
    Returns None if none found.
    """
    for name in dataset.variables:
        for cand in candidates:
            if cand.lower() == name.lower():
                return name
    return None


def plot_variable(dataset, varname, time_index=None, alt_index=None, colormap="jet",
                  output_path=None, extra_indices=None):
    """
    Extracts the requested slice from the variable and plots it according to the data shape:

    - Pure 1D time series → time-series plot
    - After slicing:
        • If data_slice.ndim == 0 → print the scalar value
        • If data_slice.ndim == 1:
            • If the remaining dimension is “subsurface_layers” (or another known coordinate) → vertical profile
            • Else → simple plot vs. index
        • If data_slice.ndim == 2:
            • If lat/lon exist → contourf map
            • Else → imshow generic 2D plot
    - If data_slice.ndim is neither 0, 1, nor 2 → error message

    Parameters
    ----------
    dataset       : netCDF4.Dataset object (already open)
    varname       : name of the variable to plot
    time_index    : int or None (if variable has a time dimension, ignored for pure time series)
    alt_index     : int or None (if variable has an altitude dimension)
    colormap      : string colormap name (passed to plt.contourf)
    output_path   : string filepath to save figure, or None to display interactively
    extra_indices : dict { dimension_name (str) : index (int) } for slicing all
                     dimensions except Time/lat/lon/alt. If a dimension is not
                     included, it remains “slice(None)” (useful for 1D plots).
    """
    var = dataset.variables[varname]
    dims = var.dimensions  # tuple of dimension names

    # Read the full data (could be a masked array)
    try:
        data_full = var[:]
    except Exception as e:
        print(f"Error: Cannot read data for variable '{varname}': {e}")
        return

    # Convert masked array to NaN
    if hasattr(data_full, "mask"):
        data_full = np.where(data_full.mask, np.nan, data_full.data)

    # ------------------------------------------------------------------------
    # 1) Pure 1D time series (dims == ('Time',) or equivalent)
    if len(dims) == 1 and find_dim_index(dims, TIME_DIMS) is not None:
        # Plot the time series directly
        time_varname = find_coord_var(dataset, TIME_DIMS)
        if time_varname:
            time_vals = dataset.variables[time_varname][:]
            if hasattr(time_vals, "mask"):
                time_vals = np.where(time_vals.mask, np.nan, time_vals.data)
        else:
            time_vals = np.arange(data_full.shape[0])

        plt.figure()
        plt.plot(time_vals, data_full, marker='o')
        xlabel = time_varname if time_varname else "Time Index"
        plt.xlabel(xlabel)
        ylabel = varname
        if hasattr(var, "units"):
            ylabel += f" ({var.units})"
        plt.ylabel(ylabel)
        plt.title(f"{varname} vs {xlabel}")

        if output_path:
            try:
                plt.savefig(output_path, bbox_inches="tight")
                print(f"Figure saved to '{output_path}'")
            except Exception as e:
                print(f"Error saving figure: {e}")
        else:
            plt.show()
        plt.close()
        return
    # ------------------------------------------------------------------------

    # Identify special dimension indices
    t_idx = find_dim_index(dims, TIME_DIMS)
    a_idx = find_dim_index(dims, ALT_DIMS)
    lat_idx = find_dim_index(dims, LAT_DIMS)
    lon_idx = find_dim_index(dims, LON_DIMS)

    # Build the slicer list
    slicer = [slice(None)] * len(dims)

    # Apply slicing on Time and altitude if specified
    if t_idx is not None:
        if time_index is None:
            print("Error: Variable has a time dimension; please supply a time index.")
            return
        slicer[t_idx] = time_index
    if a_idx is not None:
        if alt_index is None:
            print("Error: Variable has an altitude dimension; please supply an altitude index.")
            return
        slicer[a_idx] = alt_index

    # Apply slicing on all “extra” dimensions (except Time/lat/lon/alt)
    if extra_indices is None:
        extra_indices = {}
    for dim_name, idx_val in extra_indices.items():
        if dim_name in dims:
            dim_index = dims.index(dim_name)
            slicer[dim_index] = idx_val

    # Extract the sliced data
    try:
        data_slice = data_full[tuple(slicer)]
    except Exception as e:
        print(f"Error: Could not slice variable '{varname}': {e}")
        return

    # CASE: After slicing, if data_slice.ndim == 0 → scalar
    if np.ndim(data_slice) == 0:
        try:
            scalar_val = float(data_slice)
        except Exception:
            scalar_val = data_slice
        print(f"Scalar result for '{varname}': {scalar_val}")
        return

    # CASE: After slicing, if data_slice.ndim == 1 (vertical profile or simple vector)
    if data_slice.ndim == 1:
        # Identify the remaining dimension
        rem_dim = None
        for di, dname in enumerate(dims):
            if slicer[di] == slice(None):
                rem_dim = (di, dname)
                break

        if rem_dim is not None:
            di, dname = rem_dim
            coord_var = None

            # If it's "subsurface_layers", look for coordinate "soildepth"
            if dname.lower() == "subsurface_layers" and "soildepth" in dataset.variables:
                coord_var = "soildepth"
            # If there is a variable with the same name, use it
            elif dname in dataset.variables:
                coord_var = dname

            if coord_var:
                coord_vals = dataset.variables[coord_var][:]
                if hasattr(coord_vals, "mask"):
                    coord_vals = np.where(coord_vals.mask, np.nan, coord_vals.data)
                x = data_slice
                y = coord_vals

                plt.figure()
                plt.plot(x, y, marker='o')
                # Invert Y-axis if it's a depth coordinate
                if dname.lower() == "subsurface_layers":
                    plt.gca().invert_yaxis()

                xlabel = varname
                if hasattr(var, "units"):
                    xlabel += f" ({var.units})"
                plt.xlabel(xlabel)

                ylabel = coord_var
                if hasattr(dataset.variables[coord_var], "units"):
                    ylabel += f" ({dataset.variables[coord_var].units})"
                plt.ylabel(ylabel)

                plt.title(f"{varname} vs {coord_var}")

                if output_path:
                    try:
                        plt.savefig(output_path, bbox_inches="tight")
                        print(f"Figure saved to '{output_path}'")
                    except Exception as e:
                        print(f"Error saving figure: {e}")
                else:
                    plt.show()
                plt.close()
                return
            else:
                # No known coordinate found → simple plot vs index
                plt.figure()
                plt.plot(data_slice, marker='o')
                plt.xlabel("Index")
                ylabel = varname
                if hasattr(var, "units"):
                    ylabel += f" ({var.units})"
                plt.ylabel(ylabel)
                plt.title(f"{varname} (1D)")

                if output_path:
                    try:
                        plt.savefig(output_path, bbox_inches="tight")
                        print(f"Figure saved to '{output_path}'")
                    except Exception as e:
                        print(f"Error saving figure: {e}")
                else:
                    plt.show()
                plt.close()
                return

        else:
            # Unable to identify the remaining dimension → error
            print(f"Error: After slicing, data for '{varname}' is 1D but remaining dimension is unknown.")
            return

    # CASE: After slicing, if data_slice.ndim == 2
    if data_slice.ndim == 2:
        # If lat and lon exist in the original dims, re-find their indices
        lat_idx2 = find_dim_index(dims, LAT_DIMS)
        lon_idx2 = find_dim_index(dims, LON_DIMS)

        if lat_idx2 is not None and lon_idx2 is not None:
            # We have a 2D variable on a lat×lon grid
            lat_varname = find_coord_var(dataset, LAT_DIMS)
            lon_varname = find_coord_var(dataset, LON_DIMS)
            if lat_varname is None or lon_varname is None:
                print("Error: Could not locate latitude/longitude variables in the dataset.")
                return

            lat_var = dataset.variables[lat_varname][:]
            lon_var = dataset.variables[lon_varname][:]
            if hasattr(lat_var, "mask"):
                lat_var = np.where(lat_var.mask, np.nan, lat_var.data)
            if hasattr(lon_var, "mask"):
                lon_var = np.where(lon_var.mask, np.nan, lon_var.data)

            # Build 2D coordinate arrays
            if lat_var.ndim == 1 and lon_var.ndim == 1:
                lon2d, lat2d = np.meshgrid(lon_var, lat_var)
            elif lat_var.ndim == 2 and lon_var.ndim == 2:
                lat2d, lon2d = lat_var, lon_var
            else:
                print("Error: Latitude and longitude must both be either 1D or 2D.")
                return

            plt.figure(figsize=(10, 6))
            cf = plt.contourf(lon2d, lat2d, data_slice, cmap=colormap)
            cbar = plt.colorbar(cf)
            if hasattr(var, "units"):
                cbar.set_label(f"{varname} ({var.units})")
            else:
                cbar.set_label(varname)

            lon_label = f"Longitude ({getattr(dataset.variables[lon_varname], 'units', 'degrees')})"
            lat_label = f"Latitude ({getattr(dataset.variables[lat_varname], 'units', 'degrees')})"
            plt.xlabel(lon_label)
            plt.ylabel(lat_label)
            plt.title(f"{varname} (lat × lon)")

            if output_path:
                try:
                    plt.savefig(output_path, bbox_inches="tight")
                    print(f"Figure saved to '{output_path}'")
                except Exception as e:
                    print(f"Error saving figure: {e}")
            else:
                plt.show()
            plt.close()
            return

        else:
            # No lat/lon → two non-geographical dimensions; plot with imshow
            plt.figure(figsize=(8, 6))
            plt.imshow(data_slice, aspect='auto')
            cb_label = varname
            if hasattr(var, "units"):
                cb_label += f" ({var.units})"
            plt.colorbar(label=cb_label)
            plt.xlabel("Dimension 2 Index")
            plt.ylabel("Dimension 1 Index")
            plt.title(f"{varname} (2D without lat/lon)")

            if output_path:
                try:
                    plt.savefig(output_path, bbox_inches="tight")
                    print(f"Figure saved to '{output_path}'")
                except Exception as e:
                    print(f"Error saving figure: {e}")
            else:
                plt.show()
            plt.close()
            return

    # CASE: data_slice.ndim is neither 0, 1, nor 2
    print(f"Error: After slicing, data for '{varname}' has ndim={data_slice.ndim}, which is not supported.")
    return


def visualize_variable_interactive(nc_path=None):
    """
    Interactive mode: prompts for the NetCDF file if not provided, then for the variable,
    then for Time/altitude indices (skipped entirely if variable is purely 1D over Time),
    and for each other dimension offers to fix an index or to plot along that dimension (by typing 'f').

    If a dimension has length 1, the index 0 is chosen automatically.
    """
    # Determine file path
    if nc_path:
        file_input = nc_path
    else:
        readline.set_completer(complete_filename)
        readline.parse_and_bind("tab: complete")
        file_input = input("Enter the path to the NetCDF file: ").strip()

    if not file_input:
        print("No file specified. Exiting.")
        return
    if not os.path.isfile(file_input):
        print(f"Error: '{file_input}' not found.")
        return

    try:
        ds = Dataset(file_input, mode="r")
    except Exception as e:
        print(f"Error: Unable to open '{file_input}': {e}")
        return

    varnames = list(ds.variables.keys())
    if not varnames:
        print("Error: No variables found in the dataset.")
        ds.close()
        return

    # Auto-select if only one variable
    if len(varnames) == 1:
        var_input = varnames[0]
        print(f"Automatically selected the only variable: '{var_input}'")
    else:
        print("\nAvailable variables:")
        for name in varnames:
            print(f"  - {name}")
        print()
        readline.set_completer(make_varname_completer(varnames))
        var_input = input("Enter the name of the variable to visualize: ").strip()
        if var_input not in ds.variables:
            print(f"Error: Variable '{var_input}' not found. Exiting.")
            ds.close()
            return

    dims = ds.variables[var_input].dimensions  # tuple of dimension names

    # If the variable is purely 1D over Time, plot immediately without asking for time index
    if len(dims) == 1 and find_dim_index(dims, TIME_DIMS) is not None:
        plot_variable(
            dataset=ds,
            varname=var_input,
            time_index=None,
            alt_index=None,
            colormap="jet",
            output_path=None,
            extra_indices=None
        )
        ds.close()
        return

    # Otherwise, proceed to prompt for time/altitude and other dimensions
    time_idx = None
    alt_idx = None

    # Prompt for time index if applicable
    t_idx = find_dim_index(dims, TIME_DIMS)
    if t_idx is not None:
        time_len = ds.variables[var_input].shape[t_idx]
        if time_len > 1:
            while True:
                try:
                    user_t = input(f"Enter time index [0..{time_len - 1}]: ").strip()
                    if user_t == "":
                        print("No time index entered. Exiting.")
                        ds.close()
                        return
                    ti = int(user_t)
                    if 0 <= ti < time_len:
                        time_idx = ti
                        break
                except ValueError:
                    pass
                print(f"Invalid index. Enter an integer between 0 and {time_len - 1}.")
        else:
            time_idx = 0
            print("Only one time step available; using index 0.")

    # Prompt for altitude index if applicable
    a_idx = find_dim_index(dims, ALT_DIMS)
    if a_idx is not None:
        alt_len = ds.variables[var_input].shape[a_idx]
        if alt_len > 1:
            while True:
                try:
                    user_a = input(f"Enter altitude index [0..{alt_len - 1}]: ").strip()
                    if user_a == "":
                        print("No altitude index entered. Exiting.")
                        ds.close()
                        return
                    ai = int(user_a)
                    if 0 <= ai < alt_len:
                        alt_idx = ai
                        break
                except ValueError:
                    pass
                print(f"Invalid index. Enter an integer between 0 and {alt_len - 1}.")
        else:
            alt_idx = 0
            print("Only one altitude level available; using index 0.")

    # Identify other dimensions (excluding Time/lat/lon/alt)
    other_dims = []
    for idx_dim, dim_name in enumerate(dims):
        if idx_dim == t_idx or idx_dim == a_idx:
            continue
        if dim_name.lower() in (d.lower() for d in LAT_DIMS + LON_DIMS):
            continue
        other_dims.append((idx_dim, dim_name))

    # For each other dimension, ask user to fix an index or type 'f' to plot along that dimension
    extra_indices = {}
    for idx_dim, dim_name in other_dims:
        dim_len = ds.variables[var_input].shape[idx_dim]
        if dim_len == 1:
            extra_indices[dim_name] = 0
            print(f"Dimension '{dim_name}' has length 1; using index 0.")
        else:
            while True:
                prompt = (
                    f"Enter index for '{dim_name}' [0..{dim_len - 1}] "
                    f"or 'f' to plot along '{dim_name}': "
                )
                user_i = input(prompt).strip().lower()
                if user_i == 'f':
                    # Leave this dimension unfixed → no key in extra_indices
                    break
                if user_i == "":
                    print("No index entered. Exiting.")
                    ds.close()
                    return
                try:
                    idx_val = int(user_i)
                    if 0 <= idx_val < dim_len:
                        extra_indices[dim_name] = idx_val
                        break
                except ValueError:
                    pass
                print(f"Invalid index. Enter an integer between 0 and {dim_len - 1}, or 'f'.")

    # Finally, call plot_variable with collected indices
    plot_variable(
        dataset=ds,
        varname=var_input,
        time_index=time_idx,
        alt_index=alt_idx,
        colormap="jet",
        output_path=None,
        extra_indices=extra_indices
    )
    ds.close()


def visualize_variable_cli(nc_path, varname, time_index, alt_index, colormap, output_path, extra_json):
    """
    Command-line mode: visualize directly, parsing the --extra-indices argument (JSON string).
    """
    if not os.path.isfile(nc_path):
        print(f"Error: '{nc_path}' not found.")
        return

    try:
        ds = Dataset(nc_path, mode="r")
    except Exception as e:
        print(f"Error: Unable to open '{nc_path}': {e}")
        return

    if varname not in ds.variables:
        print(f"Error: Variable '{varname}' not found in '{nc_path}'.")
        ds.close()
        return

    # Parse extra_indices if provided
    extra_indices = {}
    if extra_json:
        try:
            parsed = json.loads(extra_json)
            if isinstance(parsed, dict):
                for k, v in parsed.items():
                    if isinstance(k, str) and isinstance(v, int):
                        extra_indices[k] = v
            else:
                print("Warning: --extra-indices is not a JSON object. Ignored.")
        except json.JSONDecodeError:
            print("Warning: --extra-indices is not valid JSON. Ignored.")

    plot_variable(
        dataset=ds,
        varname=varname,
        time_index=time_index,
        alt_index=alt_index,
        colormap=colormap,
        output_path=output_path,
        extra_indices=extra_indices
    )
    ds.close()


def main():
    parser = argparse.ArgumentParser(
        description="Visualize a 1D/2D slice of a NetCDF variable on various dimension types."
    )
    parser.add_argument(
        "nc_file",
        nargs="?",
        help="Path to the NetCDF file (interactive if omitted)."
    )
    parser.add_argument(
        "--variable", "-v",
        help="Name of the variable to visualize."
    )
    parser.add_argument(
        "--time-index", "-t",
        type=int,
        help="Index on the Time dimension, if applicable (ignored for pure 1D time series)."
    )
    parser.add_argument(
        "--alt-index", "-a",
        type=int,
        help="Index on the altitude dimension, if applicable."
    )
    parser.add_argument(
        "--cmap", "-c",
        default="jet",
        help="Matplotlib colormap (default: 'jet')."
    )
    parser.add_argument(
        "--output", "-o",
        help="If provided, save the figure to this file instead of displaying it."
    )
    parser.add_argument(
        "--extra-indices", "-e",
        help="JSON string to fix indices of dimensions outside Time/lat/lon/alt. "
             "Example: '{\"nslope\":0, \"physical_points\":2}'."
    )

    args = parser.parse_args()

    # If both nc_file and variable are provided → CLI mode
    if args.nc_file and args.variable:
        visualize_variable_cli(
            nc_path=args.nc_file,
            varname=args.variable,
            time_index=args.time_index,
            alt_index=args.alt_index,
            colormap=args.cmap,
            output_path=args.output,
            extra_json=args.extra_indices
        )
    else:
        # Otherwise → fully interactive mode
        visualize_variable_interactive(nc_path=args.nc_file)


if __name__ == "__main__":
    main()

