#!/usr/bin/env python3
##############################################################
### Python script to visualize a variable in a NetCDF file ###
##############################################################

"""
This script can display any numeric variable from a NetCDF file.
It supports the following cases:
  - Scalar output
  - 1D time series
  - 1D vertical profiles
  - 2D latitude/longitude map
  - 2D cross-sections
  - Optionally average over latitude and plot longitude vs. time heatmap
  - Optionally display polar stereographic view of 2D maps
  - Optionally display 3D globe view of 2D maps

Automatic setup from the environment file found in the "LMDZ.MARS/util" folder:
  1. Make sure Conda is installed.
  2. In terminal, navigate to the folder containing this script.
  4. Create the environment:
       conda env create -f display_netcdf.yml
  5. Activate the environment:
       conda activate my_env
  6. Run the script:
     python display_netcdf.py

Usage:
  1) Command-line mode:
       python display_netcdf.py /path/to/your_file.nc --variable VAR_NAME [options]
     Options:
    --variable         : Name of the variable to visualize.
    --time-index       : Index along the Time dimension (0-based, ignored for purely 1D time series).
    --alt-index        : Index along the altitude dimension (0-based), if present.
    --cmap             : Matplotlib colormap (default: "jet").
    --avg-lat          : Average over latitude and plot longitude vs. time heatmap.
    --slice-lon-index  : Fixed longitude index for altitude×longitude cross-section.
    --slice-lat-index  : Fixed latitude index for altitude×latitude cross-section.
    --show-topo        : Overlay MOLA topography on lat/lon maps.
    --show-polar       :
    --show-3d          :
    --output           : If provided, save the figure to this filename instead of displaying.
    --extra-indices    : JSON string to fix indices for any other dimensions.
                         For dimensions with "slope", use 1-based numbering here.
                         Example: '{"nslope": 1, "physical_points": 3}'

  2) Interactive mode:
       python display_netcdf.py
     The script will prompt for everything.
"""

import os
import sys
import glob
import readline
import argparse
import json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.path as mpath
import matplotlib.colors as mcolors
import cartopy.crs as ccrs
import pandas as pd
from netCDF4 import Dataset

# Attempt vedo import early for global use
try:
    import vedo
    from vedo import *
    from scipy.interpolate import RegularGridInterpolator
    vedo_available = True
except ImportError:
    vedo_available = False

# Constants for recognized dimension names
TIME_DIMS = ("Time", "time", "time_counter")
ALT_DIMS  = ("altitude",)
LAT_DIMS  = ("latitude", "lat")
LON_DIMS  = ("longitude", "lon")

# Paths for MOLA data
MOLA_NPY = 'MOLA_1px_per_deg.npy'
MOLA_CSV = 'molaTeam_contour_31rgb_steps.csv'

# Attempt to load MOLA topography
if os.path.isfile(MOLA_NPY): # shape (nlat, nlon) at 1° per pixel: lat from -90 to 90, lon from 0 to 360
    MOLA = np.load(MOLA_NPY) 
    nlat, nlon = MOLA.shape
    topo_lats = np.linspace(90 - 0.5, -90 + 0.5, nlat)
    topo_lons = np.linspace(-180 + 0.5, 180 - 0.5, nlon)
    topo_lon2d, topo_lat2d = np.meshgrid(topo_lons, topo_lats)
    topo_loaded = True
else:
    print(f"Warning: '{MOLA_NPY}' not found! Topography contours disabled.")
    topo_loaded = False


# Attempt to load contour color table
if os.path.isfile(MOLA_CSV):
    color_table = pd.read_csv(MOLA_CSV)
    csv_loaded = True
else:
    print(f"Warning: '{MOLA_CSV}' not found! 3D view colors disabled.")
    csv_loaded = False


def complete_filename(text, state):
    """
    Tab-completion for filesystem paths.
    """
    if "*" not in text:
        pattern = text + "*"
    else:
        pattern = text
    matches = glob.glob(os.path.expanduser(pattern))
    matches = [m + "/" if os.path.isdir(m) else m for m in matches]
    try:
        return matches[state]
    except IndexError:
        return None


