#!/usr/bin/env python3
#######################################################################################################
### Python script to output stratification data over time from "restartpem#.nc" files               ###
### and to plot orbital parameters from "obl_ecc_lsp.asc"                                           ###
#######################################################################################################

import os
import sys
import numpy as np
from glob import glob
from netCDF4 import Dataset
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.colors import LinearSegmentedColormap, LogNorm
from scipy.interpolate import interp1d


def get_user_inputs():
    """
    Prompt the user for:
      - folder_path: directory containing NetCDF files (default: "starts")
      - base_name:   base filename (default: "restartpem")
      - infofile:    name of the PEM info file (default: "info_PEM.txt")
    Validates existence of folder and infofile before returning.
    """
    folder_path = input(
        "Enter the folder path containing the NetCDF files "
        "(press Enter for default [starts]): "
    ).strip() or "starts"
    while not os.path.isdir(folder_path):
        print(f"  » \"{folder_path}\" does not exist or is not a directory.")
        folder_path = input(
            "Enter a valid folder path (press Enter for default [starts]): "
        ).strip() or "starts"

    base_name = input(
        "Enter the base name of the NetCDF files "
        "(press Enter for default [restartpem]): "
    ).strip() or "restartpem"

    infofile = input(
        "Enter the name of the PEM info file "
        "(press Enter for default [info_PEM.txt]): "
    ).strip() or "info_PEM.txt"
    while not os.path.isfile(infofile):
        print(f"  » \"{infofile}\" does not exist or is not a file.")
        infofile = input(
            "Enter a valid PEM info filename (press Enter for default [info_PEM.txt]): "
        ).strip() or "info_PEM.txt"

    orbfile = input(
        "Enter the name of the orbital parameters ASCII file "
        "(press Enter for default [obl_ecc_lsp.asc]): "
    ).strip() or "obl_ecc_lsp.asc"
    while not os.path.isfile(orbfile):
        print(f"  » \"{orbfile}\" does not exist or is not a file.")
        orbfile = input(
            "Enter a valid orbital parameters ASCII filename (press Enter for default [obl_ecc_lsp.asc]): "
        ).strip() or "info_PEM.txt"

    return folder_path, base_name, infofile, orbfile


def list_netcdf_files(folder_path, base_name):
    """
    List and sort all NetCDF files matching the pattern {base_name}#.nc
    in folder_path. Returns a sorted list of full file paths.
    """
    pattern = os.path.join(folder_path, f"{base_name}[0-9]*.nc")
    all_files = glob(pattern)
    if not all_files:
        return []

    def extract_index(pathname):
        fname = os.path.basename(pathname)
        idx_str = fname[len(base_name):-3]
        return int(idx_str) if idx_str.isdigit() else float('inf')

    sorted_files = sorted(all_files, key=extract_index)
    return sorted_files


def open_sample_dataset(file_path):
    """
    Open a single NetCDF file and extract:
      - ngrid, nslope
      - longitude, latitude
    Returns (ngrid, nslope, longitude_array, latitude_array).
    """
    with Dataset(file_path, 'r') as ds:
        ngrid = ds.dimensions['physical_points'].size
        nslope = ds.dimensions['nslope'].size
        longitude = ds.variables['longitude'][:].copy()
        latitude = ds.variables['latitude'][:].copy()
    return ngrid, nslope, longitude, latitude


def collect_stratification_variables(files, base_name):
    """
    Scan all files to collect:
      - variable names for each stratification property
      - max number of strata (max_nb_str)
      - global min base elevation and max top elevation
    Returns:
      - var_info: dict mapping each property_name -> sorted list of var names
      - max_nb_str: int
      - min_base_elev: float
      - max_top_elev: float
    """
    max_nb_str = 0
    min_base_elev = np.inf
    max_top_elev = -np.inf

    property_markers = {
        'heights':   'stratif_slope',    # "..._top_elevation"
        'co2_ice':   'h_co2ice',
        'h2o_ice':   'h_h2oice',
        'dust':      'h_dust',
        'pore':      'h_pore',
        'pore_ice':  'poreice_volfrac'
    }
    var_info = {prop: set() for prop in property_markers}

    for file_path in files:
        with Dataset(file_path, 'r') as ds:
            if 'nb_str_max' in ds.dimensions:
                max_nb_str = max(max_nb_str, ds.dimensions['nb_str_max'].size)

            nslope = ds.dimensions['nslope'].size
            for k in range(1, nslope + 1):
                var_name = f"stratif_slope{k:02d}_top_elevation"
                if var_name in ds.variables:
                    arr = ds.variables[var_name][:]
                    min_base_elev = min(min_base_elev, np.min(arr))
                    max_top_elev = max(max_top_elev, np.max(arr))
                    var_info['heights'].add(var_name)

            for full_var in ds.variables:
                for prop, marker in property_markers.items():
                    if (marker in full_var) and prop != 'heights':
                        var_info[prop].add(full_var)

    for prop in var_info:
        var_info[prop] = sorted(var_info[prop])

    return var_info, max_nb_str, min_base_elev, max_top_elev


def load_full_datasets(files):
    """
    Open all NetCDF files and return a list of Dataset objects.
    (They should be closed by the caller after use.)
    """
    return [Dataset(fp, 'r') for fp in files]


