#!/usr/bin/env python3
#######################################################################################################
### Python script to output the stratification data over time from the "restartpem#.nc" files files ###
#######################################################################################################


import os
import sys
import numpy as np
from glob import glob
from netCDF4 import Dataset
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d


def get_user_inputs():
    """
    Prompt the user for:
      - folder_path: directory containing NetCDF files (default: "starts")
      - base_name:   base filename (default: "restartpem")
      - infofile:    name of the PEM info file (default: "info_PEM.txt")
    Validates existence of folder and infofile before returning.
    """
    folder_path = input(
        "Enter the folder path containing the NetCDF files "
        "(press Enter for default [starts]): "
    ).strip() or "starts"
    while not os.path.isdir(folder_path):
        print(f"  » \"{folder_path}\" does not exist or is not a directory.")
        folder_path = input(
            "Enter a valid folder path (press Enter for default [starts]): "
        ).strip() or "starts"

    base_name = input(
        "Enter the base name of the NetCDF files "
        "(press Enter for default [restartpem]): "
    ).strip() or "restartpem"

    infofile = input(
        "Enter the name of the PEM info file "
        "(press Enter for default [info_PEM.txt]): "
    ).strip() or "info_PEM.txt"
    while not os.path.isfile(infofile):
        print(f"  » \"{infofile}\" does not exist or is not a file.")
        infofile = input(
            "Enter a valid PEM info filename (press Enter for default [info_PEM.txt]): "
        ).strip() or "info_PEM.txt"

    return folder_path, base_name, infofile


def list_netcdf_files(folder_path, base_name):
    """
    List and sort all NetCDF files matching the pattern {base_name}#.nc
    in folder_path. Returns a sorted list of full file paths.
    """
    pattern = os.path.join(folder_path, f"{base_name}[0-9]*.nc")
    all_files = glob(pattern)
    if not all_files:
        return []

    def extract_index(pathname):
        fname = os.path.basename(pathname)
        idx_str = fname[len(base_name):-3]
        return int(idx_str) if idx_str.isdigit() else float('inf')

    sorted_files = sorted(all_files, key=extract_index)
    return sorted_files


def open_sample_dataset(file_path):
    """
    Open a single NetCDF file and extract:
      - ngrid, nslope
      - longitude, latitude
    Returns (ngrid, nslope, longitude_array, latitude_array).
    """
    with Dataset(file_path, 'r') as ds:
        ngrid = ds.dimensions['physical_points'].size
        nslope = ds.dimensions['nslope'].size
        longitude = ds.variables['longitude'][:].copy()
        latitude = ds.variables['latitude'][:].copy()
    return ngrid, nslope, longitude, latitude


def collect_stratification_variables(files, base_name):
    """
    Scan all files to collect:
      - variable names for each stratification property
      - max number of strata (max_nb_str)
      - global min base elevation and max top elevation
    Returns:
      - var_info: dict mapping each property_name -> sorted list of var names
      - max_nb_str: int
      - min_base_elev: float
      - max_top_elev: float
    """
    max_nb_str = 0
    min_base_elev = np.inf
    max_top_elev = -np.inf

    property_markers = {
        'heights':   'stratif_slope',    # "..._top_elevation"
        'co2_ice':   'h_co2ice',
        'h2o_ice':   'h_h2oice',
        'dust':      'h_dust',
        'pore':      'h_pore',
        'pore_ice':  'icepore_volfrac'
    }
    var_info = {prop: set() for prop in property_markers}

    for file_path in files:
        with Dataset(file_path, 'r') as ds:
            if 'nb_str_max' in ds.dimensions:
                max_nb_str = max(max_nb_str, ds.dimensions['nb_str_max'].size)

            nslope = ds.dimensions['nslope'].size
            for k in range(1, nslope + 1):
                var_name = f"stratif_slope{k:02d}_top_elevation"
                if var_name in ds.variables:
                    arr = ds.variables[var_name][:]
                    min_base_elev = min(min_base_elev, np.min(arr))
                    max_top_elev = max(max_top_elev, np.max(arr))
                    var_info['heights'].add(var_name)

            for full_var in ds.variables:
                for prop, marker in property_markers.items():
                    if (marker in full_var) and prop != 'heights':
                        var_info[prop].add(full_var)

    for prop in var_info:
        var_info[prop] = sorted(var_info[prop])

    return var_info, max_nb_str, min_base_elev, max_top_elev


