#!/usr/bin/env python3
#######################################################################################################
### Python script to output stratification data over time from "restartpem#.nc" files               ###
### and to plot orbital parameters from "obl_ecc_lsp.asc"                                           ###
#######################################################################################################

import os
import sys
import numpy as np
from glob import glob
from netCDF4 import Dataset
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.colors import LinearSegmentedColormap, LogNorm
from scipy.interpolate import interp1d


def get_user_inputs():
    """
    Prompt the user for:
      - folder_path: directory containing NetCDF files (default: "starts")
      - base_name:   base filename (default: "restartpem")
      - infofile:    name of the PEM info file (default: "info_PEM.txt")
    Validates existence of folder and infofile before returning.
    """
    folder_path = input(
        "Enter the folder path containing the NetCDF files "
        "(press Enter for default [starts]): "
    ).strip() or "starts"
    while not os.path.isdir(folder_path):
        print(f"  » \"{folder_path}\" does not exist or is not a directory.")
        folder_path = input(
            "Enter a valid folder path (press Enter for default [starts]): "
        ).strip() or "starts"

    base_name = input(
        "Enter the base name of the NetCDF files "
        "(press Enter for default [restartpem]): "
    ).strip() or "restartpem"

    infofile = input(
        "Enter the name of the PEM info file "
        "(press Enter for default [info_PEM.txt]): "
    ).strip() or "info_PEM.txt"
    while not os.path.isfile(infofile):
        print(f"  » \"{infofile}\" does not exist or is not a file.")
        infofile = input(
            "Enter a valid PEM info filename (press Enter for default [info_PEM.txt]): "
        ).strip() or "info_PEM.txt"

    orbfile = input(
        "Enter the name of the orbital parameters ASCII file "
        "(press Enter for default [obl_ecc_lsp.asc]): "
    ).strip() or "obl_ecc_lsp.asc"
    while not os.path.isfile(orbfile):
        print(f"  » \"{orbfile}\" does not exist or is not a file.")
        orbfile = input(
            "Enter a valid orbital parameters ASCII filename (press Enter for default [obl_ecc_lsp.asc]): "
        ).strip() or "info_PEM.txt"

    return folder_path, base_name, infofile, orbfile


def list_netcdf_files(folder_path, base_name):
    """
    List and sort all NetCDF files matching the pattern {base_name}#.nc
    in folder_path. Returns a sorted list of full file paths.
    """
    pattern = os.path.join(folder_path, f"{base_name}[0-9]*.nc")
    all_files = glob(pattern)
    if not all_files:
        return []

    def extract_index(pathname):
        fname = os.path.basename(pathname)
        idx_str = fname[len(base_name):-3]
        return int(idx_str) if idx_str.isdigit() else float('inf')

    sorted_files = sorted(all_files, key=extract_index)
    return sorted_files


def open_sample_dataset(file_path):
    """
    Open a single NetCDF file and extract:
      - ngrid, nslope
      - longitude, latitude
    Returns (ngrid, nslope, longitude_array, latitude_array).
    """
    with Dataset(file_path, 'r') as ds:
        ngrid = ds.dimensions['physical_points'].size
        nslope = ds.dimensions['nslope'].size
        longitude = ds.variables['longitude'][:].copy()
        latitude = ds.variables['latitude'][:].copy()
    return ngrid, nslope, longitude, latitude