def extract_stratification_data(datasets, var_info, ngrid, nslope, max_nb_str):
    """
    Build:
      - heights_data[t_idx][isl] = 2D array (ngrid, n_strata_current) of top_elevations.
      - raw_prop_arrays[prop] = 4D array (ngrid, ntime, nslope, max_nb_str) of per-strata values.
    Returns:
      - heights_data: list (ntime) of lists (nslope) of 2D arrays
      - raw_prop_arrays: dict mapping each property_name -> 4D array
      - ntime: number of time steps (files)
    """
    ntime = len(datasets)

    heights_data = [
        [None for _ in range(nslope)]
        for _ in range(ntime)
    ]
    for t_idx, ds in enumerate(datasets):
        for var_name in var_info['heights']:
            slope_idx = int(var_name.split("slope")[1].split("_")[0]) - 1
            if 0 <= slope_idx < nslope:
                raw = ds.variables[var_name][0, :, :]  # (n_strata, ngrid)
                heights_data[t_idx][slope_idx] = raw.T  # (ngrid, n_strata)

    raw_prop_arrays = {}
    for prop in var_info:
        if prop == 'heights':
            continue
        raw_prop_arrays[prop] = np.zeros((ngrid, ntime, nslope, max_nb_str), dtype=np.float32)

    def slope_index_from_var(vname):
        return int(vname.split("slope")[1].split("_")[0]) - 1

    for prop in raw_prop_arrays:
        slope_map = {}
        for vname in var_info[prop]:
            isl = slope_index_from_var(vname)
            if 0 <= isl < nslope:
                slope_map[isl] = vname

        arr = raw_prop_arrays[prop]
        for t_idx, ds in enumerate(datasets):
            for isl, var_name in slope_map.items():
                raw = ds.variables[var_name][0, :, :]  # (n_strata, ngrid)
                n_strata_current = raw.shape[0]
                arr[:, t_idx, isl, :n_strata_current] = raw.T

    return heights_data, raw_prop_arrays, ntime


def normalize_to_fractions(raw_prop_arrays):
    """
    Given raw_prop_arrays for 'co2_ice', 'h2o_ice', 'dust', 'pore' (in meters),
    normalize each set of strata so that the sum of those four = 1 per cell.
    Returns:
      - frac_arrays: dict mapping same keys -> 4D arrays of fractions (0..1).
    """
    co2 = raw_prop_arrays['co2_ice']
    h2o = raw_prop_arrays['h2o_ice']
    dust = raw_prop_arrays['dust']
    pore = raw_prop_arrays['pore']

    total = co2 + h2o + dust + pore
    mask = total > 0.0

    frac_co2 = np.zeros_like(co2, dtype=np.float32)
    frac_h2o = np.zeros_like(h2o, dtype=np.float32)
    frac_dust = np.zeros_like(dust, dtype=np.float32)
    frac_pore = np.zeros_like(pore, dtype=np.float32)

    frac_co2[mask] = co2[mask] / total[mask]
    frac_h2o[mask] = h2o[mask] / total[mask]
    frac_dust[mask] = dust[mask] / total[mask]
    frac_pore[mask] = pore[mask] / total[mask]

    return {
        'co2_ice': frac_co2,
        'h2o_ice': frac_h2o,
        'dust':     frac_dust,
        'pore':     frac_pore
    }


def read_infofile(file_name):
    """
    Reads "info_PEM.txt". Expects:
      - First line: parameters where the 3rd value is martian_to_earth conversion factor.
      - Each subsequent line: floats where first value is simulation timestamp (in Mars years).
    Returns:
      - date_time: 1D numpy array of timestamps (Mars years)
      - martian_to_earth: float conversion factor
    """
    date_time = []
    with open(file_name, 'r') as fp:
        first = fp.readline().split()
        martian_to_earth = float(first[2])
        for line in fp:
            parts = line.strip().split()
            if not parts:
                continue
            try:
                date_time.append(float(parts[0]))
            except ValueError:
                continue
    return np.array(date_time, dtype=np.float64), martian_to_earth


def get_yes_no_input(prompt: str) -> bool:
    """
    Prompt the user with a yes/no question. Returns True for yes, False for no.
    """
    while True:
        choice = input(f"{prompt} (y/n): ").strip().lower()
        if choice in ['y', 'yes']:
            return True
        elif choice in ['n', 'no']:
            return False
        else:
            print("Please respond with y or n.")


def prompt_discretization_step(max_top_elev):
    """
    Prompt for a positive float dz such that 0 < dz <= max_top_elev.
    """
    while True:
        entry = input(
            "Enter the discretization step of the reference grid for the elevation [m]: "
        ).strip()
        try:
            dz = float(entry)
            if dz <= 0:
                print("  » Discretization step must be strictly positive!")
                continue
            if dz > max_top_elev:
                print(
                    f"  » {dz:.3e} m is greater than the maximum top elevation "
                    f"({max_top_elev:.3e} m). Please enter a smaller value."
                )
                continue
            return dz
        except ValueError:
            print("  » Invalid numeric value. Please try again.")


def interpolate_data_on_refgrid(
    heights_data,
    prop_arrays,
    min_base_for_interp,
    max_top_elev,
    dz,
    exclude_sub=False
):
    """
    Build a reference elevation grid and interpolate strata fractions onto it.

    Returns:
      - ref_grid: 1D array of elevations (nz,)
      - gridded_data: dict mapping each property_name to 4D array
        (ngrid, ntime, nslope, nz) with interpolated fractions.
      - top_index: 3D array (ngrid, ntime, nslope) of ints:
        number of levels covered by the topmost stratum.
    """
    if exclude_sub and (dz > max_top_elev):
        ref_grid = np.array([0.0, max_top_elev], dtype=np.float32)
    else:
        ref_grid = np.arange(min_base_for_interp, max_top_elev + dz/2, dz)
    nz = len(ref_grid)
    print(f"> Number of reference grid points = {nz}")

    sample_prop = next(iter(prop_arrays.values()))
    ngrid, ntime, nslope, max_nb_str = sample_prop.shape

    gridded_data = {
        prop: np.full((ngrid, ntime, nslope, nz), -1.0, dtype=np.float32)
        for prop in prop_arrays
    }
    top_index = np.zeros((ngrid, ntime, nslope), dtype=np.int32)

    for ig in range(ngrid):
        for t_idx in range(ntime):
            for isl in range(nslope):
                h_mat = heights_data[t_idx][isl]
                if h_mat is None:
                    continue

                raw_h = h_mat[ig, :]
                h_all = np.full((max_nb_str,), np.nan, dtype=np.float32)
                n_strata_current = raw_h.shape[0]
                h_all[:n_strata_current] = raw_h

                if exclude_sub:
                    epsilon = 1e-6
                    valid_mask = (h_all >= -epsilon)
                else:
                    valid_mask = (~np.isnan(h_all)) & (h_all != 0.0)

                if not np.any(valid_mask):
                    continue

                h_valid = h_all[valid_mask]
                top_h = np.max(h_valid)
                i_zmax = np.searchsorted(ref_grid, top_h, side='right')
                top_index[ig, t_idx, isl] = i_zmax
                if i_zmax == 0:
                    continue

                for prop, arr in prop_arrays.items():
                    prop_profile_all = arr[ig, t_idx, isl, :]
                    prop_profile = prop_profile_all[valid_mask]
                    if prop_profile.size == 0:
                        continue

                    f_interp = interp1d(
                        h_valid,
                        prop_profile,
                        kind='next',
                        bounds_error=False,
                        fill_value=-1.0
                    )
                    gridded_data[prop][ig, t_idx, isl, :i_zmax] = f_interp(ref_grid[:i_zmax])

    return ref_grid, gridded_data, top_index