def load_full_datasets(files):
    """
    Open all NetCDF files and return a list of Dataset objects.
    (They should be closed by the caller after use.)
    """
    return [Dataset(fp, 'r') for fp in files]


def extract_stratification_data(datasets, var_info, ngrid, nslope, max_nb_str):
    """
    Build:
      - heights_data[t_idx][isl] = 2D array (ngrid, n_strata_current) of top_elevations.
      - raw_prop_arrays[prop] = 4D array (ngrid, ntime, nslope, max_nb_str) of per-strata values.
    Returns:
      - heights_data: list (ntime) of lists (nslope) of 2D arrays
      - raw_prop_arrays: dict mapping each property_name -> 4D array
      - ntime: number of time steps (files)
    """
    ntime = len(datasets)

    heights_data = [
        [None for _ in range(nslope)]
        for _ in range(ntime)
    ]
    for t_idx, ds in enumerate(datasets):
        for var_name in var_info['heights']:
            slope_idx = int(var_name.split("slope")[1].split("_")[0]) - 1
            if 0 <= slope_idx < nslope:
                raw = ds.variables[var_name][0, :, :]  # (n_strata, ngrid)
                heights_data[t_idx][slope_idx] = raw.T  # (ngrid, n_strata)

    raw_prop_arrays = {}
    for prop in var_info:
        if prop == 'heights':
            continue
        raw_prop_arrays[prop] = np.zeros((ngrid, ntime, nslope, max_nb_str), dtype=np.float32)

    def slope_index_from_var(vname):
        return int(vname.split("slope")[1].split("_")[0]) - 1

    for prop in raw_prop_arrays:
        slope_map = {}
        for vname in var_info[prop]:
            isl = slope_index_from_var(vname)
            if 0 <= isl < nslope:
                slope_map[isl] = vname

        arr = raw_prop_arrays[prop]
        for t_idx, ds in enumerate(datasets):
            for isl, var_name in slope_map.items():
                raw = ds.variables[var_name][0, :, :]  # (n_strata, ngrid)
                n_strata_current = raw.shape[0]
                arr[:, t_idx, isl, :n_strata_current] = raw.T

    return heights_data, raw_prop_arrays, ntime


def normalize_to_fractions(raw_prop_arrays):
    """
    Given raw_prop_arrays for 'co2_ice', 'h2o_ice', 'dust', 'pore' (in meters),
    normalize each set of strata so that the sum of those four = 1 per strata.
    Returns:
      - frac_arrays: dict mapping same keys -> 4D arrays of fractions (0..1).
    """
    co2 = raw_prop_arrays['co2_ice']
    h2o = raw_prop_arrays['h2o_ice']
    dust = raw_prop_arrays['dust']
    pore = raw_prop_arrays['pore']

    total = co2 + h2o + dust + pore
    mask = total > 0.0

    frac_co2 = np.zeros_like(co2, dtype=np.float32)
    frac_h2o = np.zeros_like(h2o, dtype=np.float32)
    frac_dust = np.zeros_like(dust, dtype=np.float32)
    frac_pore = np.zeros_like(pore, dtype=np.float32)

    frac_co2[mask] = co2[mask] / total[mask]
    frac_h2o[mask] = h2o[mask] / total[mask]
    frac_dust[mask] = dust[mask] / total[mask]
    frac_pore[mask] = pore[mask] / total[mask]

    return {
        'co2_ice': frac_co2,
        'h2o_ice': frac_h2o,
        'dust':     frac_dust,
        'pore':     frac_pore
    }


def read_infofile(file_name):
    """
    Reads "info_PEM.txt". Expects:
      - First line: parameters (ignored).
      - Each subsequent line: floats where first value is timestamp.
    Returns: 1D numpy array of timestamps.
    """
    date_time = []
    with open(file_name, 'r') as fp:
        fp.readline()
        for line in fp:
            parts = line.strip().split()
            if not parts:
                continue
            try:
                date_time.append(float(parts[0]))
            except ValueError:
                continue
    return np.array(date_time, dtype=np.float64)


