##############################################################################################
### Python script to output the stratification data over time from the "startpem.nc" files ###
##############################################################################################

import os
import sys
import numpy as np
from netCDF4 import Dataset
import matplotlib.pyplot as plt
from scipy import interpolate

### Function to get inputs
def get_user_inputs():
    folder_path = input("Enter the folder path containing the NetCDF files (press the Enter key for default [starts]): ").strip()
    if not folder_path:
        folder_path = "starts"
    while not os.path.isdir(folder_path):
        print("Invalid folder path. Please try again.")
        folder_path = input("Enter the folder path containing the NetCDF files (press the Enter key for default [starts]): ").strip()
        if not folder_path:
            folder_path = "starts"

    base_name = input("Enter the base name of the NetCDF files (press the Enter key for default [restartpem]): ").strip()
    if not base_name:
        base_name = "restartpem"

    infofile = input("Enter the name of the PEM info file (press the Enter key for default [info_PEM.txt]): ").strip()
    if not infofile:
        infofile = "info_PEM.txt"
    while not os.path.isfile(infofile):
        print("Invalid file path. Please try again.")
        infofile = input("Enter the name of the PEM info file (press the Enter key for default [info_PEM.txt]): ").strip()
        if not infofile:
            infofile = "info_PEM.txt"

    return folder_path, base_name, infofile

### Function to read the "startpem.nc" files and process their stratification data
def process_files(folder_path,base_name):
    # Find all files in the directory with the pattern {base_name}{num}.nc
    nfile = 0
    for file_name in sorted(os.listdir(folder_path)):
        if file_name.startswith(base_name) and file_name.endswith('.nc'):
            file_number = file_name[len(base_name):-3]
            if file_number.isdigit():
                nfile += 1

    if nfile == 0:
        print("No files found. Exiting...")
        return

    # Process each file and collect data
    datasets = []
    min_base_elevation = -56943.759374550937 # Base elevation of the deepest subsurface layer
    max_top_elevation = 0.
    max_nb_str = 0
    with Dataset(os.path.join(folder_path, base_name + "1.nc"), 'r') as ds:
        ngrid = ds.dimensions['physical_points'].size # ngrid is the same for all files
        nslope = ds.dimensions['nslope'].size # nslope is the same for all files
        longitude = ds.variables['longitude'][:]
        latitude = ds.variables['latitude'][:]

    for i in range(1,nfile + 1):
        file_path = os.path.join(folder_path,base_name + str(i) + ".nc")
        #print(f"Processing file: {file_path}")
        ds = Dataset(file_path,'r')
        datasets.append(ds)

        # Track max of nb_str_max
        max_nb_str = max(ds.dimensions['nb_str_max'].size,max_nb_str)

        # Track max of top_elevation across all slopes
        for k in range(1,nslope + 1):
            slope_var_name = f"stratif_slope{k:02d}_top_elevation"
            min_base_elevation = min(min_base_elevation,np.min(ds.variables[slope_var_name][:]))
            max_top_elevation = max(max_top_elevation,np.max(ds.variables[slope_var_name][:]))

    print(f"> number of files     = {nfile}")
    print(f"> ngrid               = {ngrid}")
    print(f"> nslope              = {nslope}")
    print(f"> max(nb_str_max)     = {max_nb_str}")
    print(f"> min(base_elevation) = {min_base_elevation}")
    print(f"> max(top_elevation)  = {max_top_elevation}")

    # Concatenate stratif variables with dimension 'nb_str_max' along the "Time" dimension
    stratif_data = []
    stratif_heights = np.zeros((ngrid,nfile,nslope,max_nb_str))
    stratif_co2ice = np.zeros((ngrid,nfile,nslope,max_nb_str))
    stratif_h2oice = np.zeros((ngrid,nfile,nslope,max_nb_str))
    stratif_dust = np.zeros((ngrid,nfile,nslope,max_nb_str))
    stratif_pore = np.zeros((ngrid,nfile,nslope,max_nb_str))
    stratif_poreice = np.zeros((ngrid,nfile,nslope,max_nb_str))
    for var_name in datasets[0].variables:
        if 'top_elevation' in var_name:
            for i in range(0,ngrid):
                for j in range(0,nfile):
                    for k in range(0,nslope):
                        if f'slope{k + 1:02d}' in var_name:
                            stratif_heights[i,j,k,:datasets[j].variables[var_name].shape[1]] = datasets[j].variables[var_name][0,:,i]
            print(f"Processed variable: {var_name}")
        elif 'h_co2ice' in var_name:
            for i in range(0,ngrid):
                for j in range(0,nfile):
                    for k in range(0,nslope):
                        if f'slope{k + 1:02d}' in var_name:
                            stratif_co2ice[i,j,k,:datasets[j].variables[var_name].shape[1]] = datasets[j].variables[var_name][0,:,i]
            print(f"Processed variable: {var_name}")
        elif 'h_h2oice' in var_name:
            for i in range(0,ngrid):
                for j in range(0,nfile):
                    for k in range(0,nslope):
                        if f'slope{k + 1:02d}' in var_name:
                            stratif_h2oice[i,j,k,:datasets[j].variables[var_name].shape[1]] = datasets[j].variables[var_name][0,:,i]
            print(f"Processed variable: {var_name}")
        elif 'h_dust' in var_name:
            for i in range(0,ngrid):
                for j in range(0,nfile):
                    for k in range(0,nslope):
                        if f'slope{k + 1:02d}' in var_name:
                            stratif_dust[i,j,k,:datasets[j].variables[var_name].shape[1]] = datasets[j].variables[var_name][0,:,i]
            print(f"Processed variable: {var_name}")
        elif 'h_pore' in var_name:
            for i in range(0,ngrid):
                for j in range(0,nfile):
                    for k in range(0,nslope):
                        if f'slope{k + 1:02d}' in var_name:
                            stratif_pore[i,j,k,:datasets[j].variables[var_name].shape[1]] = datasets[j].variables[var_name][0,:,i]
            print(f"Processed variable: {var_name}")
        elif 'icepore_volfrac' in var_name:
            for i in range(0,ngrid):
                for j in range(0,nfile):
                    for k in range(0,nslope):
                        if f'slope{k + 1:02d}' in var_name:
                            stratif_poreice[i,j,k,:datasets[j].variables[var_name].shape[1]] = datasets[j].variables[var_name][0,:,i]
            print(f"Processed variable: {var_name}")

    # Close the datasets
    for ds in datasets:
        ds.close()

    stratif_data = [stratif_heights,stratif_co2ice,stratif_h2oice,stratif_dust,stratif_pore]

    while True:
        try:
            dz = float(input("Enter the discretization step of the reference grid for the elevation [m]: ").strip())
            if dz <= 0:
                print("Discretization step must be strictly positive!")
                continue
            if dz > max_top_elevation:
                print("Discretization step is higher than the maximum top elevation: please provide a correct value!")
                continue
            break
        except ValueError:
            print("Invalid value.")
    return stratif_data, min_base_elevation, max_top_elevation, longitude, latitude, dz