def attach_format_coord(ax, mat, x, y, is_pcolormesh=True):
    """
    Attach a format_coord function to the axes to display x, y, and value at cursor.
    Works for both pcolormesh and imshow style grids.
    """
    # Determine dimensions
    if mat.ndim == 2:
        ny, nx = mat.shape
    elif mat.ndim == 3 and mat.shape[2] in (3, 4):
        ny, nx, nc = mat.shape
    else:
        raise ValueError(f"Unsupported mat shape {mat.shape}")
    # Edges or extents
    if is_pcolormesh:
        xedges, yedges = x, y
    else:
        x0, x1 = x.min(), x.max()
        y0, y1 = y.min(), y.max()

    def format_coord(xp, yp):
        # Map to indices
        if is_pcolormesh:
            col = np.searchsorted(xedges, xp) - 1
            row = np.searchsorted(yedges, yp) - 1
        else:
            col = int((xp - x0) / (x1 - x0) * nx)
            row = int((yp - y0) / (y1 - y0) * ny)
        # Within bounds?
        if 0 <= row < ny and 0 <= col < nx:
            if mat.ndim == 2:
                v = mat[row, col]
                return f"x={xp:.3g}, y={yp:.3g}, val={v:.3g}"
            else:
                vals = mat[row, col]
                txt = ", ".join(f"{vv:.3g}" for vv in vals[:3])
                return f"x={xp:.3g}, y={yp:.3g}, val=({txt})"
        return f"x={xp:.3g}, y={yp:.3g}"

    ax.format_coord = format_coord


def plot_stratification_over_time(
    gridded_data,
    ref_grid,
    top_index,
    heights_data,
    date_time,
    exclude_sub=False,
    output_folder="."
):
    """
    For each grid point and slope, generate a 2×2 figure of:
      - CO2 ice fraction
      - H2O ice fraction
      - Dust fraction
      - Pore fraction
    """
    prop_names = ['co2_ice', 'h2o_ice', 'dust', 'pore']
    titles = ["CO2 ice", "H2O ice", "Dust", "Pore"]
    cmap = plt.get_cmap('turbo').copy()
    cmap.set_under('white')
    vmin, vmax = 0.0, 1.0

    sample_prop = next(iter(gridded_data.values()))
    ngrid, ntime, nslope, nz = sample_prop.shape

    if exclude_sub:
        positive_indices = np.where(ref_grid >= 0.0)[0]
        sub_ref_grid = ref_grid[positive_indices]
    else:
        positive_indices = np.arange(nz)
        sub_ref_grid = ref_grid

    for ig in range(ngrid):
        for isl in range(nslope):
            fig, axes = plt.subplots(2, 2, figsize=(10, 8))
            fig.suptitle(
                f"Content variation over time for (Grid point {ig+1}, Slope {isl+1})",
                fontsize=14,
                fontweight='bold'
            )

            # Precompute valid stratum tops per time
            valid_tops_per_time = []
            for t_idx in range(ntime):
                raw_h = heights_data[t_idx][isl][ig, :]
                h_all = raw_h[~np.isnan(raw_h)]
                if exclude_sub:
                    h_all = h_all[h_all >= 0.0]
                valid_tops_per_time.append(np.unique(h_all))

            for idx, prop in enumerate(prop_names):
                ax = axes.flat[idx]
                data_3d = gridded_data[prop][ig, :, isl, :]
                mat_full = data_3d.T
                mat = mat_full[positive_indices, :].copy()
                mat[mat < 0.0] = np.nan

                # Mask above top stratum
                for t_idx in range(ntime):
                    i_zmax = top_index[ig, t_idx, isl]
                    if i_zmax <= positive_indices[0]:
                        mat[:, t_idx] = np.nan
                    else:
                        count_z = np.count_nonzero(positive_indices < i_zmax)
                        mat[count_z:, t_idx] = np.nan

                im = ax.pcolormesh(
                    date_time,
                    sub_ref_grid,
                    mat,
                    cmap=cmap,
                    shading='auto',
                    vmin=vmin,
                    vmax=vmax
                )
                x_edges = np.concatenate([date_time, [date_time[-1] + (date_time[-1]-date_time[-2])]])
                attach_format_coord(ax, mat, x_edges, np.concatenate([sub_ref_grid, [sub_ref_grid[-1] + (sub_ref_grid[-1]-sub_ref_grid[-2])]]), is_pcolormesh=True)
                ax.set_title(titles[idx], fontsize=12)
                ax.set_xlabel("Time (Mars years)")
                ax.set_ylabel("Elevation (m)")

            fig.subplots_adjust(right=0.88)
            fig.tight_layout(rect=[0, 0, 0.88, 1.0])
            cbar_ax = fig.add_axes([0.90, 0.15, 0.02, 0.7])
            fig.colorbar(im, cax=cbar_ax, orientation='vertical', label="Content")

            fname = os.path.join(
                output_folder, f"layering_evolution_ig{ig+1}_is{isl+1}.png"
            )
            fig.savefig(fname, dpi=150)