def collect_stratification_variables(files, base_name):
    """
    Scan all files to collect:
      - variable names for each stratification property
      - max number of strata (max_nb_str)
      - global min base elevation and max top elevation
    Returns:
      - var_info: dict mapping each property_name -> sorted list of var names
      - max_nb_str: int
      - min_base_elev: float
      - max_top_elev: float
    """
    max_nb_str = 0
    min_base_elev = np.inf
    max_top_elev = -np.inf

    property_markers = {
        'heights':   'stratif_slope',    # "..._top_elevation"
        'co2_ice':   'h_co2ice',
        'h2o_ice':   'h_h2oice',
        'dust':      'h_dust',
        'pore':      'h_pore',
        'pore_ice':  'poreice_volfrac'
    }
    var_info = {prop: set() for prop in property_markers}

    for file_path in files:
        with Dataset(file_path, 'r') as ds:
            if 'nb_str_max' in ds.dimensions:
                max_nb_str = max(max_nb_str, ds.dimensions['nb_str_max'].size)

            nslope = ds.dimensions['nslope'].size
            for k in range(1, nslope + 1):
                var_name = f"stratif_slope{k:02d}_top_elevation"
                if var_name in ds.variables:
                    arr = ds.variables[var_name][:]
                    min_base_elev = min(min_base_elev, np.min(arr))
                    max_top_elev = max(max_top_elev, np.max(arr))
                    var_info['heights'].add(var_name)

            for full_var in ds.variables:
                for prop, marker in property_markers.items():
                    if (marker in full_var) and prop != 'heights':
                        var_info[prop].add(full_var)

    for prop in var_info:
        var_info[prop] = sorted(var_info[prop])

    return var_info, max_nb_str, min_base_elev, max_top_elev


def load_full_datasets(files):
    """
    Open all NetCDF files and return a list of Dataset objects.
    (They should be closed by the caller after use.)
    """
    return [Dataset(fp, 'r') for fp in files]


def extract_stratification_data(datasets, var_info, ngrid, nslope, max_nb_str):
    """
    Build:
      - heights_data[t_idx][isl] = 2D array (ngrid, n_strata_current) of top_elevations.
      - raw_prop_arrays[prop] = 4D array (ngrid, ntime, nslope, max_nb_str) of per-strata values.
    Returns:
      - heights_data: list (ntime) of lists (nslope) of 2D arrays
      - raw_prop_arrays: dict mapping each property_name -> 4D array
      - ntime: number of time steps (files)
    """
    ntime = len(datasets)

    heights_data = [
        [None for _ in range(nslope)]
        for _ in range(ntime)
    ]
    for t_idx, ds in enumerate(datasets):
        for var_name in var_info['heights']:
            slope_idx = int(var_name.split("slope")[1].split("_")[0]) - 1
            if 0 <= slope_idx < nslope:
                raw = ds.variables[var_name][0, :, :]  # (n_strata, ngrid)
                heights_data[t_idx][slope_idx] = raw.T  # (ngrid, n_strata)

    raw_prop_arrays = {}
    for prop in var_info:
        if prop == 'heights':
            continue
        raw_prop_arrays[prop] = np.zeros((ngrid, ntime, nslope, max_nb_str), dtype=np.float32)

    def slope_index_from_var(vname):
        return int(vname.split("slope")[1].split("_")[0]) - 1

    for prop in raw_prop_arrays:
        slope_map = {}
        for vname in var_info[prop]:
            isl = slope_index_from_var(vname)
            if 0 <= isl < nslope:
                slope_map[isl] = vname

        arr = raw_prop_arrays[prop]
        for t_idx, ds in enumerate(datasets):
            for isl, var_name in slope_map.items():
                raw = ds.variables[var_name][0, :, :]  # (n_strata, ngrid)
                n_strata_current = raw.shape[0]
                arr[:, t_idx, isl, :n_strata_current] = raw.T

    return heights_data, raw_prop_arrays, ntime


def normalize_to_fractions(raw_prop_arrays):
    """
    Given raw_prop_arrays for 'co2_ice', 'h2o_ice', 'dust', 'pore' (in meters),
    normalize each set of strata so that the sum of those four = 1 per cell.
    Returns:
      - frac_arrays: dict mapping same keys -> 4D arrays of fractions (0..1).
    """
    co2 = raw_prop_arrays['co2_ice']
    h2o = raw_prop_arrays['h2o_ice']
    dust = raw_prop_arrays['dust']
    pore = raw_prop_arrays['pore']

    total = co2 + h2o + dust + pore
    mask = total > 0.0

    frac_co2 = np.zeros_like(co2, dtype=np.float32)
    frac_h2o = np.zeros_like(h2o, dtype=np.float32)
    frac_dust = np.zeros_like(dust, dtype=np.float32)
    frac_pore = np.zeros_like(pore, dtype=np.float32)

    frac_co2[mask] = co2[mask] / total[mask]
    frac_h2o[mask] = h2o[mask] / total[mask]
    frac_dust[mask] = dust[mask] / total[mask]
    frac_pore[mask] = pore[mask] / total[mask]

    return {
        'co2_ice': frac_co2,
        'h2o_ice': frac_h2o,
        'dust':     frac_dust,
        'pore':     frac_pore
    }