### Function to interpolate the stratification data on a reference grid
def interpolate_data(stratif_data,min_base_elevation,max_top_elevation,dz):
    # Define the reference ref_grid
    ref_grid = np.arange(min_base_elevation,max_top_elevation,dz)
    print(f"> Number of ref_grid points = {len(ref_grid)}")

    # Interpolate the strata properties on the ref_grid
    gridded_stratif_data = -1.*np.ones((np.shape(stratif_data)[0] - 1,np.shape(stratif_data)[1],np.shape(stratif_data)[2],np.shape(stratif_data)[3],len(ref_grid)))
    for iprop in range(1,np.shape(stratif_data)[0]):
        for i in range(np.shape(stratif_data)[1]):
            for j in range(np.shape(stratif_data)[2]):
                for k in range(np.shape(stratif_data)[3]):
                    i_hmax = np.max(np.nonzero(stratif_data[0][i,j,k,:]))
                    hmax = stratif_data[0][i,j,k,i_hmax]
                    i_zmax = np.searchsorted(ref_grid,hmax,'left')
                    f = interpolate.interp1d(stratif_data[0][i,j,k,:i_hmax + 1],stratif_data[iprop][i,j,k,:i_hmax + 1],kind = 'next')#,fill_value = "extrapolate")
                    gridded_stratif_data[iprop - 1,i,j,k,:i_zmax] = f(ref_grid[:i_zmax])

    return ref_grid, gridded_stratif_data

### Function to read the "info_PEM.txt" file
def read_infofile(file_name):
    with open(file_name,'r') as file:
        # Read the first line to get the parameters
        first_line = file.readline().strip()
        parameters = list(map(float,first_line.split()))
        
        # Read the following lines
        data_lines = []
        date_time = []
        for line in file:
            data = list(map(float,line.split()))
            data_lines.append(data)
            date_time.append(data[0])

    return date_time

### Processing
folder_path, base_name, infofile = get_user_inputs()
stratif_data, min_base_elevation, max_top_elevation, longitude, latitude, dz = process_files(folder_path,base_name)
ref_grid, gridded_stratif_data = interpolate_data(stratif_data,min_base_elevation,max_top_elevation,dz)
date_time = read_infofile(infofile)

### Figures plotting
subtitle = ['CO2 ice','H2O ice','Dust','Pore']
cmap = plt.get_cmap('viridis').copy()
cmap.set_under('white')
for igr in range(np.shape(gridded_stratif_data)[1]):
    for isl in range(np.shape(gridded_stratif_data)[3]):
        fig, axes = plt.subplots(2,2,figsize = (10,8))
        fig.suptitle(f'Contents variation over time in the layered-deposit of grid point {igr + 1} and slope {isl + 1}')
        iprop = 0
        for ax in axes.flat:
            time_mesh, elevation_mesh = np.meshgrid(date_time,ref_grid)
            #im = ax.imshow(np.transpose(gridded_stratif_data[iprop][igr,:,isl,:]),aspect = 'auto',cmap = cmap,origin = 'lower',extent = [date_time[0],date_time[-1],ref_grid[0],ref_grid[-1]],vmin = 0,vmax = 1)
            im = ax.pcolormesh(time_mesh,elevation_mesh,np.transpose(gridded_stratif_data[iprop][igr,:,isl,:]),cmap = cmap,shading = 'auto',vmin = 0,vmax = 1)
            ax.set_title(subtitle[iprop])
            ax.set(xlabel = 'Time (y)',ylabel = 'Elevation (m)')
            #ax.label_outer()
            iprop += 1
        cbar = fig.colorbar(im,ax = axes.ravel().tolist(),label = 'Content value')
        plt.savefig(f"layering_evolution_ig{igr + 1}_is{isl + 1}.png")
        plt.show()