def plot_stratification_rgb_over_time(
    gridded_data,
    ref_grid,
    top_index,
    heights_data,
    date_time,
    exclude_sub=False,
    output_folder="."
):
    """
    Plot stratification over time colored using RGB ternary mix of H2O ice (blue), CO2 ice (violet), and dust (orange).
    Includes a triangular legend showing the mix proportions.
    """
    # Define constant colors
    violet = np.array([255,   0, 255], dtype=float) / 255
    blue   = np.array([  0,   0, 255], dtype=float) / 255
    orange = np.array([255, 165,   0], dtype=float) / 255

    # Elevation mask and array
    if exclude_sub:
        elevation_mask = (ref_grid >= 0.0)
        elev = ref_grid[elevation_mask]
    else:
        elevation_mask = np.ones_like(ref_grid, dtype=bool)
        elev = ref_grid

    # Pre-compute legend triangle
    res = 300
    u = np.linspace(0, 1, res)
    v = np.linspace(0, np.sqrt(3)/2, res)
    X, Y = np.meshgrid(u, v)
    V_bary = 2 * Y / np.sqrt(3)
    U_bary = X - 0.5 * V_bary
    W_bary = 1 - U_bary - V_bary
    mask_triangle = (U_bary >= 0) & (V_bary >= 0) & (W_bary >= 0)
    legend_rgb = (
        U_bary[..., None] * violet
        + V_bary[..., None] * orange
        + W_bary[..., None] * blue
    )
    legend_rgb = np.clip(legend_rgb, 0.0, 1.0)
    legend_rgba = np.zeros((res, res, 4))
    legend_rgba[..., :3] = legend_rgb
    legend_rgba[..., 3] = mask_triangle.astype(float)

    # Extract data arrays
    h2o = gridded_data['h2o_ice']
    co2 = gridded_data['co2_ice']
    dust = gridded_data['dust']
    ngrid, ntime, nslope, nz = h2o.shape

    # Fill missing depths
    ti = top_index.copy().astype(int)
    for ig in range(ngrid):
        for isl in range(nslope):
            for t in range(1, ntime):
                if ti[ig, t, isl] <= 0:
                    ti[ig, t, isl] = ti[ig, t-1, isl]

    # Loop over grid and slope
    for ig in range(ngrid):
        for isl in range(nslope):
            # Compute RGB stratification over time
            rgb = np.ones((nz, ntime, 3), dtype=float)

            frac_all = np.zeros((nz, ntime, 3), dtype=float)  # store fH2O, fCO2, fDust
            for t in range(ntime):
                depth = ti[ig, t, isl]
                if depth <= 0:
                    continue
                cH2O = np.clip(h2o[ig, t, isl, :depth], 0, None)
                cCO2 = np.clip(co2[ig, t, isl, :depth], 0, None)
                cDust = np.clip(dust[ig, t, isl, :depth], 0, None)
                total = cH2O + cCO2 + cDust
                total[total == 0] = 1.0
                fH2O = cH2O / total
                fCO2 = cCO2 / total
                fDust = cDust / total
                frac_all[:depth, t, :] = np.stack([fH2O, fCO2, fDust], axis=1)
                mix = np.outer(fH2O, blue) + np.outer(fCO2, violet) + np.outer(fDust, orange)
                rgb[:depth, t, :] = np.clip(mix, 0, 1)

            # Mask elevation
            display_rgb = rgb[elevation_mask, :, :]
            display_frac = frac_all[elevation_mask, :, :]

            display_rgb = rgb[elevation_mask, :, :]

            # Compute edges for pcolormesh
            dt = date_time[1] - date_time[0] if len(date_time) > 1 else 1
            x_edges = np.concatenate([date_time, [date_time[-1] + dt]])
            d_e = np.diff(elev)
            last_e = elev[-1] + (d_e[-1] if len(d_e)>0 else 1)
            y_edges = np.concatenate([elev, [last_e]])

            # Create figure with legend
            fig, (ax_main, ax_leg) = plt.subplots(
                1, 2, figsize=(8, 4), dpi=200,
                gridspec_kw={'width_ratios': [5, 1]}
            )

            # Main stratification panel
            mesh = ax_main.pcolormesh(
                x_edges,
                y_edges,
                display_rgb,
                shading='auto',
                edgecolors='none'
            )

            # Custom coordinate formatter: show time, elevation, and mixture fractions
            def main_format(x, y):
                # check bounds
                if x < x_edges[0] or x > x_edges[-1] or y < y_edges[0] or y > y_edges[-1]:
                    return ''
                # locate cell
                i = np.searchsorted(x_edges, x) - 1
                j = np.searchsorted(y_edges, y) - 1
                i = np.clip(i, 0, display_rgb.shape[1]-1)
                j = np.clip(j, 0, display_rgb.shape[0]-1)
                # get fractions
                fH2O, fCO2, fDust = display_frac[j, i]
                return f"Time={x:.2f}, Elev={y:.2f}, H2O={fH2O:.2f}, CO2={fCO2:.2f}, Dust={fDust:.2f}"
            ax_main.format_coord = main_format
            ax_main.set_facecolor('white')
            ax_main.set_title(f"Ternary mix over time (Grid point {ig+1}, Slope {isl+1})", fontweight='bold')
            ax_main.set_xlabel("Time (Mars years)")
            ax_main.set_ylabel("Elevation (m)")

            # Legend panel using proper edges
            u_edges = np.linspace(0, 1, res+1)
            v_edges = np.linspace(0, np.sqrt(3)/2, res+1)
            ax_leg.pcolormesh(
                u_edges,
                v_edges,
                legend_rgba,
                shading='auto',
                edgecolors='none'
            )
            ax_leg.set_aspect('equal')

            # Custom coordinate formatter for legend: show barycentric fractions
            def legend_format(x, y):
                # compute barycentric coords from cartesian (x,y)
                V = 2 * y / np.sqrt(3)
                U = x - 0.5 * V
                W = 1 - U - V
                if U >= 0 and V >= 0 and W >= 0:
                    return f"H2O: {W:.2f}, Dust: {V:.2f}, CO2: {U:.2f}"
                else:
                    return ''
            ax_leg.format_coord = legend_format

            # Draw triangle border and gridlines
            triangle = np.array([[0, 0], [1, 0], [0.5, np.sqrt(3)/2], [0, 0]])
            ax_leg.plot(triangle[:, 0], triangle[:, 1], 'k-', linewidth=1, clip_on=False, zorder=10)
            ticks = np.linspace(0.25, 0.75, 3)
            for f in ticks:
                ax_leg.plot([1 - f, 0.5 * (1 - f)], [0, (1 - f)*np.sqrt(3)/2], '--', color='k', linewidth=0.5, clip_on=False, zorder=9)
                ax_leg.plot([f, f + 0.5 * (1 - f)], [0, (1 - f)*np.sqrt(3)/2], '--', color='k', linewidth=0.5, clip_on=False, zorder=9)
                y = (np.sqrt(3)/2) * f
                ax_leg.plot([0.5 * f, 1 - 0.5 * f], [y, y], '--', color='k', linewidth=0.5, clip_on=False, zorder=9)

            # Legend labels
            ax_leg.text(0, -0.05, 'H2O ice', ha='center', va='top', fontsize=8)
            ax_leg.text(1, -0.05, 'CO2 ice', ha='center', va='top', fontsize=8)
            ax_leg.text(0.5, np.sqrt(3)/2 + 0.05, 'Dust', ha='center', va='bottom', fontsize=8)
            ax_leg.axis('off')

            # Save figure
            plt.tight_layout()
            fname = os.path.join(output_folder, f"layering_rgb_evolution_ig{ig+1}_is{isl+1}.png")
            fig.savefig(fname, dpi=150, bbox_inches='tight')