def get_yes_no_input(prompt: str) -> bool:
    """
    Prompt the user with a yes/no question. Returns True for yes, False for no.
    """
    while True:
        choice = input(f"{prompt} (y/n): ").strip().lower()
        if choice in ['y', 'yes']:
            return True
        elif choice in ['n', 'no']:
            return False
        else:
            print("Please respond with y or n.")


def prompt_discretization_step(max_top_elev):
    """
    Prompt for a positive float dz such that 0 < dz <= max_top_elev.
    """
    while True:
        entry = input(
            "Enter the discretization step of the reference grid for the elevation [m]: "
        ).strip()
        try:
            dz = float(entry)
            if dz <= 0:
                print("  » Discretization step must be strictly positive!")
                continue
            if dz > max_top_elev:
                print(
                    f"  » {dz:.3e} m is greater than the maximum top elevation "
                    f"({max_top_elev:.3e} m). Please enter a smaller value."
                )
                continue
            return dz
        except ValueError:
            print("  » Invalid numeric value. Please try again.")


def interpolate_data_on_refgrid(
    heights_data,
    prop_arrays,
    min_base_for_interp,
    max_top_elev,
    dz,
    exclude_sub=False
):
    """
    Build a reference grid and interpolate strata fractions (0..1) onto it.

    Also returns a 'top_index' array of shape (ngrid, ntime, nslope) that
    indicates, for each (ig, t_idx, isl), the number of ref_grid levels
    covered by the topmost valid stratum.

    Args:
      - heights_data: list of lists where heights_data[t][isl] is a 2D array
          (ngrid, n_strata_current) of top_elevation values.
      - prop_arrays: dict mapping each property_name to a 4D array of shape
          (ngrid, ntime, nslope, max_nb_str) holding fractions [0..1].
      - min_base_for_interp: float; if exclude_sub=True, this is 0.0.
      - max_top_elev: float
      - dz: float
      - exclude_sub: bool. If True, ignore strata with elevation < 0.

    Returns:
      - ref_grid: 1D array of elevations (nz,)
      - gridded_data: dict mapping each property_name to a 4D array of shape
          (ngrid, ntime, nslope, nz) with interpolated fractions.
      - top_index: 3D array (ngrid, ntime, nslope) of ints: number of levels
          of ref_grid covered by the topmost stratum.
    """
    # Build ref_grid, ensuring at least two points if surface-only and dz > max_top_elev
    if exclude_sub and (dz > max_top_elev):
        ref_grid = np.array([0.0, max_top_elev], dtype=np.float32)
    else:
        ref_grid = np.arange(min_base_for_interp, max_top_elev + dz / 2, dz)
    nz = len(ref_grid)
    print(f"> Number of reference grid points = {nz}")

    # Dimensions
    sample_prop = next(iter(prop_arrays.values()))
    ngrid, ntime, nslope, max_nb_str = sample_prop.shape[0:4]

    # Prepare outputs
    gridded_data = {
        prop: np.full((ngrid, ntime, nslope, nz), -1.0, dtype=np.float32)
        for prop in prop_arrays
    }
    top_index = np.zeros((ngrid, ntime, nslope), dtype=np.int32)

    for ig in range(ngrid):
        for t_idx in range(ntime):
            for isl in range(nslope):
                h_mat = heights_data[t_idx][isl]
                if h_mat is None:
                    continue

                raw_h = h_mat[ig, :]  # (n_strata_current,)
                # Create h_all of length max_nb_str, fill with NaN
                h_all = np.full((max_nb_str,), np.nan, dtype=np.float32)
                n_strata_current = raw_h.shape[0]
                h_all[:n_strata_current] = raw_h

                if exclude_sub:
                    epsilon = 1e-6
                    valid_mask = (h_all >= -epsilon)
                else:
                    valid_mask = (~np.isnan(h_all)) & (h_all != 0.0)

                if not np.any(valid_mask):
                    continue

                h_valid = h_all[valid_mask]
                top_h = np.max(h_valid)

                # Find i_zmax = number of ref_grid levels z <= top_h
                i_zmax = np.searchsorted(ref_grid, top_h, side='right')
                top_index[ig, t_idx, isl] = i_zmax

                if i_zmax == 0:
                    # top_h < ref_grid[0], skip interpolation
                    continue

                for prop, arr in prop_arrays.items():
                    prop_profile_all = arr[ig, t_idx, isl, :]    # (max_nb_str,)
                    prop_profile = prop_profile_all[valid_mask]  # (n_valid_strata,)
                    if prop_profile.size == 0:
                        continue

                    # Step‐wise interpolation (kind='next')
                    f_interp = interp1d(
                        h_valid,
                        prop_profile,
                        kind='next',
                        bounds_error=False,
                        fill_value=-1.0
                    )

                    # Evaluate for ref_grid[0:i_zmax]
                    gridded_data[prop][ig, t_idx, isl, :i_zmax] = f_interp(ref_grid[:i_zmax])

    return ref_grid, gridded_data, top_index