def read_infofile(file_name):
    """
    Reads "info_PEM.txt". Expects:
      - First line: parameters where the 3rd value is martian_to_earth conversion factor.
      - Each subsequent line: floats where first value is simulation timestamp (in Mars years).
    Returns:
      - date_time: 1D numpy array of timestamps (Mars years)
      - martian_to_earth: float conversion factor
    """
    date_time = []
    with open(file_name, 'r') as fp:
        first = fp.readline().split()
        martian_to_earth = float(first[2])
        for line in fp:
            parts = line.strip().split()
            if not parts:
                continue
            try:
                date_time.append(float(parts[0]))
            except ValueError:
                continue
    return np.array(date_time, dtype=np.float64), martian_to_earth


def get_yes_no_input(prompt: str) -> bool:
    """
    Prompt the user with a yes/no question. Returns True for yes, False for no.
    """
    while True:
        choice = input(f"{prompt} (y/n): ").strip().lower()
        if choice in ['y', 'yes']:
            return True
        elif choice in ['n', 'no']:
            return False
        else:
            print("Please respond with y or n.")


def prompt_discretization_step(max_top_elev):
    """
    Prompt for a positive float dz such that 0 < dz <= max_top_elev.
    """
    while True:
        entry = input(
            "Enter the discretization step of the reference grid for the elevation [m]: "
        ).strip()
        try:
            dz = float(entry)
            if dz <= 0:
                print("  » Discretization step must be strictly positive!")
                continue
            if dz > max_top_elev:
                print(
                    f"  » {dz:.3e} m is greater than the maximum top elevation "
                    f"({max_top_elev:.3e} m). Please enter a smaller value."
                )
                continue
            return dz
        except ValueError:
            print("  » Invalid numeric value. Please try again.")


def interpolate_data_on_refgrid(
    heights_data,
    prop_arrays,
    min_base_for_interp,
    max_top_elev,
    dz,
    exclude_sub=False
):
    """
    Build a reference elevation grid and interpolate strata fractions onto it.

    Returns:
      - ref_grid: 1D array of elevations (nz,)
      - gridded_data: dict mapping each property_name to 4D array
        (ngrid, ntime, nslope, nz) with interpolated fractions.
      - top_index: 3D array (ngrid, ntime, nslope) of ints:
        number of levels covered by the topmost stratum.
    """
    if exclude_sub and (dz > max_top_elev):
        ref_grid = np.array([0.0, max_top_elev], dtype=np.float32)
    else:
        ref_grid = np.arange(min_base_for_interp, max_top_elev + dz/2, dz)
    nz = len(ref_grid)
    print(f"> Number of reference grid points = {nz}")

    sample_prop = next(iter(prop_arrays.values()))
    ngrid, ntime, nslope, max_nb_str = sample_prop.shape

    gridded_data = {
        prop: np.full((ngrid, ntime, nslope, nz), -1.0, dtype=np.float32)
        for prop in prop_arrays
    }
    top_index = np.zeros((ngrid, ntime, nslope), dtype=np.int32)

    for ig in range(ngrid):
        for t_idx in range(ntime):
            for isl in range(nslope):
                h_mat = heights_data[t_idx][isl]
                if h_mat is None:
                    continue

                raw_h = h_mat[ig, :]
                h_all = np.full((max_nb_str,), np.nan, dtype=np.float32)
                n_strata_current = raw_h.shape[0]
                h_all[:n_strata_current] = raw_h

                if exclude_sub:
                    epsilon = 1e-6
                    valid_mask = (h_all >= -epsilon)
                else:
                    valid_mask = (~np.isnan(h_all)) & (h_all != 0.0)

                if not np.any(valid_mask):
                    continue

                h_valid = h_all[valid_mask]
                top_h = np.max(h_valid)
                i_zmax = np.searchsorted(ref_grid, top_h, side='right')
                top_index[ig, t_idx, isl] = i_zmax
                if i_zmax == 0:
                    continue

                for prop, arr in prop_arrays.items():
                    prop_profile_all = arr[ig, t_idx, isl, :]
                    prop_profile = prop_profile_all[valid_mask]
                    if prop_profile.size == 0:
                        continue

                    f_interp = interp1d(
                        h_valid,
                        prop_profile,
                        kind='next',
                        bounds_error=False,
                        fill_value=-1.0
                    )
                    gridded_data[prop][ig, t_idx, isl, :i_zmax] = f_interp(ref_grid[:i_zmax])

    return ref_grid, gridded_data, top_index


