#!/usr/bin/env python3
##############################################################
### Python script to visualize a variable in a NetCDF file ###
##############################################################

"""
This script can display any numeric variable from a NetCDF file on a lat/lon map.
It supports variables of dimension:
  - (latitude, longitude)
  - (time, latitude, longitude)
  - (altitude, latitude, longitude)
  - (time, altitude, latitude, longitude)

Usage:
  1) Command-line mode:
       python display_netcdf.py /path/to/your_file.nc --variable VAR_NAME [--time-index 0] [--alt-index 0] [--cmap viridis] [--output out.png]

  2) Interactive mode through the prompt:
       python display_netcdf.py

The script will:
  > Attempt to locate latitude and longitude variables (searching for
    names like 'latitude', 'lat', 'longitude', 'lon').
  > If there is exactly one variable in the dataset, select it automatically.
  > Prompt for time/altitude indices if needed (or accept via CLI).
  > Handle masked arrays, converting masked values to NaN.
  > Plot with a default colormap ('jet'), adjustable via --cmap.
  > Optionally save the figure instead of displaying it interactively.
"""


import os
import sys
import glob
import readline
import argparse
import numpy as np
import matplotlib.pyplot as plt
from netCDF4 import Dataset


def complete_filename(text, state):
    """
    Readline tab-completion function for filesystem paths.
    """
    if "*" not in text:
        pattern = text + "*"
    else:
        pattern = text
    matches = glob.glob(os.path.expanduser(pattern))
    matches = [m + "/" if os.path.isdir(m) else m for m in matches]
    try:
        return matches[state]
    except IndexError:
        return None


def make_varname_completer(varnames):
    """
    Returns a readline completer function for the given list of variable names.
    """
    def completer(text, state):
        options = [name for name in varnames if name.startswith(text)]
        try:
            return options[state]
        except IndexError:
            return None
    return completer


# Helper functions to detect common dimension names
TIME_DIMS = ("Time", "time", "time_counter")
ALT_DIMS  = ("altitude",)
LAT_DIMS  = ("latitude", "lat")
LON_DIMS  = ("longitude", "lon")


def find_dim_index(dims, candidates):
    """
    Search through dims tuple for any name in candidates.
    Returns the index if found, else returns None.
    """
    for idx, dim in enumerate(dims):
        for cand in candidates:
            if cand.lower() == dim.lower():
                return idx
    return None


def find_coord_var(dataset, candidates):
    """
    Among dataset variables, return the first variable whose name matches any candidate.
    Returns None if none found.
    """
    for name in dataset.variables:
        for cand in candidates:
            if cand.lower() == name.lower():
                return name
    return None


# Core plotting helper

