#!/usr/bin/env python3
##############################################################
### Python script to visualize a variable in a NetCDF file ###
##############################################################

"""
This script can display any numeric variable from a NetCDF file.
It supports the following cases:
  - Scalar output
  - 1D time series
  - 1D vertical profiles
  - 2D latitude/longitude map
  - 2D cross-sections
  - Optionally average over latitude and plot longitude vs. time heatmap
  - Optionally display polar stereographic view of 2D maps
  - Optionally display 3D globe view of 2D maps

Automatic setup from the environment file found in the "LMDZ.MARS/util" folder:
  1. Make sure Conda is installed.
  2. In terminal, navigate to the folder containing this script.
  4. Create the environment:
       conda env create -f display_netcdf.yml
  5. Activate the environment:
       conda activate my_env
  6. Run the script:
     python display_netcdf.py

Usage:
  1) Command-line mode:
       python display_netcdf.py /path/to/your_file.nc --variable VAR_NAME [options]
     Options:
    --variable         : Name of the variable to visualize.
    --time-index       : Index along the Time dimension (0-based, ignored for purely 1D time series).
    --alt-index        : Index along the altitude dimension (0-based), if present.
    --cmap             : Matplotlib colormap (default: "jet").
    --avg-lat          : Average over latitude and plot longitude vs. time heatmap.
    --slice-lon-index  : Fixed longitude index for altitude×longitude cross-section.
    --slice-lat-index  : Fixed latitude index for altitude×latitude cross-section.
    --show-topo        : Overlay MOLA topography on lat/lon maps.
    --show-polar       :
    --show-3d          :
    --output           : If provided, save the figure to this filename instead of displaying.
    --extra-indices    : JSON string to fix indices for any other dimensions.
                         For dimensions with "slope", use 1-based numbering here.
                         Example: '{"nslope": 1, "physical_points": 3}'

  2) Interactive mode:
       python display_netcdf.py
     The script will prompt for everything.
"""

import os
import sys
import glob
import readline
import argparse
import json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.path as mpath
import matplotlib.colors as mcolors
import cartopy.crs as ccrs
import pandas as pd
from netCDF4 import Dataset

# Attempt vedo import early for global use
try:
    import vedo
    from vedo import *
    from scipy.interpolate import RegularGridInterpolator
    vedo_available = True
except ImportError:
    vedo_available = False

# Constants for recognized dimension names
TIME_DIMS = ("Time", "time", "time_counter")
ALT_DIMS  = ("altitude",)
LAT_DIMS  = ("latitude", "lat")
LON_DIMS  = ("longitude", "lon")

# Paths for MOLA data
MOLA_NPY = 'MOLA_1px_per_deg.npy'
MOLA_CSV = 'molaTeam_contour_31rgb_steps.csv'

# Attempt to load MOLA topography
try:
    MOLA = np.load('MOLA_1px_per_deg.npy')  # shape (nlat, nlon) at 1° per pixel: lat from -90 to 90, lon from 0 to 360
    nlat, nlon = MOLA.shape
    topo_lats = np.linspace(90 - 0.5, -90 + 0.5, nlat)
    topo_lons = np.linspace(-180 + 0.5, 180 - 0.5, nlon)
    topo_lon2d, topo_lat2d = np.meshgrid(topo_lons, topo_lats)
    topo_loaded = True
except Exception as e:
    print(f"Warning: '{MOLA_NPY}' not found: {e}")
    topo_loaded = False


# Attempt to load contour color table
if os.path.isfile(MOLA_CSV):
    color_table = pd.read_csv(MOLA_CSV)
    csv_loaded = True
else:
    print(f"Warning: '{MOLA_CSV}' not found. 3D view colors disabled.")
    csv_loaded = False


def complete_filename(text, state):
    """
    Tab-completion for filesystem paths.
    """
    if "*" not in text:
        pattern = text + "*"
    else:
        pattern = text
    matches = glob.glob(os.path.expanduser(pattern))
    matches = [m + "/" if os.path.isdir(m) else m for m in matches]
    try:
        return matches[state]
    except IndexError:
        return None


def make_varname_completer(varnames):
    """
    Returns a readline completer for variable names.
    """
    def completer(text, state):
        options = [name for name in varnames if name.startswith(text)]
        try:
            return options[state]
        except IndexError:
            return None
    return completer