def plot_dust_to_ice_ratio_over_time(
    gridded_data,
    ref_grid,
    top_index,
    heights_data,
    date_time,
    exclude_sub=False,
    output_folder="."
):
    """
    Plot the dust-to-ice ratio in the stratification over time,
    using a blue-to-orange colormap:
    - blue: ice-dominated (low dust-to-ice ratio)
    - orange: dust-dominated (high dust-to-ice ratio)
    """
    h2o = gridded_data['h2o_ice']
    dust = gridded_data['dust']
    ngrid, ntime, nslope, nz = h2o.shape

    # Elevation mask
    if exclude_sub:
        elevation_mask = (ref_grid >= 0.0)
        elev = ref_grid[elevation_mask]
    else:
        elevation_mask = np.ones_like(ref_grid, dtype=bool)
        elev = ref_grid

    # Define custom blue-to-orange colormap
    blue = np.array([0, 0, 255], dtype=float) / 255
    orange = np.array([255, 165, 0], dtype=float) / 255
    custom_cmap = LinearSegmentedColormap.from_list('BlueOrange', [blue, orange], N=256)

    # Log‑ratio bounds and small epsilon to avoid log(0)
    vmin, vmax = -2, 1
    epsilon = 1e-6
    
    # Loop over grids and slopes
    for ig in range(ngrid):
        for isl in range(nslope):
            ti = top_index[ig, :, isl].copy().astype(int)
            for t in range(1, ntime):
                if ti[t] <= 0:
                    ti[t] = ti[t-1]

            # Compute log10(dust/ice) profile at each time step
            log_ratio_array = np.full((nz, ntime), np.nan, dtype=np.float32)
            for t in range(ntime):
                zmax = ti[t]
                if zmax <= 0:
                    continue

                h2o_profile = np.clip(h2o[ig, t, isl, :zmax], 0, None)
                dust_profile = np.clip(dust[ig, t, isl, :zmax], 0, None)

                with np.errstate(divide='ignore', invalid='ignore'):
                    ratio_profile = np.where(
                        h2o_profile > 0,
                        dust_profile / h2o_profile,
                        10**(vmax + 1)
                    )
                    log_ratio = np.log10(ratio_profile + epsilon)
                    log_ratio = np.clip(log_ratio, vmin, vmax)

                log_ratio_array[:zmax, t] = log_ratio

            # Convert back to linear ratio and apply elevation mask
            ratio_array = 10**log_ratio_array
            ratio_display = ratio_array[elevation_mask, :]

            # Plot
            fig, ax = plt.subplots(figsize=(8, 6), dpi=150)
            im = ax.pcolormesh(
                date_time,
                elev,
                ratio_display,
                shading='auto',
                cmap='managua_r',
                norm=LogNorm(vmin=10**vmin, vmax=10**vmax),
            )
            x_edges = np.concatenate([date_time, [date_time[-1] + (date_time[-1]-date_time[-2])]])
            attach_format_coord(ax, ratio_display, x_edges, np.concatenate([elev, [elev[-1] + (elev[-1]-elev[-2])]]), is_pcolormesh=True)
            ax.set_title(f"Dust-to-Ice ratio over time (Grid point {ig+1}, Slope {isl+1})", fontweight='bold')
            ax.set_xlabel('Time (Mars years)')
            ax.set_ylabel('Elevation (m)')

            # Add colorbar with simplified ratio labels
            cbar = fig.colorbar(im, ax=ax, orientation='vertical')
            cbar.set_label('Dust / H₂O ice (ratio)')

            # Define custom ticks and labels
            ticks = [1e-2, 1e-1, 1, 1e1]
            labels = ['1:100', '1:10', '1:1', '10:1']
            cbar.set_ticks(ticks)
            cbar.set_ticklabels(labels)

            # Save figure
            plt.tight_layout()
            fname = os.path.join(
                output_folder,
                f"dust_to_ice_ratio_grid{ig+1}_slope{isl+1}.png"
            )
            fig.savefig(fname, dpi=150)