def plot_stratification_over_time(
    gridded_data,
    ref_grid,
    top_index,
    heights_data,
    date_time,
    exclude_sub=False,
    output_folder="."
):
    """
    For each grid point and slope, generate a 2×2 figure of:
      - CO2 ice fraction
      - H2O ice fraction
      - Dust fraction
      - Pore fraction
    """
    prop_names = ['co2_ice', 'h2o_ice', 'dust', 'pore']
    titles = ["CO2 ice", "H2O ice", "Dust", "Pore"]
    cmap = plt.get_cmap('turbo').copy()
    cmap.set_under('white')
    vmin, vmax = 0.0, 1.0

    sample_prop = next(iter(gridded_data.values()))
    ngrid, ntime, nslope, nz = sample_prop.shape

    if exclude_sub:
        positive_indices = np.where(ref_grid >= 0.0)[0]
        sub_ref_grid = ref_grid[positive_indices]
    else:
        positive_indices = np.arange(nz)
        sub_ref_grid = ref_grid

    for ig in range(ngrid):
        for isl in range(nslope):
            fig, axes = plt.subplots(2, 2, figsize=(10, 8))
            fig.suptitle(
                f"Content variation over time for (Grid point {ig+1}, Slope {isl+1})",
                fontsize=14,
                fontweight='bold'
            )

            # Precompute valid stratum tops per time
            valid_tops_per_time = []
            for t_idx in range(ntime):
                raw_h = heights_data[t_idx][isl][ig, :]
                h_all = raw_h[~np.isnan(raw_h)]
                if exclude_sub:
                    h_all = h_all[h_all >= 0.0]
                valid_tops_per_time.append(np.unique(h_all))

            for idx, prop in enumerate(prop_names):
                ax = axes.flat[idx]
                data_3d = gridded_data[prop][ig, :, isl, :]
                mat_full = data_3d.T
                mat = mat_full[positive_indices, :].copy()
                mat[mat < 0.0] = np.nan

                # Mask above top stratum
                for t_idx in range(ntime):
                    i_zmax = top_index[ig, t_idx, isl]
                    if i_zmax <= positive_indices[0]:
                        mat[:, t_idx] = np.nan
                    else:
                        count_z = np.count_nonzero(positive_indices < i_zmax)
                        mat[count_z:, t_idx] = np.nan

                im = ax.pcolormesh(
                    date_time,
                    sub_ref_grid,
                    mat,
                    cmap=cmap,
                    shading='auto',
                    vmin=vmin,
                    vmax=vmax
                )
                ax.set_title(titles[idx], fontsize=12)
                ax.set_xlabel("Time (Mars years)")
                ax.set_ylabel("Elevation (m)")

            fig.subplots_adjust(right=0.88)
            fig.tight_layout(rect=[0, 0, 0.88, 1.0])
            cbar_ax = fig.add_axes([0.90, 0.15, 0.02, 0.7])
            fig.colorbar(im, cax=cbar_ax, orientation='vertical', label="Content")

            fname = os.path.join(
                output_folder, f"layering_evolution_ig{ig+1}_is{isl+1}.png"
            )
            fig.savefig(fname, dpi=150)