def plot_stratification_over_time(
    gridded_data,
    ref_grid,
    top_index,
    heights_data,
    date_time,
    exclude_sub=False,
    output_folder="."
):
    """
    For each grid point (ig) and slope (isl), generate a 2×2 figure:
      - CO2 ice fraction
      - H2O ice fraction
      - Dust fraction
      - Pore fraction

    Fractions are in [0..1]. Values < 0 (fill) are masked.
    Using top_index, any elevation above the last stratum is forced to NaN (white).

    Additionally, draw horizontal violet line segments at each stratum top elevation
    only over the interval [date_time[t_idx], date_time[t_idx+1]] where that stratum
    exists at time t_idx. This way, boundaries appear only where the strata exist.
    """
    import numpy as np
    import matplotlib.pyplot as plt

    prop_names = ['co2_ice', 'h2o_ice', 'dust', 'pore']
    titles = ["CO2 ice", "H2O ice", "Dust", "Pore"]
    cmap = plt.get_cmap('hot_r').copy()
    cmap.set_under('white')
    vmin, vmax = 0.0, 1.0

    sample_prop = next(iter(gridded_data.values()))
    ngrid, ntime, nslope, nz = sample_prop.shape

    if exclude_sub:
        positive_indices = np.where(ref_grid >= 0.0)[0]
        if positive_indices.size == 0:
            print("Warning: no positive elevations in ref_grid → nothing to display.")
            return
        sub_ref_grid = ref_grid[positive_indices]
    else:
        positive_indices = np.arange(nz)
        sub_ref_grid = ref_grid

    for ig in range(ngrid):
        for isl in range(nslope):
            fig, axes = plt.subplots(2, 2, figsize=(10, 8))
            fig.suptitle(
                f"Content variation over time for (Grid Point {ig+1}, Slope {isl+1})",
                fontsize=14
            )

            # For each time step t_idx, gather this stratum's valid tops
            # and draw a line segment from date_time[t_idx] to date_time[t_idx+1].
            # We'll skip t_idx = ntime - 1 since no next point.
            # Precompute, for each t_idx, the array of valid top elevations:
            valid_tops_per_time = []
            for t_idx in range(ntime):
                raw_h = heights_data[t_idx][isl][ig, :]  # (n_strata_current,)
                # Exclude NaNs or zeros
                h_all = raw_h[~np.isnan(raw_h)]
                if exclude_sub:
                    h_all = h_all[h_all >= 0.0]
                valid_tops_per_time.append(np.unique(h_all))

            for idx, prop in enumerate(prop_names):
                ax = axes.flat[idx]
                data_3d = gridded_data[prop][ig, :, isl, :]    # shape (ntime, nz)
                mat_full = data_3d.T                           # shape (nz, ntime)
                mat = mat_full[positive_indices, :].copy()     # (nz_pos, ntime)

                # Mask fill values (< 0) as NaN
                mat[mat < 0.0] = np.nan

                # Mask everything above the top stratum using top_index
                for t_idx in range(ntime):
                    i_zmax = top_index[ig, t_idx, isl]
                    if i_zmax <= positive_indices[0]:
                        mat[:, t_idx] = np.nan
                    else:
                        count_z = np.count_nonzero(positive_indices < i_zmax)
                        mat[count_z:, t_idx] = np.nan

                # Draw pcolormesh
                im = ax.pcolormesh(
                    date_time,
                    sub_ref_grid,
                    mat,
                    cmap=cmap,
                    shading='auto',
                    vmin=vmin,
                    vmax=vmax
                )
                ax.set_title(titles[idx], fontsize=12)
                ax.set_xlabel("Time (y)")
                ax.set_ylabel("Elevation (m)")

                # Draw horizontal violet segments only where strata exist
                for t_idx in range(ntime - 1):
                    h_vals = valid_tops_per_time[t_idx]
                    if h_vals.size == 0:
                        continue
                    t_left = date_time[t_idx]
                    t_right = date_time[t_idx + 1]
                    for h in h_vals:
                        # Only draw if h falls within sub_ref_grid
                        if h < sub_ref_grid[0] or h > sub_ref_grid[-1]:
                            continue
                        ax.hlines(
                            y=h,
                            xmin=t_left,
                            xmax=t_right,
                            color='violet',
                            linewidth=0.7,
                            linestyle='-'
                        )

            # Reserve extra space on the right for the colorbar
            fig.subplots_adjust(right=0.88)

            # Place a single shared colorbar in its own axes
            cbar_ax = fig.add_axes([0.90, 0.15, 0.02, 0.7])
            fig.colorbar(
                im,
                cax=cbar_ax,
                orientation='vertical',
                label="Content"
            )

            # Tight layout excluding the region we reserved (0.88)
            fig.tight_layout(rect=[0, 0, 0.88, 1.0])

            fname = os.path.join(
                output_folder, f"layering_evolution_ig{ig+1}_is{isl+1}.png"
            )
            fig.savefig(fname, dpi=150)
            plt.show()
            plt.close(fig)