def plot_strata_count_and_total_height(heights_data, date_time, output_folder="."):
    """
    For each grid point and slope, plot:
      - Number of strata vs time
      - Total deposit height vs time
    """
    ntime = len(heights_data)
    nslope = len(heights_data[0])
    ngrid = heights_data[0][0].shape[0]

    for ig in range(ngrid):
        for isl in range(nslope):
            n_strata_t = np.zeros(ntime, dtype=int)
            total_height_t = np.zeros(ntime, dtype=float)

            for t_idx in range(ntime):
                h_mat = heights_data[t_idx][isl]
                raw_h = h_mat[ig, :]
                valid_mask = (~np.isnan(raw_h)) & (raw_h != 0.0)
                if np.any(valid_mask):
                    h_valid = raw_h[valid_mask]
                    n_strata_t[t_idx] = h_valid.size
                    total_height_t[t_idx] = np.max(h_valid)

            fig, axes = plt.subplots(2, 1, figsize=(8, 6), sharex=True)
            fig.suptitle(
                f"Strata count & total height over time for (Grid point {ig+1}, Slope {isl+1})",
                fontsize=14,
                fontweight='bold'
            )

            axes[0].plot(date_time, n_strata_t, marker='+', linestyle='-')
            axes[0].set_ylabel("Number of strata")
            axes[0].grid(True)

            axes[1].plot(date_time, total_height_t, marker='+', linestyle='-')
            axes[1].set_xlabel("Time (Mars years)")
            axes[1].set_ylabel("Total height (m)")
            axes[1].grid(True)

            fig.tight_layout(rect=[0, 0, 1, 0.95])
            fname = os.path.join(
                output_folder, f"strata_count_height_ig{ig+1}_is{isl+1}.png"
            )
            fig.savefig(fname, dpi=150)


def read_orbital_data(orb_file, martian_to_earth):
    """
    Read the .asc file containing obliquity, eccentricity and Ls p.
    Columns:
      0 = time in thousand Martian years
      1 = obliquity (deg)
      2 = eccentricity
      3 = Ls p (deg)
    Converts times to Earth years.
    """
    data = np.loadtxt(orb_file)
    dates_mka = data[:, 0]
    dates_yr = dates_mka * 1e3 / martian_to_earth
    obliquity = data[:, 1]
    eccentricity = data[:, 2]
    lsp = data[:, 3]
    return dates_yr, obliquity, eccentricity, lsp


def plot_orbital_parameters(infofile, orb_file, date_time, output_folder="."):
    """
    Plot the evolution of obliquity, eccentricity and Ls p
    versus simulated time.
    """
    # Read conversion factor from infofile
    _, martian_to_earth = read_infofile(infofile)

    # Read orbital data
    dates_yr, obl, ecc, lsp = read_orbital_data(orb_file, martian_to_earth)

    # Interpolate orbital parameters at simulation dates (date_time)
    obl_interp = interp1d(dates_yr, obl, kind='linear', bounds_error=False, fill_value="extrapolate")(date_time)
    ecc_interp = interp1d(dates_yr, ecc, kind='linear', bounds_error=False, fill_value="extrapolate")(date_time)
    lsp_interp = interp1d(dates_yr, lsp, kind='linear', bounds_error=False, fill_value="extrapolate")(date_time)

    # Plot
    fig, axes = plt.subplots(3, 1, figsize=(8, 10), sharex=True)
    fig.suptitle("Orbital parameters vs simulated time", fontsize=14, fontweight='bold')

    axes[0].plot(date_time, obl_interp, 'r-', marker='+')
    axes[0].set_ylabel("Obliquity (°)")
    axes[0].grid(True)

    axes[1].plot(date_time, ecc_interp, 'b-', marker='+')
    axes[1].set_ylabel("Eccentricity")
    axes[1].grid(True)

    axes[2].plot(date_time, lsp_interp, 'g-', marker='+')
    axes[2].set_ylabel("Ls of perihelion  (°)")
    axes[2].set_xlabel("Time (Mars years)")
    axes[2].grid(True)

    plt.tight_layout(rect=[0, 0, 1, 0.96])
    fname = os.path.join(output_folder, "orbital_parameters_laskar.png")
    fig.savefig(fname, dpi=150)


def mars_ls(pday, peri_day, e_elips, year_day, lsperi=0.0):
    """
    Compute solar longitude (Ls) in radians for a given Mars date array 'pday'.
    Returns Ls in degrees [0, 360).
    """
    zz = (pday - peri_day) / year_day
    zanom = 2 * np.pi * (zz - np.round(zz))
    xref = np.abs(zanom)

    # Solve Kepler's equation via Newton–Raphson
    zx0 = xref + e_elips * np.sin(xref)
    for _ in range(10):
        f  = zx0 - e_elips * np.sin(zx0) - xref
        fp = 1 - e_elips * np.cos(zx0)
        dz = -f / fp
        zx0 += dz
        if np.all(np.abs(dz) <= 1e-7):
            break

    zx0 = np.where(zanom < 0, -zx0, zx0)
    zteta = 2 * np.arctan(
        np.sqrt((1 + e_elips) / (1 - e_elips)) * np.tan(zx0 / 2)
    )
    psollong = np.mod(zteta + lsperi, 2 * np.pi)

    return np.degrees(psollong)


def read_orbital_data_nc(starts_folder, infofile=None):
    """
    Read orbital parameters from restartfi_postPEM*.nc files in starts_folder.
    """
    if not os.path.isdir(starts_folder):
        raise ValueError(f"Invalid starts_folder '{starts_folder}': not a directory.")

    # Read simulation time mapping if provided
    if infofile:
        dates_yr, martian_to_earth = read_infofile(infofile)
    else:
        dates_yr = None

    pattern = os.path.join(starts_folder, "restartfi_postPEM*.nc")
    files = glob(pattern)
    if not files:
        raise FileNotFoundError(f"No NetCDF restart files found matching {pattern}")

    def extract_number(path):
        name = os.path.basename(path)
        prefix = 'restartfi_postPEM'
        if name.startswith(prefix) and name.endswith('.nc'):
            num_str = name[len(prefix):-3]
            if num_str.isdigit():
                return int(num_str)
        return float('inf')

    files = sorted(files, key=extract_number)

    all_year_day, all_peri, all_aphe, all_date_peri, all_obl = [], [], [], [], []
    for nc_path in files:
        with Dataset(nc_path, 'r') as nc:
            ctrl = nc.variables['controle'][:]
            all_year_day.append(ctrl[13])
            all_peri.append(ctrl[14])
            all_aphe.append(ctrl[15])
            all_date_peri.append(ctrl[16])
            all_obl.append(ctrl[17])

    year_day      = np.array(all_year_day)
    perihelion    = np.array(all_peri)
    aphelion      = np.array(all_aphe)
    date_peri_day = np.array(all_date_peri)
    obliquity     = np.array(all_obl)

    eccentricity  = (aphelion - perihelion) / (aphelion + perihelion)
    ls_perihelion = mars_ls(
        date_peri_day,
        date_peri_day,
        eccentricity,
        year_day
    )

    return dates_yr, obliquity, eccentricity, ls_perihelion