def find_dim_index(dims, candidates):
    """
    Search through dims tuple for any name in candidates.
    Returns the index if found, else returns None.
    """
    for idx, dim in enumerate(dims):
        for cand in candidates:
            if cand.lower() == dim.lower():
                return idx
    return None


def find_coord_var(dataset, candidates):
    """
    Among dataset variables, return the first variable whose name matches any candidate.
    Returns None if none found.
    """
    for name in dataset.variables:
        for cand in candidates:
            if cand.lower() == name.lower():
                return name
    return None


def overlay_topography(ax, transform, levels=10):
    """
    Overlay MOLA topography contours onto a given GeoAxes.
    """
    if not topo_loaded:
        return
    ax.contour(
        topo_lon2d, topo_lat2d, MOLA,
        levels=levels,
        linewidths=0.5,
        colors='black',
        transform=transform
    )


def plot_polar_views(lon2d, lat2d, data2d, colormap, varname, units=None, topo_overlay=True):
    """
    Plot two polar‐stereographic views (north & south) of the same data.
    """
    figs = []  # collect figure handles

    for pole in ("north", "south"):
        # Choose projection and extent for each pole
        if pole == "north":
            proj = ccrs.NorthPolarStereo(central_longitude=180)
            extent = [-180, 180, 60, 90]
        else:
            proj = ccrs.SouthPolarStereo(central_longitude=180)
            extent = [-180, 180, -90, -60]

        # Create figure and GeoAxes
        fig = plt.figure(figsize=(8, 6))
        ax = fig.add_subplot(1, 1, 1, projection=proj, aspect=True)
        ax.set_global()
        ax.set_extent(extent, ccrs.PlateCarree())

        # Draw circular boundary
        theta = np.linspace(0, 2 * np.pi, 100)
        center, radius = [0.5, 0.5], 0.5
        verts = np.vstack([np.sin(theta), np.cos(theta)]).T
        circle = mpath.Path(verts * radius + center)
        ax.set_boundary(circle, transform=ax.transAxes)

        # Add meridians/parallels
        gl = ax.gridlines(
            draw_labels=True,
            color='k',
            xlocs=range(-180, 181, 30),
            ylocs=range(-90, 91, 10),
            linestyle='--',
            linewidth=0.5
        )
        #gl.top_labels = False
        #gl.right_labels = False

        # Plot data in PlateCarree projection
        cf = ax.contourf(
            lon2d, lat2d, data2d,
            levels=100,
            cmap=colormap,
            transform=ccrs.PlateCarree()
        )

        # Optionally overlay MOLA topography
        if topo_overlay:
            overlay_topography(ax, transform=ccrs.PlateCarree(), levels=20)

        # Colorbar and title
        cbar = fig.colorbar(cf, ax=ax, pad=0.1)
        label = varname + (f" ({units})" if units else "")
        cbar.set_label(label)
        ax.set_title(f"{varname} — {pole.capitalize()} Pole", pad=50)

        figs.append(fig)

    # Show both figures
    plt.show()