def plot_variable(dataset, varname, time_index=None, alt_index=None, colormap="jet", output_path=None):
    """
    Extracts the requested slice from the variable and plots it on a lat/lon grid.

    Parameters
    ----------
    dataset    : netCDF4.Dataset object (already open)
    varname    : string name of the variable to plot
    time_index : int or None (if variable has a time dimension)
    alt_index  : int or None (if variable has an altitude dimension)
    colormap   : string colormap name (passed to plt.contourf)
    output_path: string filepath to save figure, or None to display interactively
    """
    var = dataset.variables[varname]
    dims = var.dimensions

    try:
        data_full = var[:]
    except Exception as e:
        print(f"Error: Cannot read data for variable '{varname}': {e}")
        return

    # Convert masked array to NaN
    if hasattr(data_full, "mask"):
        data_full = np.where(data_full.mask, np.nan, data_full.data)

    # Identify dimension indices
    t_idx = find_dim_index(dims, TIME_DIMS)
    a_idx = find_dim_index(dims, ALT_DIMS)
    lat_idx = find_dim_index(dims, LAT_DIMS)
    lon_idx = find_dim_index(dims, LON_DIMS)

    # Check that lat and lon dims exist
    if lat_idx is None or lon_idx is None:
        print("Error: Could not find 'latitude'/'lat' and 'longitude'/'lon' dimensions for plotting.")
        return

    # Build a slice with defaults
    slicer = [slice(None)] * len(dims)
    if t_idx is not None:
        if time_index is None:
            print("Error: Variable has a time dimension; please supply a time index.")
            return
        slicer[t_idx] = time_index
    if a_idx is not None:
        if alt_index is None:
            print("Error: Variable has an altitude dimension; please supply an altitude index.")
            return
        slicer[a_idx] = alt_index

    # Extract the 2D slice
    try:
        data_slice = data_full[tuple(slicer)]
    except Exception as e:
        print(f"Error: Could not slice variable '{varname}': {e}")
        return

    # After slicing, data_slice should be 2D
    if data_slice.ndim != 2:
        print(f"Error: After slicing, data for '{varname}' is not 2D (ndim={data_slice.ndim}).")
        return

    nlat, nlon = data_slice.shape
    # Handle too-small grid (1x1, 1xN, Nx1)
    if nlat < 2 or nlon < 2:
        lat_varname = find_coord_var(dataset, LAT_DIMS)
        lon_varname = find_coord_var(dataset, LON_DIMS)
        if lat_varname and lon_varname:
            lat_vals = dataset.variables[lat_varname][:]
            lon_vals = dataset.variables[lon_varname][:]
            if hasattr(lat_vals, "mask"):
                lat_vals = np.where(lat_vals.mask, np.nan, lat_vals.data)
            if hasattr(lon_vals, "mask"):
                lon_vals = np.where(lon_vals.mask, np.nan, lon_vals.data)
            # Single point
            if nlat == 1 and nlon == 1:
                print(f"Single data point: value={data_slice[0,0]} at (lon={lon_vals[0]}, lat={lat_vals[0]})")
                return
            # 1 x N -> plot vs lon
            if nlat == 1 and nlon > 1:
                x = lon_vals
                y = data_slice[0, :]
                plt.figure()
                plt.plot(x, y, marker='o')
                plt.xlabel(f"Longitude ({getattr(dataset.variables[lon_varname], 'units', 'degrees')})")
                plt.ylabel(varname)
                plt.title(f"{varname} (lat={lat_vals[0]})")
            # N x 1 -> plot vs lat
            elif nlon == 1 and nlat > 1:
                x = lat_vals
                y = data_slice[:, 0]
                plt.figure()
                plt.plot(x, y, marker='o')
                plt.xlabel(f"Latitude ({getattr(dataset.variables[lat_varname], 'units', 'degrees')})")
                plt.ylabel(varname)
                plt.title(f"{varname} (lon={lon_vals[0]})")
            else:
                print("Unexpected slice shape.")
                return
            if output_path:
                plt.savefig(output_path, bbox_inches="tight")
                print(f"Figure saved to '{output_path}'")
            else:
                plt.show()
            plt.close()
            return

    # Locate coordinate variables
    lat_varname = find_coord_var(dataset, LAT_DIMS)
    lon_varname = find_coord_var(dataset, LON_DIMS)
    if lat_varname is None or lon_varname is None:
        print("Error: Could not locate latitude/longitude variables in the dataset.")
        return

    lat_var = dataset.variables[lat_varname][:]
    lon_var = dataset.variables[lon_varname][:]
    if hasattr(lat_var, "mask"):
        lat_var = np.where(lat_var.mask, np.nan, lat_var.data)
    if hasattr(lon_var, "mask"):
        lon_var = np.where(lon_var.mask, np.nan, lon_var.data)

    # Build 2D coordinate arrays
    if lat_var.ndim == 1 and lon_var.ndim == 1:
        lon2d, lat2d = np.meshgrid(lon_var, lat_var)
    elif lat_var.ndim == 2 and lon_var.ndim == 2:
        lat2d, lon2d = lat_var, lon_var
    else:
        print("Error: Latitude and longitude variables must both be either 1D or 2D.")
        return

    # Retrieve units if available
    var_units = getattr(var, "units", "")
    lat_units = getattr(dataset.variables[lat_varname], "units", "degrees")
    lon_units = getattr(dataset.variables[lon_varname], "units", "degrees")

    # Plot
    plt.figure(figsize=(10, 6))
    cf = plt.contourf(lon2d, lat2d, data_slice, cmap=colormap)
    cbar = plt.colorbar(cf)
    if var_units:
        cbar.set_label(f"{varname} ({var_units})")
    else:
        cbar.set_label(varname)

    plt.xlabel(f"Longitude ({lon_units})")
    plt.ylabel(f"Latitude ({lat_units})")
    plt.title(f"{varname} Visualization")

    if output_path:
        try:
            plt.savefig(output_path, bbox_inches="tight")
            print(f"Figure saved to '{output_path}'")
        except Exception as e:
            print(f"Error saving figure: {e}")
    else:
        plt.show()
    plt.close()