def plot_stratification_rgb_over_time(
    gridded_data,
    ref_grid,
    top_index,
    heights_data,
    date_time,
    exclude_sub=False,
    output_folder="."
):
    """
    Plot stratification over time colored using RGB ternary mix of H2O ice (blue), CO2 ice (violet), and dust (orange).
    Includes a triangular legend showing the mix proportions.
    """

    # Define constant colors
    violet = np.array([255, 0, 255], dtype=float) / 255
    blue   = np.array([0, 0, 255], dtype=float) / 255
    orange = np.array([255, 165, 0], dtype=float) / 255

    # Prepare elevation mask
    mask_elev = (ref_grid >= 0.0) if exclude_sub else np.ones_like(ref_grid, dtype=bool)
    elev = ref_grid[mask_elev]

    # Generate legend image once
    res = 300
    u = np.linspace(0, 1, res)
    v = np.linspace(0, np.sqrt(3)/2, res)
    X, Y = np.meshgrid(u, v)
    V_bary = 2 * Y / np.sqrt(3)
    U_bary = X - 0.5 * V_bary
    W_bary = 1 - U_bary - V_bary
    mask_triangle = (U_bary >= 0) & (V_bary >= 0) & (W_bary >= 0)

    legend_rgb = (
        U_bary[..., None] * violet
        + V_bary[..., None] * orange
        + W_bary[..., None] * blue
    )
    legend_rgb = np.clip(legend_rgb, 0.0, 1.0)
    legend_rgba = np.zeros((res, res, 4))
    legend_rgba[..., :3] = legend_rgb
    legend_rgba[..., 3] = mask_triangle.astype(float)

    # Loop over grid and slope
    h2o = gridded_data['h2o_ice']
    co2 = gridded_data['co2_ice']
    dust = gridded_data['dust']
    ngrid, ntime, nslope, nz = h2o.shape

    for ig in range(ngrid):
        for isl in range(nslope):
            # Compute RGB stratification over time
            rgb = np.ones((nz, ntime, 3), dtype=float)
            for t in range(ntime):
                mask_z = np.arange(nz) < top_index[ig, t, isl]
                if not mask_z.any():
                    continue
                cH2O = np.clip(h2o[ig, t, isl, mask_z], 0, None)
                cCO2 = np.clip(co2[ig, t, isl, mask_z], 0, None)
                cDust = np.clip(dust[ig, t, isl, mask_z], 0, None)
                total = cH2O + cCO2 + cDust
                total[total == 0] = 1.0
                fH2O = cH2O / total
                fCO2 = cCO2 / total
                fDust = cDust / total
                mix = (
                    np.outer(fH2O, blue)
                    + np.outer(fCO2, violet)
                    + np.outer(fDust, orange)
                )
                mix = np.clip(mix, 0.0, 1.0)
                rgb[mask_z, t, :] = mix

            display_rgb = rgb[mask_elev, :, :]

            # Create figure with legend
            fig, (ax_main, ax_leg) = plt.subplots(
                1, 2, figsize=(12, 5), dpi=200,
                gridspec_kw={'width_ratios': [5, 1]}
            )

            # Main stratification panel
            ax_main.imshow(
                display_rgb,
                aspect='auto',
                extent=[date_time[0], date_time[-1], elev.min(), elev.max()],
                interpolation='nearest',
                origin='lower'
            )
            ax_main.set_facecolor('white')
            ax_main.set_title(f"Ternary mix over time (Grid point {ig+1}, Slope {isl+1})", fontweight='bold')
            ax_main.set_xlabel("Time (Mars years)")
            ax_main.set_ylabel("Elevation (m)")

            # Legend panel
            ax_leg.imshow(
                legend_rgba,
                extent=[0, 1, 0, np.sqrt(3)/2],
                origin='lower',
                interpolation='nearest'
            )

            # Draw triangle border
            triangle = np.array([[0, 0], [1, 0], [0.5, np.sqrt(3)/2], [0, 0]])
            ax_leg.plot(triangle[:, 0], triangle[:, 1], 'k-', linewidth=1)

            # Dashed gridlines
            ticks = np.linspace(0.25, 0.75, 3)
            for f in ticks:
                ax_leg.plot([1 - f, 0.5 * (1 - f)], [0, (1 - f)*np.sqrt(3)/2], '--', color='k', linewidth=0.5)
                ax_leg.plot([f, f + 0.5 * (1 - f)], [0, (1 - f)*np.sqrt(3)/2], '--', color='k', linewidth=0.5)
                y = (np.sqrt(3)/2) * f
                ax_leg.plot([0.5 * f, 1 - 0.5 * f], [y, y], '--', color='k', linewidth=0.5)

            # Legend labels
            ax_leg.text(0, -0.05, 'H2O ice', ha='center', va='top', fontsize=8)
            ax_leg.text(1, -0.05, 'CO2 ice', ha='center', va='top', fontsize=8)
            ax_leg.text(0.5, np.sqrt(3)/2 + 0.05, 'Dust', ha='center', va='bottom', fontsize=8)
            ax_leg.axis('off')

            plt.tight_layout()

            # Save figure
            fname = os.path.join(output_folder, f"layering_rgb_evolution_ig{ig+1}_is{isl+1}.png")
            fig.savefig(fname, dpi=150, bbox_inches='tight')