def make_varname_completer(varnames):
    """
    Returns a readline completer for variable names.
    """
    def completer(text, state):
        options = [name for name in varnames if name.startswith(text)]
        try:
            return options[state]
        except IndexError:
            return None
    return completer


def find_dim_index(dims, candidates):
    """
    Search through dims tuple for any name in candidates.
    Returns the index if found, else returns None.
    """
    for idx, dim in enumerate(dims):
        for cand in candidates:
            if cand.lower() == dim.lower():
                return idx
    return None


def find_coord_var(dataset, candidates):
    """
    Among dataset variables, return the first variable whose name matches any candidate.
    Returns None if none found.
    """
    for name in dataset.variables:
        for cand in candidates:
            if cand.lower() == name.lower():
                return name
    return None


def overlay_topography(ax, transform, levels=10):
    """
    Overlay MOLA topography contours onto a given GeoAxes.
    """
    if not topo_loaded:
        return
    ax.contour(
        topo_lon2d, topo_lat2d, MOLA,
        levels=levels,
        linewidths=0.5,
        colors='black',
        transform=transform
    )


def attach_format_coord(ax, mat, x, y, x_dim, y_dim, varname, is_pcolormesh=True, data_crs=ccrs.PlateCarree()):
    """
    Attach a format_coord function to the axes to display x, y, and value at cursor.
    Works for both pcolormesh and imshow style grids.
    """
    # Determine dimensions
    if mat.ndim == 2:
        ny, nx = mat.shape
    elif mat.ndim == 3 and mat.shape[2] in (3, 4):
        ny, nx, nc = mat.shape
    else:
        raise ValueError(f"Unsupported mat shape {mat.shape}")

    # Edges or extents
    if is_pcolormesh:
        xedges, yedges = x, y
    else:
        x0, x1 = x.min(), x.max()
        y0, y1 = y.min(), y.max()

    # Detect if ax is a GeoAxes with a projection we can invert
    proj = getattr(ax, 'projection', None)
    geo_axes = (
        isinstance(proj, ccrs.Projection)
        and x_dim.lower() in LON_DIMS
        and y_dim.lower() in LAT_DIMS
    )

    def format_coord(xp, yp):
        # Geographic transform if appropriate
        if geo_axes:
            try:
                lonp, latp = data_crs.transform_point(xp, yp, src_crs=proj)
            except Exception:
                lonp, latp = xp, yp
            xi, yi = lonp, latp
        else:
            xi, yi = xp, yp

        # Map to matrix indices
        if is_pcolormesh:
            col = np.searchsorted(xedges, xi) - 1
            row = np.searchsorted(yedges, yi) - 1
        else:
            col = int((xi - x0) / (x1 - x0) * nx)
            row = int((yi - y0) / (y1 - y0) * ny)

        # Build the label
        label_xy = f"{x_dim}={xi:.3g}, {y_dim}={yi:.3g}"
        if 0 <= row < ny and 0 <= col < nx:
            if mat.ndim == 2:
                v = mat[row, col]
                return f"{label_xy}, {varname}={v:.3g}"
            else:
                vals = mat[row, col]
                txt = ", ".join(f"{vv:.3g}" for vv in vals[:3])
                return f"{label_xy}, {varname}=({txt})"
        else:
            return label_xy

    ax.format_coord = format_coord