def plot_3D_globe(lon2d, lat2d, data2d, colormap, varname, units=None):
    """
    Plot a 3D globe view of the data using vedo, with surface coloring based on data2d
    and overlaid contour lines from MOLA topography.
    """
    if not vedo_available:
        print("3D view skipped: vedo missing.")
        return
    if not csv_loaded:
        print("3D view skipped: color table missing.")
        return

    # Prepare MOLA grid
    nlat, nlon = MOLA.shape
    lats = np.linspace(90, -90, nlat)
    lons = np.linspace(-180, 180, nlon)
    lon_grid, lat_grid = np.meshgrid(lons, lats)

    # Interpolate data2d onto MOLA grid
    lat_data = np.linspace(-90, 90, data2d.shape[0])
    lon_data = np.linspace(-180, 180, data2d.shape[1])
    interp2d = RegularGridInterpolator((lat_data, lon_data), data2d,
                                       bounds_error=False, fill_value=None)
    newdata2d = interp2d((lat_grid, lon_grid))

    # Generate contour lines from MOLA
    cs = plt.contour(lon_grid, lat_grid, MOLA, levels=10, linewidths=0)
    plt.clf()
    contour_lines = []
    radius = 3389500 # Mars average radius [m]
    for segs, level in zip(cs.allsegs, cs.levels):
        for verts in segs:
            lon_c = verts[:, 0]
            lat_c = verts[:, 1]
            phi_c = np.radians(90 - lat_c)
            theta_c = np.radians(lon_c)
            elev = RegularGridInterpolator((lats, lons), MOLA,
                                           bounds_error=False,
                                           fill_value=0.0)((lat_c, lon_c))
            r_cont = radius + elev * 10
            x_c = r_cont * np.sin(phi_c) * np.cos(theta_c) * 1.002
            y_c = r_cont * np.sin(phi_c) * np.sin(theta_c) * 1.002
            z_c = r_cont * np.cos(phi_c) * 1.002
            pts = np.column_stack([x_c, y_c, z_c])
            if pts.shape[0] > 1:
                contour_lines.append(Line(pts, c='k', lw=0.5))

    # Create sphere surface mesh
    phi = np.deg2rad(90 - lat_grid)
    theta = np.deg2rad(lon_grid)
    r = radius + MOLA * 10
    x = r * np.sin(phi) * np.cos(theta)
    y = r * np.sin(phi) * np.sin(theta)
    z = r * np.cos(phi)
    pts = np.stack([x.ravel(), y.ravel(), z.ravel()], axis=1)

    # Build mesh faces
    faces = []
    for i in range(nlat - 1):
        for j in range(nlon - 1):
            p0 = i * nlon + j
            p1 = p0 + 1
            p2 = p0 + nlon
            p3 = p2 + 1
            faces.extend([(p0, p2, p1), (p1, p2, p3)])

    mesh = Mesh([pts, faces])
    mesh.cmap(colormap, newdata2d.ravel())
    mesh.add_scalarbar(title=varname + (f' [{units}]' if units else ''), c='white')
    mesh.lighting('default')

    # Geographic grid lines
    meridians, parallels, labels = [], [], []
    zero_lon_offset = radius * 0.03
    for lon in range(-150, 181, 30):
        lat_line = np.linspace(-90, 90, nlat)
        lon_line = np.full_like(lat_line, lon)
        phi = np.deg2rad(90 - lat_line)
        theta = np.deg2rad(lon_line)
        elev = RegularGridInterpolator((lats, lons), MOLA)((lat_line, lon_line))
        rr = radius + elev * 10
        pts_line = np.column_stack([
            rr * np.sin(phi) * np.cos(theta),
            rr * np.sin(phi) * np.sin(theta),
            rr * np.cos(phi)
        ]) * 1.005
        label_pos = pts_line[len(pts_line)//2]
        norm = np.linalg.norm(label_pos)
        label_pos_out = label_pos / norm * (norm + radius * 0.02)
        if lon == 0:
            label_pos_out[1] += zero_lon_offset
        meridians.append(Line(pts_line, c='k', lw=1)#.flagpole(
            #f"{lon}°",
            #point=label_pos_out,
            #offset=[0, 0, radius * 0.05],
            #s=radius*0.01,
            #c='yellow'
        #).follow_camera()
        )

    for lat in range(-60, 91, 30):
        lon_line = np.linspace(-180, 180, nlon)
        lat_line = np.full_like(lon_line, lat)
        phi = np.deg2rad(90 - lat_line)
        theta = np.deg2rad(lon_line)
        elev = RegularGridInterpolator((lats, lons), MOLA)((lat_line, lon_line))
        rr = radius + elev * 10
        pts_line = np.column_stack([
            rr * np.sin(phi) * np.cos(theta),
            rr * np.sin(phi) * np.sin(theta),
            rr * np.cos(phi)
        ]) * 1.005
        label_pos = pts_line[len(pts_line)//2]
        norm = np.linalg.norm(label_pos)
        label_pos_out = label_pos / norm * (norm + radius * 0.02)
        parallels.append(Line(pts_line, c='k', lw=1)#.flagpole(
            #f"{lat}°",
            #point=label_pos_out,
            #offset=[0, 0, radius * 0.05],
            #s=radius*0.01,
            #c='yellow'
        #).follow_camera()
        )

    # Create plotter
    plotter = Plotter(title="3D topography view", bg="bb", axes=0)

    # Configure camera
    cam_dist = radius * 3
    plotter.camera.SetPosition([cam_dist, 0, 0])
    plotter.camera.SetFocalPoint([0, 0, 0])
    plotter.camera.SetViewUp([0, 0, 1])

    # Show the globe
    plotter.show(mesh, *contour_lines, *meridians, *parallels)


def plot_variable(dataset, varname, time_index=None, alt_index=None,
                  colormap="jet", output_path=None, extra_indices=None,
                  avg_lat=False):
    """
    Core plotting logic: reads the variable, handles masks,
    determines dimensionality, and creates the appropriate plot:
      - 1D time series
      - 1D profiles or physical_points maps
      - 2D lat×lon or generic 2D
      - Time×lon heatmap if avg_lat=True
      - Scalar printing
    """
    var = dataset.variables[varname]
    dims = var.dimensions

    # Read full data
    try:
        data_full = var[:]
    except Exception as e:
        print(f"Error: Cannot read data for '{varname}': {e}")
        return
    if hasattr(data_full, "mask"):
        data_full = np.where(data_full.mask, np.nan, data_full.data)

    # Pure 1D time series
    if len(dims) == 1 and find_dim_index(dims, TIME_DIMS) is not None:
        time_var = find_coord_var(dataset, TIME_DIMS)
        tvals = (dataset.variables[time_var][:] if time_var
                 else np.arange(data_full.shape[0]))
        if hasattr(tvals, "mask"):
            tvals = np.where(tvals.mask, np.nan, tvals.data)
        plt.figure()
        plt.plot(tvals, data_full, marker="o")
        plt.xlabel(time_var or "Time Index")
        plt.ylabel(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
        plt.title(f"{varname} vs {time_var or 'Index'}")
        if output_path:
            plt.savefig(output_path, bbox_inches="tight")
            print(f"Saved to {output_path}")
        else:
            plt.show()
        return

    # Identify dims
    t_idx = find_dim_index(dims, TIME_DIMS)
    lat_idx = find_dim_index(dims, LAT_DIMS)
    lon_idx = find_dim_index(dims, LON_DIMS)
    a_idx = find_dim_index(dims, ALT_DIMS)

    # Average over latitude & plot time × lon heatmap
    if avg_lat and t_idx is not None and lat_idx is not None and lon_idx is not None:
        # compute mean over lat axis
        data_avg = np.nanmean(data_full, axis=lat_idx)
        # prepare coordinates
        time_var = find_coord_var(dataset, TIME_DIMS)
        lon_var = find_coord_var(dataset, LON_DIMS)
        tvals = dataset.variables[time_var][:]
        lons = dataset.variables[lon_var][:]
        if hasattr(tvals, "mask"):
            tvals = np.where(tvals.mask, np.nan, tvals.data)
        if hasattr(lons, "mask"):
            lons = np.where(lons.mask, np.nan, lons.data)
        plt.figure(figsize=(10, 6))
        plt.pcolormesh(lons, tvals, data_avg, shading="auto", cmap=colormap)
        plt.xlabel(f"Longitude ({getattr(dataset.variables[lon_var], 'units', 'deg')})")
        plt.ylabel(time_var)
        cbar = plt.colorbar()
        cbar.set_label(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
        plt.title(f"{varname} averaged over latitude")
        if output_path:
            plt.savefig(output_path, bbox_inches="tight")
            print(f"Saved to {output_path}")
        else:
            plt.show()
        return

    # Build slicer for other cases
    slicer = [slice(None)] * len(dims)
    if t_idx is not None:
        if time_index is None:
            print("Error: please supply a time index.")
            return
        slicer[t_idx] = time_index
    if a_idx is not None:
        if alt_index is None:
            print("Error: please supply an altitude index.")
            return
        slicer[a_idx] = alt_index

    if extra_indices is None:
        extra_indices = {}
    for dn, idx_val in extra_indices.items():
        if dn in dims:
            slicer[dims.index(dn)] = idx_val

    # Extract slice
    try:
        dslice = data_full[tuple(slicer)]
    except Exception as e:
        print(f"Error slicing '{varname}': {e}")
        return

    # Scalar
    if np.ndim(dslice) == 0:
        print(f"Scalar '{varname}': {float(dslice)}")
        return

    # 1D: vector, profile, or physical_points
    if dslice.ndim == 1:
        rem = [(i, name) for i, name in enumerate(dims) if slicer[i] == slice(None)]
        if rem:
            di, dname = rem[0]
            # physical_points → interpolated map
            if dname.lower() == "physical_points":
                latv = find_coord_var(dataset, LAT_DIMS)
                lonv = find_coord_var(dataset, LON_DIMS)
                if latv and lonv:
                    lats = dataset.variables[latv][:]
                    lons = dataset.variables[lonv][:]

                    # Unmask
                    if hasattr(lats, "mask"):
                        lats = np.where(lats.mask, np.nan, lats.data)
                    if hasattr(lons, "mask"):
                        lons = np.where(lons.mask, np.nan, lons.data)

                    # Convert radians to degrees if needed
                    lats_deg = np.round(np.degrees(lats), 6)
                    lons_deg = np.round(np.degrees(lons), 6)

                    # Build regular grid
                    uniq_lats = np.unique(lats_deg)
                    uniq_lons = np.unique(lons_deg)
                    nlon = len(uniq_lons)

                    data2d = []
                    for lat_val in uniq_lats:
                        mask = lats_deg == lat_val
                        slice_vals = dslice[mask]
                        lons_at_lat = lons_deg[mask]
                        if len(slice_vals) == 1:
                            row = np.full(nlon, slice_vals[0])
                        else:
                            order = np.argsort(lons_at_lat)
                            row = np.full(nlon, np.nan)
                            row[: len(slice_vals)] = slice_vals[order]
                        data2d.append(row)
                    data2d = np.array(data2d)

                    # Wrap longitude if needed
                    if -180.0 in uniq_lons:
                        idx = np.where(np.isclose(uniq_lons, -180.0))[0][0]
                        data2d = np.hstack([data2d, data2d[:, [idx]]])
                        uniq_lons = np.append(uniq_lons, 180.0)

                    # Plot interpolated map
                    proj = ccrs.PlateCarree()
                    fig, ax = plt.subplots(subplot_kw=dict(projection=proj), figsize=(8, 6))
                    lon2d, lat2d = np.meshgrid(uniq_lons, uniq_lats)
                    lon_ticks = np.arange(-180, 181, 30)
                    lat_ticks = np.arange(-90, 91, 30)
                    ax.set_xticks(lon_ticks, crs=ccrs.PlateCarree())
                    ax.set_yticks(lat_ticks, crs=ccrs.PlateCarree())
                    ax.tick_params(
                        axis='x', which='major',
                        length=4,
                        direction='out',
                        pad=2,
                        labelsize=8
                    )
                    ax.tick_params(
                       axis='y', which='major',
                       length=4,
                       direction='out',
                       pad=2,
                       labelsize=8
                    )
                    cf = ax.contourf(
                        lon2d, lat2d, data2d,
                        levels=100,
                        cmap=colormap,
                        transform=proj
                    )

                    # Overlay MOLA topography
                    overlay_topography(ax, transform=proj, levels=10)

                    # Colorbar & labels
                    cbar = fig.colorbar(cf, ax=ax, pad=0.02)
                    cbar.set_label(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
                    ax.set_title(f"{varname} (interpolated map over physical_points)")
                    ax.set_xlabel(f"Longitude ({getattr(dataset.variables[lonv], 'units', 'deg')})")
                    ax.set_ylabel(f"Latitude ({getattr(dataset.variables[latv], 'units', 'deg')})")

                    # Prompt for polar-stereo views if interactive
                    if input("Display polar-stereo views? [y/n]: ").strip().lower() == "y":
                        units = getattr(dataset.variables[varname], "units", None)
                        plot_polar_views(lon2d, lat2d, data2d, colormap, varname, units)

                    # Prompt for 3D globe view if interactive
                    if input("Display 3D globe view? [y/n]: ").strip().lower() == "y":
                        units = getattr(dataset.variables[varname], "units", None)
                        plot_3D_globe(lon2d, lat2d, data2d, colormap, varname, units)

                    if output_path:
                        plt.savefig(output_path, bbox_inches="tight")
                        print(f"Saved to {output_path}")
                    else:
                        plt.show()
                    return
            # vertical profile?
            coord = None
            if dname.lower() == "subsurface_layers" and "soildepth" in dataset.variables:
                coord = "soildepth"
            elif dname in dataset.variables:
                coord = dname
            if coord:
                coords = dataset.variables[coord][:]
                if hasattr(coords, "mask"):
                    coords = np.where(coords.mask, np.nan, coords.data)
                plt.figure()
                plt.plot(dslice, coords, marker="o")
                if dname.lower() == "subsurface_layers":
                    plt.gca().invert_yaxis()
                plt.xlabel(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
                plt.ylabel(coord + (f" ({dataset.variables[coord].units})" if hasattr(dataset.variables[coord], "units") else ""))
                plt.title(f"{varname} vs {coord}")
                if output_path:
                    plt.savefig(output_path, bbox_inches="tight")
                    print(f"Saved to {output_path}")
                else:
                    plt.show()
                return
        # generic 1D
        plt.figure()
        plt.plot(dslice, marker="o")
        plt.xlabel("Index")
        plt.ylabel(varname + (f" ({var.units})" if hasattr(var, "units") else ""))
        plt.title(f"{varname} (1D)")
        if output_path:
            plt.savefig(output_path, bbox_inches="tight")
            print(f"Saved to {output_path}")
        else:
            plt.show()
        return

    # if dslice.ndim == 2:
        lat_idx2 = find_dim_index(dims, LAT_DIMS)
        lon_idx2 = find_dim_index(dims, LON_DIMS)

        # Geographic lat×lon slice
        if lat_idx2 is not None and lon_idx2 is not None:
            latv = find_coord_var(dataset, LAT_DIMS)
            lonv = find_coord_var(dataset, LON_DIMS)
            lats = dataset.variables[latv][:]
            lons = dataset.variables[lonv][:]

            # Handle masked arrays
            if hasattr(lats, "mask"):
                lats = np.where(lats.mask, np.nan, lats.data)
            if hasattr(lons, "mask"):
                lons = np.where(lons.mask, np.nan, lons.data)

            # Create map projection
            proj = ccrs.PlateCarree()
            fig, ax = plt.subplots(figsize=(10, 6), subplot_kw=dict(projection=proj))

            # Make meshgrid and plot
            lon2d, lat2d = np.meshgrid(lons, lats)
            cf = ax.contourf(
                lon2d, lat2d, dslice,
                levels=100,
                cmap=colormap,
                transform=proj
            )

            # Overlay topography
            overlay_topography(ax, transform=proj, levels=10)

            # Colorbar and labels
            lon_ticks = np.arange(-180, 181, 30)
            lat_ticks = np.arange(-90, 91, 30)
            ax.set_xticks(lon_ticks, crs=ccrs.PlateCarree())
            ax.set_yticks(lat_ticks, crs=ccrs.PlateCarree())
            ax.tick_params(
                axis='x', which='major',
                length=4,
                direction='out',
                pad=2,
                labelsize=8
            )
            ax.tick_params(
                axis='y', which='major',
                length=4,
                direction='out',
                pad=2,
                labelsize=8
            )
            cbar = fig.colorbar(cf, ax=ax, orientation="vertical", pad=0.02)
            cbar.set_label(varname + (f" ({dataset.variables[varname].units})"
                                      if hasattr(dataset.variables[varname], "units") else ""))
            ax.set_title(f"{varname} (lat × lon)")
            ax.set_xlabel(f"Longitude ({getattr(dataset.variables[lonv], 'units', 'deg')})")
            ax.set_ylabel(f"Latitude ({getattr(dataset.variables[latv], 'units', 'deg')})")

            # Prompt for polar-stereo views if interactive
            if sys.stdin.isatty() and input("Display polar-stereo views? [y/n]: ").strip().lower() == "y":
                units = getattr(dataset.variables[varname], "units", None)
                plot_polar_views(lon2d, lat2d, dslice, colormap, varname, units)

            # Prompt for 3D globe view if interactive
            if sys.stdin.isatty() and input("Display 3D globe view? [y/n]: ").strip().lower() == "y":
                units = getattr(dataset.variables[varname], "units", None)
                plot_3D_globe(lon2d, lat2d, dslice, colormap, varname, units)

            if output_path:
                plt.savefig(output_path, bbox_inches="tight")
                print(f"Saved to {output_path}")
            else:
                plt.show()
            return

        # Generic 2D
        plt.figure(figsize=(8, 6))
        plt.imshow(dslice, aspect="auto")
        plt.colorbar(label=varname + (f" ({var.units})" if hasattr(var, "units") else ""))
        plt.xlabel("Dim 2 index")
        plt.ylabel("Dim 1 index")
        plt.title(f"{varname} (2D)")
        if output_path:
            plt.savefig(output_path, bbox_inches="tight")
            print(f"Saved to {output_path}")
        else:
            plt.show()
        return

    print(f"Error: ndim={dslice.ndim} not supported.")


def visualize_variable_interactive(nc_path=None):
    """
    Interactive loop: keep prompting for variables to plot until user quits.
    """
    # Open dataset
    if nc_path:
        path = nc_path
    else:
        readline.set_completer(complete_filename)
        readline.parse_and_bind("tab: complete")
        path = input("Enter path to NetCDF file: ").strip()

    if not os.path.isfile(path):
        print(f"Error: '{path}' not found.")
        return

    ds = Dataset(path, "r")
    var_list = list(ds.variables.keys())
    if not var_list:
        print("No variables found in file.")
        ds.close()
        return

    # Enable interactive mode
    plt.ion()

    while True:
        # Enable tab-completion for variable names
        readline.set_completer(make_varname_completer(var_list))
        readline.parse_and_bind("tab: complete")

        print("\nAvailable variables:")
        for name in var_list:
            print(f"  - {name}")
        varname = input("\nEnter variable name to plot (or 'q' to quit): ").strip()
        if varname.lower() in ("q", "quit", "exit"):
            print("Exiting.")
            break
        if varname not in ds.variables:
            print(f"Variable '{varname}' not found. Try again.")
            continue

        # Display dimensions and size
        var = ds.variables[varname]
        dims, shape = var.dimensions, var.shape
        print(f"\nVariable '{varname}' has dimensions:")
        for dim, size in zip(dims, shape):
            print(f"  - {dim}: size {size}")
        print()

        # Prepare slicing parameters
        time_index = None
        alt_index = None
        avg = False
        extra_indices = {}

        # Time index
        t_idx = find_dim_index(dims, TIME_DIMS)
        if t_idx is not None:
            if shape[t_idx] > 1:
                while True:
                    idx = input(f"Enter time index [1–{shape[t_idx]}] (press Enter for all): ").strip()
                    if idx == '':
                        time_index = None
                        break
                    if idx.isdigit():
                        i = int(idx)
                        if 1 <= i <= shape[t_idx]:
                            time_index = i - 1
                            break
                    print("Invalid entry. Please enter a valid number or press Enter.")
            else:
                time_index = 0

        # Altitude index
        a_idx = find_dim_index(dims, ALT_DIMS)
        if a_idx is not None:
            if shape[a_idx] > 1:
                while True:
                    idx = input(f"Enter altitude index [1–{shape[a_idx]}] (press Enter for all): ").strip()
                    if idx == '':
                        alt_index = None
                        break
                    if idx.isdigit():
                        i = int(idx)
                        if 1 <= i <= shape[a_idx]:
                            alt_index = i - 1
                            break
                    print("Invalid entry. Please enter a valid number or press Enter.")
            else:
                alt_index = 0

        # Average over latitude?
        lat_idx = find_dim_index(dims, LAT_DIMS)
        lon_idx = find_dim_index(dims, LON_DIMS)
        if (t_idx is not None and lat_idx is not None and lon_idx is not None and
            shape[t_idx] > 1 and shape[lat_idx] > 1 and shape[lon_idx] > 1):
            resp = input("Average over latitude and plot lon vs time? [y/n]: ").strip().lower()
            avg = (resp == 'y')

        # Other dimensions
        for i, dname in enumerate(dims):
            if i in (t_idx, a_idx):
                continue
            size = shape[i]
            if size == 1:
                extra_indices[dname] = 0
                continue
            while True:
                idx = input(f"Enter index [1–{size}] for '{dname}' (press Enter for all): ").strip()
                if idx == '':
                    # keep all values
                    break
                if idx.isdigit():
                    j = int(idx)
                    if 1 <= j <= size:
                        extra_indices[dname] = j - 1
                        break
                print("Invalid entry. Please enter a valid number or press Enter.")

        # Plot the variable
        plot_variable(
            ds, varname,
            time_index    = time_index,
            alt_index     = alt_index,
            colormap      = 'jet',
            output_path   = None,
            extra_indices = extra_indices,
            avg_lat       = avg
        )

    ds.close()


def visualize_variable_cli(nc_file, varname, time_index, alt_index,
                           colormap, output_path, extra_json, avg_lat):
    """
    Command-line mode: visualize directly, parsing the --extra-indices argument (JSON string).
    """
    if not os.path.isfile(nc_file):
        print(f"Error: '{nc_file}' not found.")
        return
    ds = Dataset(nc_file, "r")
    if varname not in ds.variables:
        print(f"Variable '{varname}' not in file.")
        ds.close()
        return

    # Display dimensions and size
    dims  = ds.variables[varname].dimensions
    shape = ds.variables[varname].shape
    print(f"\nVariable '{varname}' has {len(dims)} dimensions:")
    for name, size in zip(dims, shape):
        print(f"  - {name}: size {size}")
    print()

    # Special case: time-only → plot directly
    t_idx = find_dim_index(dims, TIME_DIMS)
    if (
        t_idx is not None and shape[t_idx] > 1 and
        all(shape[i] == 1 for i in range(len(dims)) if i != t_idx)
    ):
        print("Detected single-point spatial dims; plotting time series…")
        var_obj = ds.variables[varname]
        data = var_obj[:].squeeze()
        time_var = find_coord_var(ds, TIME_DIMS)
        if time_var:
            tvals = ds.variables[time_var][:]
        else:
            tvals = np.arange(data.shape[0])
        if hasattr(data, "mask"):
            data = np.where(data.mask, np.nan, data.data)
        if hasattr(tvals, "mask"):
            tvals = np.where(tvals.mask, np.nan, tvals.data)
        plt.figure()
        plt.plot(tvals, data, marker="o")
        plt.xlabel(time_var or "Time Index")
        plt.ylabel(varname + (f" ({var_obj.units})" if hasattr(var_obj, "units") else ""))
        plt.title(f"{varname} vs {time_var or 'Index'}")
        if output_path:
            plt.savefig(output_path, bbox_inches="tight")
            print(f"Saved to {output_path}")
        else:
            plt.show()
        ds.close()
        return

    # if --avg-lat but lat/lon/Time not compatible → disable
    lat_idx = find_dim_index(dims, LAT_DIMS)
    lon_idx = find_dim_index(dims, LON_DIMS)
    if avg_lat and not (
        t_idx   is not None and shape[t_idx]  > 1 and
        lat_idx is not None and shape[lat_idx] > 1 and
        lon_idx is not None and shape[lon_idx] > 1
    ):
        print("Note: disabling --avg-lat (requires Time, lat & lon each >1).")
        avg_lat = False

    # Parse extra indices JSON
    extra = {}
    if extra_json:
        try:
            parsed = json.loads(extra_json)
            for k, v in parsed.items():
                if isinstance(v, int):
                    if "slope" in k.lower():
                        extra[k] = v - 1
                    else:
                        extra[k] = v
        except:
            print("Warning: bad extra-indices.")

    plot_variable(ds, varname, time_index, alt_index,
                  colormap, output_path, extra, avg_lat)
    ds.close()


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('nc_file', nargs='?', help='NetCDF file (omit for interactive)')
    parser.add_argument('-v','--variable', help='Variable name')
    parser.add_argument('-t','--time-index', type=int, help='Time index (0-based)')
    parser.add_argument('-a','--alt-index', type=int, help='Altitude index (0-based)')
    parser.add_argument('-c','--cmap', default='jet', help='Colormap')
    parser.add_argument('--avg-lat', action='store_true', help='Average over latitude')
    parser.add_argument('--show-polar', action='store_true', help='Enable polar-stereo views')
    parser.add_argument('--show-3d', action='store_true', help='Enable 3D globe view')
    parser.add_argument('-o','--output', help='Save figure path')
    parser.add_argument('-e','--extra-indices', help='JSON string for other dims')
    args = parser.parse_args()

    if args.nc_file and args.variable:
        visualize_variable_cli(
            args.nc_file, args.variable,
            args.time_index, args.alt_index,
            args.cmap, args.output,
            args.extra_indices, args.avg_lat
        )
    else:
        visualize_variable_interactive(args.nc_file)


if __name__ == "__main__":
    main()