def plot_dust_to_ice_ratio_over_time(
    gridded_data,
    ref_grid,
    top_index,
    heights_data,
    date_time,
    exclude_sub=False,
    output_folder="."
):
    """
    Plot the dust-to-ice ratio in the stratification over time,
    using a blue-to-orange colormap:
    - blue: ice-dominated (low dust-to-ice ratio)
    - orange: dust-dominated (high dust-to-ice ratio)
    """
    h2o = gridded_data['h2o_ice']
    dust = gridded_data['dust']
    ngrid, ntime, nslope, nz = h2o.shape

    # Elevation mask
    if exclude_sub:
        elevation_mask = (ref_grid >= 0.0)
        elev = ref_grid[elevation_mask]
    else:
        elevation_mask = np.ones_like(ref_grid, dtype=bool)
        elev = ref_grid

    # Define custom blue-to-orange colormap
    blue = np.array([0, 0, 255], dtype=float) / 255
    orange = np.array([255, 165, 0], dtype=float) / 255
    custom_cmap = LinearSegmentedColormap.from_list('BlueOrange', [blue, orange], N=256)

    # Log‑ratio bounds and small epsilon to avoid log(0)
    vmin, vmax = -2, 1
    epsilon = 1e-6

    # Loop over grids and slopes
    for ig in range(ngrid):
        for isl in range(nslope):
            log_ratio_array = np.full((nz, ntime), np.nan, dtype=np.float32)

            # Compute log10(dust/ice) profile at each time step
            for t in range(ntime):
                zmax = top_index[ig, t, isl]
                if zmax <= 0:
                    continue

                h2o_profile = np.clip(h2o[ig, t, isl, :zmax], 0, None)
                dust_profile = np.clip(dust[ig, t, isl, :zmax], 0, None)

                with np.errstate(divide='ignore', invalid='ignore'):
                    ratio_profile = np.where(
                        h2o_profile > 0,
                        dust_profile / h2o_profile,
                        10**(vmax + 1)
                    )
                    log_ratio = np.log10(ratio_profile + epsilon)
                    log_ratio = np.clip(log_ratio, vmin, vmax)

                log_ratio_array[:zmax, t] = log_ratio

            # Convert back to linear ratio and apply elevation mask
            ratio_array = 10**log_ratio_array
            ratio_display = ratio_array[elevation_mask, :]

            # Plot
            fig, ax = plt.subplots(figsize=(8, 6), dpi=150)
            im = ax.imshow(
                ratio_display,
                aspect='auto',
                extent=[date_time[0], date_time[-1], elev.min(), elev.max()],
                origin='lower',
                interpolation='nearest',
                cmap='managua_r',
                norm=LogNorm(vmin=10**vmin, vmax=10**vmax)
            )

            # Add colorbar with simplified ratio labels
            cbar = fig.colorbar(im, ax=ax, orientation='vertical')
            cbar.set_label('Dust / H₂O ice (ratio)')

            # Define custom ticks and labels
            ticks = [1e-2, 1e-1, 1, 1e1]
            labels = ['1:100', '1:10', '1:1', '10:1']
            cbar.set_ticks(ticks)
            cbar.set_ticklabels(labels)

            # Save figure
            plt.tight_layout()
            fname = os.path.join(
                output_folder,
                f"dust_to_ice_ratio_grid{ig+1}_slope{isl+1}.png"
            )
            fig.savefig(fname, dpi=150)