def plot_polar_views(lon2d, lat2d, data2d, colormap, varname, units=None, topo_overlay=True):
    """
    Plot two polar‐stereographic views (north & south) of the same data.
    """
    figs = []  # collect figure handles

    for pole in ("north", "south"):
        # Choose projection and extent for each pole
        if pole == "north":
            proj = ccrs.NorthPolarStereo(central_longitude=180)
            extent = [-180, 180, 60, 90]
        else:
            proj = ccrs.SouthPolarStereo(central_longitude=180)
            extent = [-180, 180, -90, -60]

        # Create figure and GeoAxes
        fig = plt.figure(figsize=(8, 6))
        ax = fig.add_subplot(1, 1, 1, projection=proj, aspect=True)
        ax.set_global()
        ax.set_extent(extent, ccrs.PlateCarree())

        # Draw circular boundary
        theta = np.linspace(0, 2 * np.pi, 100)
        center, radius = [0.5, 0.5], 0.5
        verts = np.vstack([np.sin(theta), np.cos(theta)]).T
        circle = mpath.Path(verts * radius + center)
        ax.set_boundary(circle, transform=ax.transAxes)

        # Add meridians/parallels
        gl = ax.gridlines(
            draw_labels=True,
            color='k',
            xlocs=range(-180, 181, 30),
            ylocs=range(-90, 91, 10),
            linestyle='--',
            linewidth=0.5
        )

        # Plot data in PlateCarree projection
        cf = ax.pcolormesh(
            lon2d, lat2d, data2d,
            shading='auto',
            cmap=colormap,
            transform=ccrs.PlateCarree()
        )
        uniq_lons = np.unique(lon2d.ravel())
        uniq_lats = np.unique(lat2d.ravel())
        attach_format_coord(ax, data2d, uniq_lons, uniq_lats, 'lon', 'lat', varname, is_pcolormesh=True)

        # Optionally overlay MOLA topography
        if topo_overlay:
            overlay_topography(ax, transform=ccrs.PlateCarree(), levels=20)

        # Colorbar and title
        cbar = fig.colorbar(cf, ax=ax, pad=0.1)
        label = varname + (f" ({units})" if units else "")
        cbar.set_label(label)
        ax.set_title(f"{varname} — {pole.capitalize()} polar region", pad=20, y=1.05, fontsize=12, fontweight='bold')

        figs.append(fig)

    # Show both figures
    plt.show()