def main():
    # 1) Get user inputs
    folder_path, base_name, infofile = get_user_inputs()

    # 2) List and verify NetCDF files
    files = list_netcdf_files(folder_path, base_name)
    if not files:
        print(f"No NetCDF files named \"{base_name}#.nc\" found in \"{folder_path}\". Exiting.")
        sys.exit(1)
    nfile = len(files)
    print(f"> Found {nfile} NetCDF file(s) matching \"{base_name}#.nc\".")

    # 3) Open one sample to get ngrid, nslope, lon/lat
    sample_file = files[0]
    ngrid, nslope, longitude, latitude = open_sample_dataset(sample_file)
    print(f"> ngrid  = {ngrid}")
    print(f"> nslope = {nslope}")

    # 4) Scan all files to collect variable info + global min/max elevations
    var_info, max_nb_str, min_base_elev, max_top_elev = collect_stratification_variables(
        files, base_name
    )
    print(f"> max(nb_str_max)      = {max_nb_str}")
    print(f"> min(base_elevation)  = {min_base_elev:.3f}")
    print(f"> max(top_elevation)   = {max_top_elev:.3f}")

    # 5) Open all datasets for extraction
    datasets = load_full_datasets(files)

    # 6) Extract raw stratification data
    heights_data, raw_prop_arrays, ntime = extract_stratification_data(
        datasets, var_info, ngrid, nslope, max_nb_str
    )

    # 7) Close all datasets
    for ds in datasets:
        ds.close()

    # 8) Normalize raw prop arrays to volume fractions
    frac_arrays = normalize_to_fractions(raw_prop_arrays)

    # 9) Ask whether to show subsurface
    show_subsurface = get_yes_no_input("Show subsurface layers?")
    exclude_sub = not show_subsurface

    if exclude_sub:
        min_base_for_interp = 0.0
        print("> Will interpolate only elevations >= 0 m (surface strata).")
    else:
        min_base_for_interp = min_base_elev
        print(f"> Will interpolate full depth (min base = {min_base_elev:.3f} m).")

    # 10) Prompt for discretization step
    dz = prompt_discretization_step(max_top_elev)

    # 11) Build reference grid and interpolate (returns top_index as well)
    ref_grid, gridded_data, top_index = interpolate_data_on_refgrid(
        heights_data,
        frac_arrays,
        min_base_for_interp,
        max_top_elev,
        dz,
        exclude_sub=exclude_sub
    )

    # 12) Read time stamps from "info_PEM.txt"
    date_time = read_infofile(infofile)
    if date_time.size != ntime:
        print(
            "Warning: number of timestamps does not match number of NetCDF files "
            f"({date_time.size} vs {ntime})."
        )

    # 13) Plot and save figures (passing top_index and heights_data)
    plot_stratification_over_time(
        gridded_data,
        ref_grid,
        top_index,
        heights_data,
        date_time,
        exclude_sub=exclude_sub,
        output_folder="."
    )


if __name__ == "__main__":
    main()