def plot_strata_count_and_total_height(heights_data, date_time, output_folder="."):
    """
    For each grid point and slope, plot:
      - Number of strata vs time
      - Total deposit height vs time
    """
    ntime = len(heights_data)
    nslope = len(heights_data[0])
    ngrid = heights_data[0][0].shape[0]

    for ig in range(ngrid):
        for isl in range(nslope):
            n_strata_t = np.zeros(ntime, dtype=int)
            total_height_t = np.zeros(ntime, dtype=float)

            for t_idx in range(ntime):
                h_mat = heights_data[t_idx][isl]
                raw_h = h_mat[ig, :]
                valid_mask = (~np.isnan(raw_h)) & (raw_h != 0.0)
                if np.any(valid_mask):
                    h_valid = raw_h[valid_mask]
                    n_strata_t[t_idx] = h_valid.size
                    total_height_t[t_idx] = np.max(h_valid)

            fig, axes = plt.subplots(2, 1, figsize=(8, 6), sharex=True)
            fig.suptitle(
                f"Strata count & total height over time for (Grid point {ig+1}, Slope {isl+1})",
                fontsize=14,
                fontweight='bold'
            )

            axes[0].plot(date_time, n_strata_t, marker='+', linestyle='-')
            axes[0].set_ylabel("Number of strata")
            axes[0].grid(True)

            axes[1].plot(date_time, total_height_t, marker='+', linestyle='-')
            axes[1].set_xlabel("Time (Mars years)")
            axes[1].set_ylabel("Total height (m)")
            axes[1].grid(True)

            fig.tight_layout(rect=[0, 0, 1, 0.95])
            fname = os.path.join(
                output_folder, f"strata_count_height_ig{ig+1}_is{isl+1}.png"
            )
            fig.savefig(fname, dpi=150)


def read_orbital_data(orb_file, martian_to_earth):
    """
    Read the .asc file containing obliquity, eccentricity and Ls p.
    Columns:
      0 = time in thousand Martian years
      1 = obliquity (deg)
      2 = eccentricity
      3 = Ls p (deg)
    Converts times to Earth years.
    """
    data = np.loadtxt(orb_file)
    dates_mka = data[:, 0]
    dates_yr = dates_mka * 1e3 / martian_to_earth
    obliquity = data[:, 1]
    eccentricity = data[:, 2]
    lsp = data[:, 3]
    return dates_yr, obliquity, eccentricity, lsp


def plot_orbital_parameters(infofile, orb_file, date_time, output_folder="."):
    """
    Plot the evolution of obliquity, eccentricity and Ls p
    versus simulated time.
    """
    # Read conversion factor from infofile
    _, martian_to_earth = read_infofile(infofile)

    # Read orbital data
    dates_yr, obl, ecc, lsp = read_orbital_data(orb_file, martian_to_earth)

    # Interpolate orbital parameters at simulation dates (date_time)
    obl_interp = interp1d(dates_yr, obl, kind='linear', bounds_error=False, fill_value="extrapolate")(date_time)
    ecc_interp = interp1d(dates_yr, ecc, kind='linear', bounds_error=False, fill_value="extrapolate")(date_time)
    lsp_interp = interp1d(dates_yr, lsp, kind='linear', bounds_error=False, fill_value="extrapolate")(date_time)

    # Plot
    fig, axes = plt.subplots(3, 1, figsize=(8, 10), sharex=True)
    fig.suptitle("Orbital Parameters vs Simulated Time", fontsize=14, fontweight='bold')

    axes[0].plot(date_time, obl_interp, 'r+', linestyle='-')
    axes[0].set_ylabel("Obliquity (°)")
    axes[0].grid(True)

    axes[1].plot(date_time, ecc_interp, 'b+', linestyle='-')
    axes[1].set_ylabel("Eccentricity")
    axes[1].grid(True)

    axes[2].plot(date_time, lsp_interp, 'g+', linestyle='-')
    axes[2].set_ylabel("Ls p (°)")
    axes[2].set_xlabel("Time (Mars years)")
    axes[2].grid(True)

    plt.tight_layout(rect=[0, 0, 1, 0.96])
    fname = os.path.join(output_folder, "orbital_parameters.png")
    fig.savefig(fname, dpi=150)