def plot_orbital_parameters_nc(starts_folder, infofile, date_time, output_folder="."):
    """
    Plot the evolution of obliquity, eccentricity and Ls p coming from simulation data
    versus simulated time.
    """
    # Read orbital data
    times_yr, obl, ecc, lsp = read_orbital_data_nc(starts_folder, infofile)

    fargs = dict(kind='linear', bounds_error=False, fill_value='extrapolate')
    obl_i = interp1d(times_yr, obl, **fargs)(date_time)
    ecc_i = interp1d(times_yr, ecc, **fargs)(date_time)
    lsp_i = interp1d(times_yr, lsp, **fargs)(date_time)

    fig, axes = plt.subplots(3,1, figsize=(8,10), sharex=True)
    fig.suptitle("Orbital parameters vs simulated time", fontsize=14, fontweight='bold')

    # Plot
    axes[0].plot(date_time, obl_i, 'r-', marker='+')
    axes[0].set_ylabel("Obliquity (°)")
    axes[0].grid(True)

    axes[1].plot(date_time, ecc_i, 'b-', marker='+')
    axes[1].set_ylabel("Eccentricity")
    axes[1].grid(True)

    axes[2].plot(date_time, lsp_i, 'g-', marker='+')
    axes[2].set_ylabel("Ls of perihelion (°)")
    axes[2].set_xlabel("Time (Mars years)")
    axes[2].grid(True)

    plt.tight_layout(rect=[0,0,1,0.96])
    outname = os.path.join(output_folder, "orbital_parameters_simu.png")
    fig.savefig(outname, dpi=150)


def plot_dust_to_ice_ratio_with_obliquity(
    starts_folder,
    infofile,
    gridded_data,
    ref_grid,
    top_index,
    heights_data,
    date_time,
    exclude_sub=False,
    output_folder="."
):
    """
    Plot the dust-to-ice ratio over time as a heatmap, and overlay the evolution of
    obliquity on a secondary y-axis.
    """
    h2o = gridded_data['h2o_ice']
    co2 = gridded_data['co2_ice']
    dust = gridded_data['dust']
    ngrid, ntime, nslope, nz = h2o.shape

    # Read orbital data
    times_yr, obl, _, _ = read_orbital_data_nc(starts_folder, infofile)
    fargs = dict(kind='linear', bounds_error=False, fill_value='extrapolate')
    obliquity = interp1d(times_yr, obl, **fargs)(date_time)

    # Computed total height
    for ig in range(ngrid):
        for isl in range(nslope):
            total_height_t = np.zeros(ntime, dtype=float)

            for t_idx in range(ntime):
                h_mat = heights_data[t_idx][isl]
                raw_h = h_mat[ig, :]
                valid_mask = (~np.isnan(raw_h)) & (raw_h != 0.0)
                if np.any(valid_mask):
                    h_valid = raw_h[valid_mask]
                    total_height_t[t_idx] = np.max(h_valid)

    # Compute the per-interval sign of height change
    dh = np.diff(total_height_t)
    signs = np.sign(dh)
    color_map = { 1: 'green', -1: 'red', 0: 'orange' }

    # Elevation mask
    if exclude_sub:
        elevation_mask = (ref_grid >= 0.0)
        elev = ref_grid[elevation_mask]
    else:
        elevation_mask = np.ones_like(ref_grid, dtype=bool)
        elev = ref_grid

    # Custom colormap: blue (ice) to orange (dust)
    blue = np.array([0, 0, 255]) / 255
    orange = np.array([255, 165, 0]) / 255
    custom_cmap = LinearSegmentedColormap.from_list('BlueOrange', [blue, orange], N=256)

    # Log‑ratio bounds and small epsilon to avoid log(0)
    vmin, vmax = -2, 1
    epsilon = 1e-6

    # Loop over grids and slopes
    for ig in range(ngrid):
        for isl in range(nslope):
            ti = top_index[ig, :, isl].copy().astype(int)
            frac_all = np.zeros((nz, ntime, 3), dtype=float)  # store fH2O, fCO2, fDust
            for t in range(1, ntime):
                if ti[t] <= 0:
                    ti[t] = ti[t-1]

            # Compute log10(dust/ice) profile at each time step
            log_ratio_array = np.full((nz, ntime), np.nan, dtype=np.float32)
            for t in range(ntime):
                zmax = ti[t]
                if zmax <= 0:
                    continue
                cH2O = np.clip(h2o[ig, t, isl, :zmax], 0, None)
                cCO2 = np.clip(co2[ig, t, isl, :zmax], 0, None)
                cDust = np.clip(dust[ig, t, isl, :zmax], 0, None)
                total = cH2O + cCO2 + cDust
                total[total == 0] = 1.0
                fH2O = cH2O / total
                fCO2 = cCO2 / total
                fDust = cDust / total
                frac_all[:zmax, t, :] = np.stack([fH2O, fCO2, fDust], axis=1)

                h2o_profile = np.clip(h2o[ig, t, isl, :zmax], 0, None)
                dust_profile = np.clip(dust[ig, t, isl, :zmax], 0, None)

                with np.errstate(divide='ignore', invalid='ignore'):
                    ratio_profile = np.where(
                        h2o_profile > 0,
                        dust_profile / h2o_profile,
                        10**(vmax + 1)
                    )
                    log_ratio = np.log10(ratio_profile + epsilon)
                    log_ratio = np.clip(log_ratio, vmin, vmax)

                log_ratio_array[:zmax, t] = log_ratio

            # Convert back to linear ratio and mask
            ratio_array = 10**log_ratio_array
            ratio_display = ratio_array[elevation_mask, :]
            display_frac = frac_all[elevation_mask, :, :]

            # Plot
            fig, ax = plt.subplots(figsize=(8, 6), dpi=150)
            im = ax.pcolormesh(
                date_time,
                elev,
                ratio_display,
                shading='auto',
                cmap='managua_r',
                norm=LogNorm(vmin=10**vmin, vmax=10**vmax),
            )
            x_edges = np.concatenate([date_time, [date_time[-1] + (date_time[-1]-date_time[-2])]])
            y_edges = np.concatenate([elev, [elev[-1] + (elev[-1]-elev[-2])]])
            def format_coord_all(x, y):
                # check bounds
                if x < x_edges[0] or x > x_edges[-1] or y < y_edges[0] or y > y_edges[-1]:
                    return ''
                # locate cell
                i = np.searchsorted(x_edges, x) - 1
                j = np.searchsorted(y_edges, y) - 1
                i = np.clip(i, 0, display_frac.shape[1]-1)
                j = np.clip(j, 0, display_frac.shape[0]-1)
                # get fractions
                fH2O  = display_frac[j, i, 0]
                fDust = display_frac[j, i, 2]
                obl   = np.interp(x, date_time, obliquity)
                return f"Time={x:.2f}, Elev={y:.2f}, H2O={fH2O:.2f}, Dust={fDust:.2f}, Obl={obl:.2f}°"

            ax.format_coord = format_coord_all
            ax.set_title(f"Dust-to-Ice ratio over time (Grid point {ig+1}, Slope {isl+1})", fontweight='bold')
            ax.set_xlabel('Time (Mars years)')
            ax.set_ylabel('Elevation (m)')

            # Add colorbar
            cbar = fig.colorbar(im, ax=ax, orientation='vertical', pad=0.15)
            cbar.set_label('Dust / H₂O ice (ratio)')
            cbar.set_ticks([1e-2, 1e-1, 1, 1e1])
            cbar.set_ticklabels(['1:100', '1:10', '1:1', '10:1'])

            # Overlay obliquity on secondary y-axis
            ax2 = ax.twinx()
            for i in range(len(dh)):
                ax2.plot(
                    [date_time[i], date_time[i+1]],
                    [obliquity[i], obliquity[i+1]],
                    color=color_map[signs[i]],
                    marker='+',
                    linewidth=1.5
                )
            ax2.format_coord = format_coord_all
            ax2.set_ylabel('Obliquity (°)')
            ax2.tick_params(axis='y')
            ax2.grid(False)

            # Save
            os.makedirs(output_folder, exist_ok=True)
            outname = os.path.join(
                output_folder,
                f'dust_ice_obliquity_grid{ig+1}_slope{isl+1}.png'
            )
            plt.tight_layout()
            fig.savefig(outname, dpi=150)