def visualize_variable_interactive(nc_path=None):
    """
    Interactive mode: if nc_path is provided, skip prompting for filename.
    Otherwise, prompt for filename. Then select variable (automatically if only one), prompt for indices.
    """
    # Determine file path
    if nc_path:
        file_input = nc_path
    else:
        readline.set_completer(complete_filename)
        readline.parse_and_bind("tab: complete")
        file_input = input("Enter the path to the NetCDF file: ").strip()

    if not file_input:
        print("No file specified. Exiting.")
        return
    if not os.path.isfile(file_input):
        print(f"Error: '{file_input}' not found.")
        return

    try:
        ds = Dataset(file_input, mode="r")
    except Exception as e:
        print(f"Error: Unable to open '{file_input}': {e}")
        return

    varnames = list(ds.variables.keys())
    if not varnames:
        print("Error: No variables found in the dataset.")
        ds.close()
        return

    # Auto-select if only one variable
    if len(varnames) == 1:
        var_input = varnames[0]
        print(f"Automatically selected the only variable: '{var_input}'")
    else:
        print("\nAvailable variables:")
        for name in varnames:
            print(f"  - {name}")
        print()
        readline.set_completer(make_varname_completer(varnames))
        var_input = input("Enter the name of the variable to visualize: ").strip()
        if var_input not in ds.variables:
            print(f"Error: Variable '{var_input}' not found. Exiting.")
            ds.close()
            return

    dims = ds.variables[var_input].dimensions
    time_idx = None
    alt_idx = None

    t_idx = find_dim_index(dims, TIME_DIMS)
    if t_idx is not None:
        length = ds.variables[var_input].shape[t_idx]
        if length > 1:
            while True:
                try:
                    user_t = input(f"Enter time index [0..{length - 1}]: ").strip()
                    if user_t == "":
                        print("No time index entered. Exiting.")
                        ds.close()
                        return
                    time_idx = int(user_t)
                    if 0 <= time_idx < length:
                        break
                except ValueError:
                    pass
                print(f"Invalid index. Enter an integer between 0 and {length - 1}.")
        else:
            time_idx = 0
            print("Only one time step available; using index 0.")

    a_idx = find_dim_index(dims, ALT_DIMS)
    if a_idx is not None:
        length = ds.variables[var_input].shape[a_idx]
        if length > 1:
            while True:
                try:
                    user_a = input(f"Enter altitude index [0..{length - 1}]: ").strip()
                    if user_a == "":
                        print("No altitude index entered. Exiting.")
                        ds.close()
                        return
                    alt_idx = int(user_a)
                    if 0 <= alt_idx < length:
                        break
                except ValueError:
                    pass
                print(f"Invalid index. Enter an integer between 0 and {length - 1}.")
        else:
            alt_idx = 0
            print("Only one altitude level available; using index 0.")

    plot_variable(
        dataset=ds,
        varname=var_input,
        time_index=time_idx,
        alt_index=alt_idx,
        colormap="jet",
        output_path=None
    )
    ds.close()


def visualize_variable_cli(nc_path, varname, time_index, alt_index, colormap, output_path):
    """
    Command-line mode: directly visualize based on provided arguments.
    """
    if not os.path.isfile(nc_path):
        print(f"Error: '{nc_path}' not found.")
        return
    try:
        ds = Dataset(nc_path, mode="r")
    except Exception as e:
        print(f"Error: Unable to open '{nc_path}': {e}")
        return

    if varname not in ds.variables:
        print(f"Error: Variable '{varname}' not found in '{nc_path}'.")
        ds.close()
        return

    plot_variable(
        dataset=ds,
        varname=varname,
        time_index=time_index,
        alt_index=alt_index,
        colormap=colormap,
        output_path=output_path
    )
    ds.close()


def main():
    parser = argparse.ArgumentParser(
        description="Visualize a 2D slice of a NetCDF variable on a latitude-longitude map."
    )
    parser.add_argument(
        "nc_file",
        nargs="?",
        help="Path to the NetCDF file (interactive if omitted)."
    )
    parser.add_argument(
        "--variable", "-v",
        help="Name of the variable to visualize."
    )
    parser.add_argument(
        "--time-index", "-t",
        type=int,
        help="Index along the time dimension (if applicable)."
    )
    parser.add_argument(
        "--alt-index", "-a",
        type=int,
        help="Index along the altitude dimension (if applicable)."
    )
    parser.add_argument(
        "--cmap", "-c",
        default="jet",
        help="Matplotlib colormap (default: 'jet')."
    )
    parser.add_argument(
        "--output", "-o",
        help="If provided, save the plot to this file instead of displaying it."
    )

    args = parser.parse_args()

    # If nc_file is provided but variable is missing: ask only for variable
    if args.nc_file and not args.variable:
        visualize_variable_interactive(nc_path=args.nc_file)
    # If either nc_file or variable is missing, run fully interactive
    elif not args.nc_file or not args.variable:
        visualize_variable_interactive()
    else:
        visualize_variable_cli(
            nc_path=args.nc_file,
            varname=args.variable,
            time_index=args.time_index,
            alt_index=args.alt_index,
            colormap=args.cmap,
            output_path=args.output
        )


if __name__ == "__main__":
    main()