def main():
    # 1) Get user inputs
    folder_path, base_name, infofile, orbfile = get_user_inputs()

    # 2) List and verify NetCDF files
    files = list_netcdf_files(folder_path, base_name)
    if not files:
        print(f"No NetCDF files named \"{base_name}#.nc\" found in \"{folder_path}\".")
        sys.exit(1)
    print(f"> Found {len(files)} NetCDF file(s).")

    # 3) Open one sample to get grid dimensions & coordinates
    sample_file = files[0]
    ngrid, nslope, longitude, latitude = open_sample_dataset(sample_file)
    print(f"> ngrid  = {ngrid}, nslope = {nslope}")

    # 4) Collect variable info + global min/max elevations
    var_info, max_nb_str, min_base_elev, max_top_elev = collect_stratification_variables(files, base_name)
    print(f"> max strata per slope = {max_nb_str}")
    print(f"> min base elev = {min_base_elev:.3f} m, max top elev = {max_top_elev:.3f} m")

    # 5) Load full datasets
    datasets = load_full_datasets(files)

    # 6) Extract stratification data
    heights_data, raw_prop_arrays, ntime = extract_stratification_data(datasets, var_info, ngrid, nslope, max_nb_str)

    # 7) Close datasets
    for ds in datasets:
        ds.close()

    # 8) Normalize to fractions
    frac_arrays = normalize_to_fractions(raw_prop_arrays)

    # 9) Ask whether to include subsurface
    show_subsurface = get_yes_no_input("Show subsurface layers?")
    exclude_sub = not show_subsurface
    if exclude_sub:
        min_base_for_interp = 0.0
        print("> Interpolating only elevations >= 0 m (surface strata).")
    else:
        min_base_for_interp = min_base_elev
        print(f"> Interpolating full depth down to {min_base_elev:.3f} m.")

    # 10) Prompt discretization step
    dz = prompt_discretization_step(max_top_elev)

    # 11) Build reference grid and interpolate
    ref_grid, gridded_data, top_index = interpolate_data_on_refgrid(
        heights_data, frac_arrays, min_base_for_interp, max_top_elev, dz, exclude_sub=exclude_sub
    )

    # 12) Read timestamps and conversion factor from infofile
    date_time, martian_to_earth = read_infofile(infofile)
    if date_time.size != ntime:
        print(f"Warning: {date_time.size} timestamps vs {ntime} NetCDF files.")

    # 13) Plot stratification data over time
    plot_stratification_over_time(
        gridded_data, ref_grid, top_index, heights_data, date_time,
        exclude_sub=exclude_sub, output_folder="."
    )
    plot_stratification_rgb_over_time(
        gridded_data, ref_grid, top_index, heights_data, date_time,
        exclude_sub=exclude_sub, output_folder="."
    )
    plot_dust_to_ice_ratio_over_time(
        gridded_data, ref_grid, top_index, heights_data, date_time,
        exclude_sub=exclude_sub, output_folder="."
    )
    plot_strata_count_and_total_height(heights_data, date_time, output_folder=".")

    # 14) Plot orbital parameters
    plot_orbital_parameters(infofile, orbfile, date_time, output_folder=".")

    # 15) Show all figures
    plt.show()


if __name__ == "__main__":
    main()