def main():
    # 1) Get user inputs
    folder_path, base_name, infofile, orbfile = get_user_inputs()

    # 2) List and verify NetCDF files
    files = list_netcdf_files(folder_path, base_name)
    if not files:
        print(f"No NetCDF files named \"{base_name}#.nc\" found in \"{folder_path}\".")
        sys.exit(1)
    print(f"> Found {len(files)} NetCDF file(s).")

    # 3) Open one sample to get grid dimensions & coordinates
    sample_file = files[0]
    ngrid, nslope, longitude, latitude = open_sample_dataset(sample_file)
    print(f"> ngrid  = {ngrid}, nslope = {nslope}")

    # 4) Collect variable info + global min/max elevations
    var_info, max_nb_str, min_base_elev, max_top_elev = collect_stratification_variables(files, base_name)
    print(f"> max strata per slope = {max_nb_str}")
    print(f"> min base elev = {min_base_elev:.3f} m, max top elev = {max_top_elev:.3f} m")

    # 5) Load full datasets
    datasets = load_full_datasets(files)

    # 6) Extract stratification data
    heights_data, raw_prop_arrays, ntime = extract_stratification_data(datasets, var_info, ngrid, nslope, max_nb_str)

    # 7) Close datasets
    for ds in datasets:
        ds.close()

    # 8) Normalize to fractions
    frac_arrays = normalize_to_fractions(raw_prop_arrays)

    # 9) Ask whether to include subsurface
    show_subsurface = get_yes_no_input("Show subsurface layers?")
    exclude_sub = not show_subsurface
    if exclude_sub:
        min_base_for_interp = 0.0
        print("> Interpolating only elevations >= 0 m (surface strata).")
    else:
        min_base_for_interp = min_base_elev
        print(f"> Interpolating full depth down to {min_base_elev:.3f} m.")

    # 10) Prompt discretization step
    dz = prompt_discretization_step(max_top_elev)

    # 11) Build reference grid and interpolate
    ref_grid, gridded_data, top_index = interpolate_data_on_refgrid(
        heights_data, frac_arrays, min_base_for_interp, max_top_elev, dz, exclude_sub=exclude_sub
    )

    # 12) Read timestamps and conversion factor from infofile
    date_time, martian_to_earth = read_infofile(infofile)
    if date_time.size != ntime:
        print(f"Warning: {date_time.size} timestamps vs {ntime} NetCDF files.")

    # 13) Plot stratification data over time
    plot_stratification_over_time(
        gridded_data, ref_grid, top_index, heights_data, date_time,
        exclude_sub=exclude_sub, output_folder="."
    )
    plot_stratification_rgb_over_time(
        gridded_data, ref_grid, top_index, heights_data, date_time,
        exclude_sub=exclude_sub, output_folder="."
    )
    #plot_dust_to_ice_ratio_over_time(
    #    gridded_data, ref_grid, top_index, heights_data, date_time,
    #    exclude_sub=exclude_sub, output_folder="."
    #)
    plot_dust_to_ice_ratio_with_obliquity(
        folder_path, infofile,
        gridded_data, ref_grid, top_index, heights_data, date_time,
        exclude_sub=exclude_sub, output_folder="."
    )
    #plot_strata_count_and_total_height(heights_data, date_time, output_folder=".")

    # 14) Plot orbital parameters
    #plot_orbital_parameters(infofile, orbfile, date_time, output_folder=".")
    plot_orbital_parameters_nc(folder_path, infofile, date_time, output_folder=".")

    # 15) Show all figures
    plt.show()


if __name__ == "__main__":
    main()