def plot_3D_globe(lon2d, lat2d, data2d, colormap, varname, units=None):
    """
    Plot a 3D globe view of the data using vedo, with surface coloring based on data2d
    and overlaid contour lines from MOLA topography.
    """
    if not vedo_available:
        print("3D view skipped: vedo missing.")
        return
    if not csv_loaded:
        print("3D view skipped: color table missing.")
        return

    # Prepare MOLA grid
    nlat, nlon = MOLA.shape
    lats = np.linspace(90, -90, nlat)
    lons = np.linspace(-180, 180, nlon)
    lon_grid, lat_grid = np.meshgrid(lons, lats)

    # Interpolate data2d onto MOLA grid
    lat_data = np.linspace(-90, 90, data2d.shape[0])
    lon_data = np.linspace(-180, 180, data2d.shape[1])
    interp2d = RegularGridInterpolator((lat_data, lon_data), data2d,
                                       bounds_error=False, fill_value=None)
    newdata2d = interp2d((lat_grid, lon_grid))

    # Generate contour lines from MOLA
    cs = plt.contour(lon_grid, lat_grid, MOLA, levels=10, linewidths=0)
    plt.clf()
    contour_lines = []
    radius = 3389500 # Mars average radius [m]
    for segs, level in zip(cs.allsegs, cs.levels):
        for verts in segs:
            lon_c = verts[:, 0]
            lat_c = verts[:, 1]
            phi_c = np.radians(90 - lat_c)
            theta_c = np.radians(lon_c)
            elev = RegularGridInterpolator((lats, lons), MOLA,
                                           bounds_error=False,
                                           fill_value=0.0)((lat_c, lon_c))
            r_cont = radius + elev * 10
            x_c = r_cont * np.sin(phi_c) * np.cos(theta_c) * 1.002
            y_c = r_cont * np.sin(phi_c) * np.sin(theta_c) * 1.002
            z_c = r_cont * np.cos(phi_c) * 1.002
            pts = np.column_stack([x_c, y_c, z_c])
            if pts.shape[0] > 1:
                contour_lines.append(Line(pts, c='k', lw=0.5))

    # Create sphere surface mesh
    phi = np.deg2rad(90 - lat_grid)
    theta = np.deg2rad(lon_grid)
    r = radius + MOLA * 10
    x = r * np.sin(phi) * np.cos(theta)
    y = r * np.sin(phi) * np.sin(theta)
    z = r * np.cos(phi)
    pts = np.stack([x.ravel(), y.ravel(), z.ravel()], axis=1)

    # Build mesh faces
    faces = []
    for i in range(nlat - 1):
        for j in range(nlon - 1):
            p0 = i * nlon + j
            p1 = p0 + 1
            p2 = p0 + nlon
            p3 = p2 + 1
            faces.extend([(p0, p2, p1), (p1, p2, p3)])

    mesh = Mesh([pts, faces])
    mesh.cmap(colormap, newdata2d.ravel())
    mesh.add_scalarbar(title=varname + (f' [{units}]' if units else ''), c='white')
    mesh.lighting('default')

    # Geographic grid lines
    meridians, parallels, labels = [], [], []
    zero_lon_offset = radius * 0.03
    for lon in range(-150, 181, 30):
        lat_line = np.linspace(-90, 90, nlat)
        lon_line = np.full_like(lat_line, lon)
        phi = np.deg2rad(90 - lat_line)
        theta = np.deg2rad(lon_line)
        elev = RegularGridInterpolator((lats, lons), MOLA)((lat_line, lon_line))
        rr = radius + elev * 10
        pts_line = np.column_stack([
            rr * np.sin(phi) * np.cos(theta),
            rr * np.sin(phi) * np.sin(theta),
            rr * np.cos(phi)
        ]) * 1.005
        label_pos = pts_line[len(pts_line)//2]
        norm = np.linalg.norm(label_pos)
        label_pos_out = label_pos / norm * (norm + radius * 0.02)
        if lon == 0:
            label_pos_out[1] += zero_lon_offset
        meridians.append(Line(pts_line, c='k', lw=1)#.flagpole(
            #f"{lon}°",
            #point=label_pos_out,
            #offset=[0, 0, radius * 0.05],
            #s=radius*0.01,
            #c='yellow'
        #).follow_camera()
        )

    for lat in range(-60, 91, 30):
        lon_line = np.linspace(-180, 180, nlon)
        lat_line = np.full_like(lon_line, lat)
        phi = np.deg2rad(90 - lat_line)
        theta = np.deg2rad(lon_line)
        elev = RegularGridInterpolator((lats, lons), MOLA)((lat_line, lon_line))
        rr = radius + elev * 10
        pts_line = np.column_stack([
            rr * np.sin(phi) * np.cos(theta),
            rr * np.sin(phi) * np.sin(theta),
            rr * np.cos(phi)
        ]) * 1.005
        label_pos = pts_line[len(pts_line)//2]
        norm = np.linalg.norm(label_pos)
        label_pos_out = label_pos / norm * (norm + radius * 0.02)
        parallels.append(Line(pts_line, c='k', lw=1)#.flagpole(
            #f"{lat}°",
            #point=label_pos_out,
            #offset=[0, 0, radius * 0.05],
            #s=radius*0.01,
            #c='yellow'
        #).follow_camera()
        )

    # Create plotter
    plotter = Plotter(title="3D globe view", bg="bb", axes=0)

    # Configure camera
    cam_dist = radius * 3
    plotter.camera.SetPosition([cam_dist, 0, 0])
    plotter.camera.SetFocalPoint([0, 0, 0])
    plotter.camera.SetViewUp([0, 0, 1])

    # Show the globe
    plotter.show(mesh, *contour_lines, *meridians, *parallels)


def transform_physical_points(dataset, var, data):
    """
    Transform a physical_points 1D array into a 2D grid of shape (nlat, nlon).
    """
    # Fetch lat/lon coordinate variables
    lat_var = find_coord_var(dataset, LAT_DIMS)
    lon_var = find_coord_var(dataset, LON_DIMS)
    if lat_var is None or lon_var is None:
        raise ValueError("Cannot find latitude or longitude variables for physical_points")
    raw_lats = dataset.variables[lat_var][:]
    raw_lons = dataset.variables[lon_var][:]
    # Unmask
    if hasattr(raw_lats, 'mask'):
        raw_lats = np.where(raw_lats.mask, np.nan, raw_lats.data)
    if hasattr(raw_lons, 'mask'):
        raw_lons = np.where(raw_lons.mask, np.nan, raw_lons.data)
    # Convert radians to degrees if in radians
    if np.max(np.abs(raw_lats)) <= np.pi:
        raw_lats = np.degrees(raw_lats)
        raw_lons = np.degrees(raw_lons)
    # Get unique coords
    uniq_lats = np.unique(raw_lats)
    uniq_lons = np.unique(raw_lons)
    # Initialize grid
    grid = np.full((uniq_lats.size, uniq_lons.size), np.nan)
    # Build the grid
    for value, lat, lon in zip(data.ravel(), raw_lats.ravel(), raw_lons.ravel()):
        i = np.where(np.isclose(uniq_lats, lat))[0][0]
        j = np.where(np.isclose(uniq_lons, lon))[0][0]
        grid[i, j] = value
    # Duplicate the pole value across all longitudes
    for i in (0, -1):
        row = grid[i, :]
        count_good = np.count_nonzero(~np.isnan(row))
        if count_good == 1:
            pole_value = row[~np.isnan(row)][0]
            grid[i, :] = pole_value
    # Wrap longitude if needed
    if -180.0 in uniq_lons:
        idx = np.where(np.isclose(uniq_lons, -180.0))[0][0]
        grid = np.hstack([grid, grid[:, [idx]]])
        uniq_lons = np.append(uniq_lons, 180.0)
    return grid, uniq_lats, uniq_lons, lat_var, lon_var


def get_dimension_indices(ds, varname):
    """
    For each dimension of the variable:
     - if size == 1 → automatically select index 0
     - otherwise prompt the user:
         <number>     : take that specific index (1-based)
         'a'          : average over this dimension
         'e' or Enter : take all values
    Returns {dim_name: int index, 'avg', or None}.
    """
    var = ds.variables[varname]
    dims = var.dimensions
    shape = var.shape
    selection = {}
    for dim, size in zip(dims, shape):
        if size == 1:
            selection[dim] = 0
            continue
        prompt = (
                f"Available options for '{dim}' (size {size}):\n"
                f"  > '1–{size}' to pick that index\n"
                "  > 'a' to average over this dimension\n"
                "  > 'e' or Enter to take all values\n"
                "Choose: "
        )
        while True:
            resp = input(prompt).strip().lower()
            if resp in ("", "e"):
                selection[dim] = None
                break
            if resp == 'a':
                selection[dim] = 'avg'
                break
            if resp.isdigit():
                n = int(resp)
                if 1 <= n <= size:
                    selection[dim] = n - 1
                    break
            print(f"  Invalid entry '{resp}'. Please enter a number, 'a', 'e', or just Enter.")
    return selection


def plot_variable(dataset, varname, colormap="jet", output_path=None, extra_indices=None):
    """
    Automatically select singleton dims, prompt for others,
    allow user to choose x/y for 2D, handle special cases (physical_points, averaging).
    """
    var = dataset.variables[varname]
    dims = list(var.dimensions)
    # Read data
    try:
        data_full = var[:]
    except Exception as e:
        print(f"Error: Cannot read data for '{varname}': {e}")
        return
    # Unmask
    if hasattr(data_full, 'mask'):
        data_full = np.where(data_full.mask, np.nan, data_full.data)
    # Initialize extra_indices
    extra_indices = extra_indices or {}
    # Handle averaging selections
    for dim, mode in dict(extra_indices).items():
        if mode == 'avg':
            ax = dims.index(dim)
            data_full = np.nanmean(data_full, axis=ax, keepdims=True)
            extra_indices[dim] = 0
    # Build slicer
    slicer = []
    for dim in dims:
        idx = extra_indices.get(dim)
        slicer.append(idx if isinstance(idx, int) else slice(None))
    data_slice = data_full[tuple(slicer)]
    nd = data_slice.ndim
    # Special case: physical_points dimension
    if nd == 1 and 'physical_points' in dims:
        # Transform into 2D grid
        grid, uniq_lats, uniq_lons, latv, lonv = transform_physical_points(dataset, var, data_slice)
        # Plot map
        proj = ccrs.PlateCarree()
        fig, ax = plt.subplots(figsize=(8, 6), subplot_kw=dict(projection=proj))
        lon2d, lat2d = np.meshgrid(uniq_lons, uniq_lats)
        lon_ticks = np.arange(-180, 181, 30)
        lat_ticks = np.arange(-90, 91, 30)
        ax.set_xticks(lon_ticks, crs=ccrs.PlateCarree())
        ax.set_yticks(lat_ticks, crs=ccrs.PlateCarree())
        ax.tick_params(
            axis='x', which='major',
            length=4,
            direction='out',
            pad=2,
            labelsize=8
        )
        ax.tick_params(
            axis='y', which='major',
            length=4,
            direction='out',
            pad=2,
            labelsize=8
        )
        cf = ax.pcolormesh(lon2d, lat2d, grid, shading='auto', cmap=colormap, transform=ccrs.PlateCarree())
        attach_format_coord(ax, grid, uniq_lons, uniq_lats, 'lon', 'lat', varname, is_pcolormesh=True)
        overlay_topography(ax, transform=proj, levels=10) # Overlay MOLA topography
        cbar = fig.colorbar(cf, ax=ax, pad=0.02)
        cbar.set_label(varname)
        ax.set_title(f"{varname} (physical_points)", fontweight='bold')
        ax.set_xlabel(f"Longitude ({getattr(dataset.variables[lonv], 'units', 'deg')})")
        ax.set_ylabel(f"Latitude ({getattr(dataset.variables[latv], 'units', 'deg')})")
        # Prompt for polar-stereo views if interactive
        if input("Display polar-stereo views? [y = yes, anything else = no]: ").strip().lower() == "y":
            units = getattr(var, 'units', None)
            plot_polar_views(lon2d, lat2d, grid, colormap, varname, units)
        # Prompt for 3D globe view if interactive
        if input("Display 3D globe view? [y = yes, anything else = no]: ").strip().lower() == "y":
            units = getattr(var, 'units', None)
            plot_3D_globe(lon2d, lat2d, grid, colormap, varname, units)
        if output_path:
            fig.savefig(output_path, bbox_inches='tight')
            print(f"Saved to {output_path}")
        else:
            plt.show()
        return
    # 0D
    if nd == 0:
        print(f"\033[36mScalar '{varname}': {float(data_slice)}\033[0m")
        return
    # 1D
    if nd == 1:
        rem = [(i, d) for i, (d, s) in enumerate(zip(dims, slicer)) if isinstance(s, slice)]
        axis_idx, dim_name = rem[0]
        coord_var = find_coord_var(dataset, [dim_name])
        if coord_var:
            x = dataset.variables[coord_var][:]
            if hasattr(x, 'mask'):
                x = np.where(x.mask, np.nan, x.data)
            xlabel = coord_var
        else:
            x = np.arange(data_slice.shape[0])
            xlabel = dim_name
        y = data_slice
        plt.figure(figsize=(8, 4))
        plt.plot(x, y)
        plt.grid(True)
        plt.xlabel(xlabel)
        plt.ylabel(varname)
        plt.title(f"{varname} vs {xlabel}")
        if output_path:
            plt.savefig(output_path, bbox_inches='tight')
            print(f"Saved plot to {output_path}")
        else:
            plt.show()
        return
    # 2D
    if nd == 2:
        remaining = [d for d, idx in zip(dims, slicer) if isinstance(idx, slice)]
        # Choose X/Y interactively
        resp = input(f"Which dimension on X? {remaining}: ").strip()
        if resp == remaining[1]:
            x_dim, y_dim = remaining[1], remaining[0]
        else:
            x_dim, y_dim = remaining[0], remaining[1]
        def get_coords(dim):
            coord_var = find_coord_var(dataset, [dim])
            if coord_var:
                arr = dataset.variables[coord_var][:]
                if hasattr(arr, 'mask'):
                    arr = np.where(arr.mask, np.nan, arr.data)
                return arr
            return np.arange(data_slice.shape[remaining.index(dim)])
        x_coords = get_coords(x_dim)
        y_coords = get_coords(y_dim)
        order = [remaining.index(y_dim), remaining.index(x_dim)]
        plot_data = np.moveaxis(data_slice, order, [0, 1])
        fig, ax = plt.subplots(figsize=(8, 6))
        im = ax.pcolormesh(x_coords, y_coords, plot_data, shading='auto', cmap=colormap)
        attach_format_coord(ax, plot_data, x_coords, y_coords, x_dim, y_dim, varname, is_pcolormesh=True)
        cbar = fig.colorbar(im, ax=ax, pad=0.02)
        cbar.set_label(varname)
        ax.set_xlabel(x_dim)
        ax.set_ylabel(y_dim)
        ax.set_title(f"{varname} ({y_dim} vs {x_dim})")
        ax.grid(True)
        if {x_dim, y_dim} & set(LAT_DIMS) and {x_dim, y_dim} & set(LON_DIMS):
            # Prompt for polar-stereo views if interactive
            if sys.stdin.isatty() and input("Display polar-stereo views? [y = yes, anything else = no]: ").strip().lower() == "y":
                units = getattr(dataset.variables[varname], "units", None)
                plot_polar_views(x_coords, y_coords, plot_data, colormap, varname, units)
            # Prompt for 3D globe view if interactive
            if sys.stdin.isatty() and input("Display 3D globe view? [y = yes, anything else = no]: ").strip().lower() == "y":
                units = getattr(dataset.variables[varname], "units", None)
                plot_3D_globe(x_coords, y_coords, plot_data, colormap, varname, units)
        if output_path:
            fig.savefig(output_path, bbox_inches='tight')
            print(f"Saved to {output_path}")
        else:
            plt.show()
        return
    print(f"Plotting for ndim={nd} not yet supported.")


def visualize_variable_interactive(nc_path=None):
    """
    Interactive loop: keep prompting for variables to plot until user quits.
    """
    # Open dataset
    if nc_path:
        path = nc_path
    else:
        readline.set_completer(complete_filename)
        readline.parse_and_bind("tab: complete")
        path = input("Enter path to NetCDF file: ").strip()

    if not os.path.isfile(path):
        print(f"Error: '{path}' not found.")
        return

    ds = Dataset(path, "r")
    var_list = list(ds.variables.keys())
    if not var_list:
        print("No variables found in file.")
        ds.close()
        return

    # Enable interactive mode
    plt.ion()

    while True:
        # Enable tab-completion for variable names
        readline.set_completer(make_varname_completer(var_list))
        readline.parse_and_bind("tab: complete")

        print("\nAvailable variables:")
        for name in var_list:
            print(f"  > {name}")
        varname = input("\nEnter variable name to plot (or Enter to quit): ").strip()
        if varname == "":
            print("Exiting.")
            break
        if varname not in ds.variables:
            print(f"Variable '{varname}' not found. Try again.")
            continue

        # Display dimensions and size
        var = ds.variables[varname]
        dims, shape = var.dimensions, var.shape
        print(f"\nVariable '{varname}' has dimensions:")
        for dim, size in zip(dims, shape):
            print(f"  - {dim} (size {size})")
        print()

        # Prepare slicing parameters
        selection = get_dimension_indices(ds, varname)

        # Plot the variable
        plot_variable(
            ds,
            varname,
            colormap    = 'jet',
            output_path = None,
            extra_indices = selection
        )

    ds.close()


def visualize_variable_cli(nc_file, varname, time_index, alt_index,
                           colormap, output_path, extra_json, avg_lat):
    """
    Command-line mode: visualize directly, parsing the --extra-indices argument (JSON string).
    """
    if not os.path.isfile(nc_file):
        print(f"Error: '{nc_file}' not found.")
        return
    ds = Dataset(nc_file, "r")
    if varname not in ds.variables:
        print(f"Variable '{varname}' not in file.")
        ds.close()
        return

    # Display dimensions and size
    dims  = ds.variables[varname].dimensions
    shape = ds.variables[varname].shape
    print(f"\nVariable '{varname}' has {len(dims)} dimensions:")
    for name, size in zip(dims, shape):
        print(f"  - {name}: size {size}")
    print()

    # Special case: time-only → plot directly
    t_idx = find_dim_index(dims, TIME_DIMS)
    if (
        t_idx is not None and shape[t_idx] > 1 and
        all(shape[i] == 1 for i in range(len(dims)) if i != t_idx)
    ):
        print("Detected single-point spatial dims; plotting time series…")
        var_obj = ds.variables[varname]
        data = var_obj[:].squeeze()
        time_var = find_coord_var(ds, TIME_DIMS)
        if time_var:
            tvals = ds.variables[time_var][:]
        else:
            tvals = np.arange(data.shape[0])
        if hasattr(data, "mask"):
            data = np.where(data.mask, np.nan, data.data)
        if hasattr(tvals, "mask"):
            tvals = np.where(tvals.mask, np.nan, tvals.data)
        plt.figure()
        plt.plot(tvals, data, marker="o")
        plt.xlabel(time_var or "Time Index")
        plt.ylabel(varname + (f" ({var_obj.units})" if hasattr(var_obj, "units") else ""))
        plt.title(f"{varname} vs {time_var or 'Index'}", fontweight='bold')
        if output_path:
            plt.savefig(output_path, bbox_inches="tight")
            print(f"Saved to {output_path}")
        else:
            plt.show()
        ds.close()
        return

    # if --avg-lat but lat/lon/Time not compatible → disable
    lat_idx = find_dim_index(dims, LAT_DIMS)
    lon_idx = find_dim_index(dims, LON_DIMS)
    if avg_lat and not (
        t_idx   is not None and shape[t_idx]  > 1 and
        lat_idx is not None and shape[lat_idx] > 1 and
        lon_idx is not None and shape[lon_idx] > 1
    ):
        print("Note: disabling --avg-lat (requires Time, lat & lon each >1).")
        avg_lat = False

    # Parse extra indices JSON
    extra = {}
    if extra_json:
        try:
            parsed = json.loads(extra_json)
            for k, v in parsed.items():
                if isinstance(v, int):
                    if "slope" in k.lower():
                        extra[k] = v - 1
                    else:
                        extra[k] = v
        except:
            print("Warning: bad extra-indices.")

    plot_variable(ds, varname, time_index, alt_index,
                  colormap, output_path, extra, avg_lat)
    ds.close()


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('nc_file', nargs='?', help='NetCDF file (omit for interactive)')
    parser.add_argument('-v','--variable', help='Variable name')
    parser.add_argument('-t','--time-index', type=int, help='Time index (0-based)')
    parser.add_argument('-a','--alt-index', type=int, help='Altitude index (0-based)')
    parser.add_argument('-c','--cmap', default='jet', help='Colormap')
    parser.add_argument('--avg-lat', action='store_true', help='Average over latitude')
    parser.add_argument('--show-polar', action='store_true', help='Enable polar-stereo views')
    parser.add_argument('--show-3d', action='store_true', help='Enable 3D globe view')
    parser.add_argument('-o','--output', help='Save figure path')
    parser.add_argument('-e','--extra-indices', help='JSON string for other dims')
    args = parser.parse_args()

    if args.nc_file and args.variable:
        visualize_variable_cli(
            args.nc_file, args.variable,
            args.time_index, args.alt_index,
            args.cmap, args.output,
            args.extra_indices, args.avg_lat
        )
    else:
        visualize_variable_interactive(args.nc_file)


if __name__ == "__main__":
    main()

