source: trunk/LMDZ.COMMON/libf/evolution/deftank/visu_evol_layering.py @ 3958

Last change on this file since 3958 was 3926, checked in by jbclement, 7 weeks ago

PEM:
Updates of files in the deftank + parameter values for the layering algorithm.
JBC

  • Property svn:executable set to *
File size: 46.1 KB
Line 
1#!/usr/bin/env python3
2#######################################################################################################
3### Python script to output stratification data over time from "restartpem#.nc" files               ###
4### and to plot orbital parameters from "obl_ecc_lsp.asc"                                           ###
5#######################################################################################################
6
7import os
8import sys
9import numpy as np
10from glob import glob
11from netCDF4 import Dataset
12import matplotlib.pyplot as plt
13from mpl_toolkits.axes_grid1.inset_locator import inset_axes
14from matplotlib.colors import LinearSegmentedColormap, LogNorm
15from matplotlib.ticker import FuncFormatter
16from scipy.interpolate import interp1d
17
18
19def get_user_inputs():
20    """
21    Prompt the user for:
22      - folder_path: directory containing NetCDF files (default: "starts")
23      - base_name:   base filename (default: "restartpem")
24      - infofile:    name of the PEM info file (default: "info_PEM.txt")
25    Validates existence of folder and infofile before returning.
26    """
27    folder_path = input(
28        "Enter the folder path containing the NetCDF files "
29        "(press Enter for default [starts]): "
30    ).strip() or "starts"
31    while not os.path.isdir(folder_path):
32        print(f"  » \"{folder_path}\" does not exist or is not a directory.")
33        folder_path = input(
34            "Enter a valid folder path (press Enter for default [starts]): "
35        ).strip() or "starts"
36
37    base_name = input(
38        "Enter the base name of the NetCDF files "
39        "(press Enter for default [restartpem]): "
40    ).strip() or "restartpem"
41
42    infofile = input(
43        "Enter the name of the PEM info file "
44        "(press Enter for default [info_PEM.txt]): "
45    ).strip() or "info_PEM.txt"
46    while not os.path.isfile(infofile):
47        print(f"  » \"{infofile}\" does not exist or is not a file.")
48        infofile = input(
49            "Enter a valid PEM info filename (press Enter for default [info_PEM.txt]): "
50        ).strip() or "info_PEM.txt"
51
52    orbfile = input(
53        "Enter the name of the orbital parameters ASCII file "
54        "(press Enter for default [obl_ecc_lsp.asc]): "
55    ).strip() or "obl_ecc_lsp.asc"
56    while not os.path.isfile(orbfile):
57        print(f"  » \"{orbfile}\" does not exist or is not a file.")
58        orbfile = input(
59            "Enter a valid orbital parameters ASCII filename (press Enter for default [obl_ecc_lsp.asc]): "
60        ).strip() or "info_PEM.txt"
61
62    return folder_path, base_name, infofile, orbfile
63
64
65def list_netcdf_files(folder_path, base_name):
66    """
67    List and sort all NetCDF files matching the pattern {base_name}#.nc
68    in folder_path. Returns a sorted list of full file paths.
69    """
70    pattern = os.path.join(folder_path, f"{base_name}[0-9]*.nc")
71    all_files = glob(pattern)
72    if not all_files:
73        return []
74
75    def extract_index(pathname):
76        fname = os.path.basename(pathname)
77        idx_str = fname[len(base_name):-3]
78        return int(idx_str) if idx_str.isdigit() else float('inf')
79
80    sorted_files = sorted(all_files, key=extract_index)
81    return sorted_files
82
83
84def open_sample_dataset(file_path):
85    """
86    Open a single NetCDF file and extract:
87      - ngrid, nslope
88      - longitude, latitude
89    Returns (ngrid, nslope, longitude_array, latitude_array).
90    """
91    with Dataset(file_path, 'r') as ds:
92        ngrid = ds.dimensions['physical_points'].size
93        nslope = ds.dimensions['nslope'].size
94        longitude = ds.variables['longitude'][:].copy()
95        latitude = ds.variables['latitude'][:].copy()
96    return ngrid, nslope, longitude, latitude
97
98
99def collect_stratification_variables(files, base_name):
100    """
101    Scan all files to collect:
102      - variable names for each stratification property
103      - max number of strata (max_nb_str)
104      - global min base elevation and max top elevation
105    Returns:
106      - var_info: dict mapping each property_name -> sorted list of var names
107      - max_nb_str: int
108      - min_base_elev: float
109      - max_top_elev: float
110    """
111    max_nb_str = 0
112    min_base_elev = np.inf
113    max_top_elev = -np.inf
114
115    property_markers = {
116        'heights':   'stratif_slope',    # "..._top_elevation"
117        'co2_ice':   'h_co2ice',
118        'h2o_ice':   'h_h2oice',
119        'dust':      'h_dust',
120        'pore':      'h_pore',
121        'pore_ice':  'poreice_volfrac'
122    }
123    var_info = {prop: set() for prop in property_markers}
124
125    for file_path in files:
126        with Dataset(file_path, 'r') as ds:
127            if 'nb_str_max' in ds.dimensions:
128                max_nb_str = max(max_nb_str, ds.dimensions['nb_str_max'].size)
129
130            nslope = ds.dimensions['nslope'].size
131            for k in range(1, nslope + 1):
132                var_name = f"stratif_slope{k:02d}_top_elevation"
133                if var_name in ds.variables:
134                    arr = ds.variables[var_name][:]
135                    min_base_elev = min(min_base_elev, np.min(arr))
136                    max_top_elev = max(max_top_elev, np.max(arr))
137                    var_info['heights'].add(var_name)
138
139            for full_var in ds.variables:
140                for prop, marker in property_markers.items():
141                    if (marker in full_var) and prop != 'heights':
142                        var_info[prop].add(full_var)
143
144    for prop in var_info:
145        var_info[prop] = sorted(var_info[prop])
146
147    return var_info, max_nb_str, min_base_elev, max_top_elev
148
149
150def load_full_datasets(files):
151    """
152    Open all NetCDF files and return a list of Dataset objects.
153    (They should be closed by the caller after use.)
154    """
155    return [Dataset(fp, 'r') for fp in files]
156
157
158def extract_stratification_data(datasets, var_info, ngrid, nslope, max_nb_str):
159    """
160    Build:
161      - heights_data[t_idx][isl] = 2D array (ngrid, n_strata_current) of top_elevations.
162      - raw_prop_arrays[prop] = 4D array (ngrid, ntime, nslope, max_nb_str) of per-strata values.
163    Returns:
164      - heights_data: list (ntime) of lists (nslope) of 2D arrays
165      - raw_prop_arrays: dict mapping each property_name -> 4D array
166      - ntime: number of time steps (files)
167    """
168    ntime = len(datasets)
169
170    heights_data = [
171        [None for _ in range(nslope)]
172        for _ in range(ntime)
173    ]
174    for t_idx, ds in enumerate(datasets):
175        for var_name in var_info['heights']:
176            slope_idx = int(var_name.split("slope")[1].split("_")[0]) - 1
177            if 0 <= slope_idx < nslope:
178                raw = ds.variables[var_name][0, :, :]  # (n_strata, ngrid)
179                heights_data[t_idx][slope_idx] = raw.# (ngrid, n_strata)
180
181    raw_prop_arrays = {}
182    for prop in var_info:
183        if prop == 'heights':
184            continue
185        raw_prop_arrays[prop] = np.zeros((ngrid, ntime, nslope, max_nb_str), dtype=np.float32)
186
187    def slope_index_from_var(vname):
188        return int(vname.split("slope")[1].split("_")[0]) - 1
189
190    for prop in raw_prop_arrays:
191        slope_map = {}
192        for vname in var_info[prop]:
193            isl = slope_index_from_var(vname)
194            if 0 <= isl < nslope:
195                slope_map[isl] = vname
196
197        arr = raw_prop_arrays[prop]
198        for t_idx, ds in enumerate(datasets):
199            for isl, var_name in slope_map.items():
200                raw = ds.variables[var_name][0, :, :]  # (n_strata, ngrid)
201                n_strata_current = raw.shape[0]
202                arr[:, t_idx, isl, :n_strata_current] = raw.T
203
204    return heights_data, raw_prop_arrays, ntime
205
206
207def normalize_to_fractions(raw_prop_arrays):
208    """
209    Given raw_prop_arrays for 'co2_ice', 'h2o_ice', 'dust', 'pore' (in meters),
210    normalize each set of strata so that the sum of those four = 1 per cell.
211    Returns:
212      - frac_arrays: dict mapping same keys -> 4D arrays of fractions (0..1).
213    """
214    co2 = raw_prop_arrays['co2_ice']
215    h2o = raw_prop_arrays['h2o_ice']
216    dust = raw_prop_arrays['dust']
217    pore = raw_prop_arrays['pore']
218
219    total = co2 + h2o + dust + pore
220    mask = total > 0.0
221
222    frac_co2 = np.zeros_like(co2, dtype=np.float32)
223    frac_h2o = np.zeros_like(h2o, dtype=np.float32)
224    frac_dust = np.zeros_like(dust, dtype=np.float32)
225    frac_pore = np.zeros_like(pore, dtype=np.float32)
226
227    frac_co2[mask] = co2[mask] / total[mask]
228    frac_h2o[mask] = h2o[mask] / total[mask]
229    frac_dust[mask] = dust[mask] / total[mask]
230    frac_pore[mask] = pore[mask] / total[mask]
231
232    return {
233        'co2_ice': frac_co2,
234        'h2o_ice': frac_h2o,
235        'dust':     frac_dust,
236        'pore':     frac_pore
237    }
238
239
240def read_infofile(file_name):
241    """
242    Reads "info_PEM.txt". Expects:
243      - First line: parameters where the 3rd value is martian_to_earth conversion factor.
244      - Each subsequent line: floats where first value is simulation timestamp (in Mars years).
245    Returns:
246      - date_time: 1D numpy array of timestamps (Mars years)
247      - martian_to_earth: float conversion factor
248    """
249    date_time = []
250    with open(file_name, 'r') as fp:
251        first = fp.readline().split()
252        martian_to_earth = float(first[2])
253        for line in fp:
254            parts = line.strip().split()
255            if not parts:
256                continue
257            try:
258                date_time.append(float(parts[0]))
259            except ValueError:
260                continue
261    return np.array(date_time, dtype=np.float64), martian_to_earth
262
263
264def get_yes_no_input(prompt: str) -> bool:
265    """
266    Prompt the user with a yes/no question. Returns True for yes, False for no.
267    """
268    while True:
269        choice = input(f"{prompt} (y/n): ").strip().lower()
270        if choice in ['y', 'yes']:
271            return True
272        elif choice in ['n', 'no']:
273            return False
274        else:
275            print("Please respond with y or n.")
276
277
278def prompt_discretization_step(max_top_elev):
279    """
280    Prompt for a positive float dz such that 0 < dz <= max_top_elev.
281    """
282    while True:
283        entry = input(
284            "Enter the discretization step of the reference grid for the elevation [m]: "
285        ).strip()
286        try:
287            dz = float(entry)
288            if dz <= 0:
289                print("  » Discretization step must be strictly positive!")
290                continue
291            if dz > max_top_elev:
292                print(
293                    f"  » {dz:.3e} m is greater than the maximum top elevation "
294                    f"({max_top_elev:.3e} m). Please enter a smaller value."
295                )
296                continue
297            return dz
298        except ValueError:
299            print("  » Invalid numeric value. Please try again.")
300
301
302def interpolate_data_on_refgrid(
303    heights_data,
304    prop_arrays,
305    min_base_for_interp,
306    max_top_elev,
307    dz,
308    exclude_sub=False
309):
310    """
311    Build a reference elevation grid and interpolate strata fractions onto it.
312
313    Returns:
314      - ref_grid: 1D array of elevations (nz,)
315      - gridded_data: dict mapping each property_name to 4D array
316        (ngrid, ntime, nslope, nz) with interpolated fractions.
317      - top_index: 3D array (ngrid, ntime, nslope) of ints:
318        number of levels covered by the topmost stratum.
319    """
320    if exclude_sub and (dz > max_top_elev):
321        ref_grid = np.array([0.0, max_top_elev], dtype=np.float32)
322    else:
323        ref_grid = np.arange(min_base_for_interp, max_top_elev + dz/2, dz)
324    nz = len(ref_grid)
325    print(f"> Number of reference grid points = {nz}")
326
327    sample_prop = next(iter(prop_arrays.values()))
328    ngrid, ntime, nslope, max_nb_str = sample_prop.shape
329
330    gridded_data = {
331        prop: np.full((ngrid, ntime, nslope, nz), -1.0, dtype=np.float32)
332        for prop in prop_arrays
333    }
334    top_index = np.zeros((ngrid, ntime, nslope), dtype=np.int32)
335
336    for ig in range(ngrid):
337        for t_idx in range(ntime):
338            for isl in range(nslope):
339                h_mat = heights_data[t_idx][isl]
340                if h_mat is None:
341                    continue
342
343                raw_h = h_mat[ig, :]
344                h_all = np.full((max_nb_str,), np.nan, dtype=np.float32)
345                n_strata_current = raw_h.shape[0]
346                h_all[:n_strata_current] = raw_h
347
348                if exclude_sub:
349                    epsilon = 1e-6
350                    valid_mask = (h_all >= -epsilon)
351                else:
352                    valid_mask = (~np.isnan(h_all)) & (h_all != 0.0)
353
354                if not np.any(valid_mask):
355                    continue
356
357                h_valid = h_all[valid_mask]
358                top_h = np.max(h_valid)
359                i_zmax = np.searchsorted(ref_grid, top_h, side='right')
360                top_index[ig, t_idx, isl] = i_zmax
361                if i_zmax == 0:
362                    continue
363
364                for prop, arr in prop_arrays.items():
365                    prop_profile_all = arr[ig, t_idx, isl, :]
366                    prop_profile = prop_profile_all[valid_mask]
367                    if prop_profile.size == 0:
368                        continue
369
370                    f_interp = interp1d(
371                        h_valid,
372                        prop_profile,
373                        kind='next',
374                        bounds_error=False,
375                        fill_value=-1.0
376                    )
377                    gridded_data[prop][ig, t_idx, isl, :i_zmax] = f_interp(ref_grid[:i_zmax])
378
379    return ref_grid, gridded_data, top_index
380
381
382def attach_format_coord(ax, mat, x, y, is_pcolormesh=True):
383    """
384    Attach a format_coord function to the axes to display x, y, and value at cursor.
385    Works for both pcolormesh and imshow style grids.
386    """
387    # Determine dimensions
388    if mat.ndim == 2:
389        ny, nx = mat.shape
390    elif mat.ndim == 3 and mat.shape[2] in (3, 4):
391        ny, nx, nc = mat.shape
392    else:
393        raise ValueError(f"Unsupported mat shape {mat.shape}")
394    # Edges or extents
395    if is_pcolormesh:
396        xedges, yedges = x, y
397    else:
398        x0, x1 = x.min(), x.max()
399        y0, y1 = y.min(), y.max()
400
401    def format_coord(xp, yp):
402        # Map to indices
403        if is_pcolormesh:
404            col = np.searchsorted(xedges, xp) - 1
405            row = np.searchsorted(yedges, yp) - 1
406        else:
407            col = int((xp - x0) / (x1 - x0) * nx)
408            row = int((yp - y0) / (y1 - y0) * ny)
409        # Within bounds?
410        if 0 <= row < ny and 0 <= col < nx:
411            if mat.ndim == 2:
412                v = mat[row, col]
413                return f"x={xp:.3g}, y={yp:.3g}, val={v:.3g}"
414            else:
415                vals = mat[row, col]
416                txt = ", ".join(f"{vv:.3g}" for vv in vals[:3])
417                return f"x={xp:.3g}, y={yp:.3g}, val=({txt})"
418        return f"x={xp:.3g}, y={yp:.3g}"
419
420    ax.format_coord = format_coord
421
422
423def plot_stratification_over_time(
424    gridded_data,
425    ref_grid,
426    top_index,
427    heights_data,
428    date_time,
429    exclude_sub=False,
430    output_folder="."
431):
432    """
433    For each grid point and slope, generate a 2×2 figure of:
434      - CO2 ice fraction
435      - H2O ice fraction
436      - Dust fraction
437      - Pore fraction
438    """
439    prop_names = ['co2_ice', 'h2o_ice', 'dust', 'pore']
440    titles = ["CO2 ice", "H2O ice", "Dust", "Pore"]
441    cmap = plt.get_cmap('turbo').copy()
442    cmap.set_under('white')
443    vmin, vmax = 0.0, 1.0
444
445    sample_prop = next(iter(gridded_data.values()))
446    ngrid, ntime, nslope, nz = sample_prop.shape
447
448    if exclude_sub:
449        positive_indices = np.where(ref_grid >= 0.0)[0]
450        sub_ref_grid = ref_grid[positive_indices]
451    else:
452        positive_indices = np.arange(nz)
453        sub_ref_grid = ref_grid
454
455    for ig in range(ngrid):
456        for isl in range(nslope):
457            fig, axes = plt.subplots(2, 2, figsize=(10, 8))
458            fig.suptitle(
459                f"Content variation over time for (Grid point {ig+1}, Slope {isl+1})",
460                fontsize=14,
461                fontweight='bold'
462            )
463
464            # Precompute valid stratum tops per time
465            valid_tops_per_time = []
466            for t_idx in range(ntime):
467                raw_h = heights_data[t_idx][isl][ig, :]
468                h_all = raw_h[~np.isnan(raw_h)]
469                if exclude_sub:
470                    h_all = h_all[h_all >= 0.0]
471                valid_tops_per_time.append(np.unique(h_all))
472
473            for idx, prop in enumerate(prop_names):
474                ax = axes.flat[idx]
475                data_3d = gridded_data[prop][ig, :, isl, :]
476                mat_full = data_3d.T
477                mat = mat_full[positive_indices, :].copy()
478                mat[mat < 0.0] = np.nan
479
480                # Mask above top stratum
481                for t_idx in range(ntime):
482                    i_zmax = top_index[ig, t_idx, isl]
483                    if i_zmax <= positive_indices[0]:
484                        mat[:, t_idx] = np.nan
485                    else:
486                        count_z = np.count_nonzero(positive_indices < i_zmax)
487                        mat[count_z:, t_idx] = np.nan
488
489                im = ax.pcolormesh(
490                    date_time,
491                    sub_ref_grid,
492                    mat,
493                    cmap=cmap,
494                    shading='auto',
495                    vmin=vmin,
496                    vmax=vmax
497                )
498                x_edges = np.concatenate([date_time, [date_time[-1] + (date_time[-1]-date_time[-2])]])
499                attach_format_coord(ax, mat, x_edges, np.concatenate([sub_ref_grid, [sub_ref_grid[-1] + (sub_ref_grid[-1]-sub_ref_grid[-2])]]), is_pcolormesh=True)
500                ax.set_title(titles[idx], fontsize=12)
501                ax.set_xlabel("Time (Mars years)")
502                ax.set_ylabel("Elevation (m)")
503
504            fig.subplots_adjust(right=0.88)
505            fig.tight_layout(rect=[0, 0, 0.88, 1.0])
506            cbar_ax = fig.add_axes([0.90, 0.15, 0.02, 0.7])
507            fig.colorbar(im, cax=cbar_ax, orientation='vertical', label="Content")
508
509            fname = os.path.join(
510                output_folder, f"layering_evolution_ig{ig+1}_is{isl+1}.png"
511            )
512            fig.savefig(fname, dpi=1200, bbox_inches='tight')
513
514
515def plot_stratification_rgb_over_time(
516    gridded_data,
517    ref_grid,
518    top_index,
519    heights_data,
520    date_time,
521    exclude_sub=False,
522    output_folder="."
523):
524    """
525    Plot stratification over time colored using RGB ternary mix of H2O ice (blue), CO2 ice (violet), and dust (orange).
526    Includes a triangular legend showing the mix proportions.
527    """
528    # Define constant colors
529    violet = np.array([255,   0, 255], dtype=float) / 255
530    blue   = np.array([  0,   0, 255], dtype=float) / 255
531    orange = np.array([255, 165,   0], dtype=float) / 255
532
533    # Elevation mask and array
534    if exclude_sub:
535        elevation_mask = (ref_grid >= 0.0)
536        elev = ref_grid[elevation_mask]
537    else:
538        elevation_mask = np.ones_like(ref_grid, dtype=bool)
539        elev = ref_grid
540
541    # Pre-compute legend triangle
542    res = 300
543    u = np.linspace(0, 1, res)
544    v = np.linspace(0, np.sqrt(3)/2, res)
545    X, Y = np.meshgrid(u, v)
546    V_bary = 2 * Y / np.sqrt(3)
547    U_bary = X - 0.5 * V_bary
548    W_bary = 1 - U_bary - V_bary
549    mask_triangle = (U_bary >= 0) & (V_bary >= 0) & (W_bary >= 0)
550    legend_rgb = (
551        U_bary[..., None] * violet
552        + V_bary[..., None] * orange
553        + W_bary[..., None] * blue
554    )
555    legend_rgb = np.clip(legend_rgb, 0.0, 1.0)
556    legend_rgba = np.zeros((res, res, 4))
557    legend_rgba[..., :3] = legend_rgb
558    legend_rgba[..., 3] = mask_triangle.astype(float)
559
560    # Extract data arrays
561    h2o = gridded_data['h2o_ice']
562    co2 = gridded_data['co2_ice']
563    dust = gridded_data['dust']
564    ngrid, ntime, nslope, nz = h2o.shape
565
566    # Fill missing depths
567    ti = top_index.copy().astype(int)
568    for ig in range(ngrid):
569        for isl in range(nslope):
570            for t in range(1, ntime):
571                if ti[ig, t, isl] <= 0:
572                    ti[ig, t, isl] = ti[ig, t-1, isl]
573
574    # Loop over grid and slope
575    for ig in range(ngrid):
576        for isl in range(nslope):
577            # Compute RGB stratification over time
578            rgb = np.ones((nz, ntime, 3), dtype=float)
579            frac_all = np.full((nz, ntime, 3), np.nan, dtype=float)  # store fH2O, fCO2, fDust
580            for t in range(ntime):
581                depth = ti[ig, t, isl]
582                if depth <= 0:
583                    continue
584                cH2O = np.clip(h2o[ig, t, isl, :depth], 0, None)
585                cCO2 = np.clip(co2[ig, t, isl, :depth], 0, None)
586                cDust = np.clip(dust[ig, t, isl, :depth], 0, None)
587                total = cH2O + cCO2 + cDust
588                total[total == 0] = 1.0
589                fH2O = cH2O / total
590                fCO2 = cCO2 / total
591                fDust = cDust / total
592                frac_all[:depth, t, 0] = fH2O
593                frac_all[:depth, t, 1] = fCO2
594                frac_all[:depth, t, 2] = fDust
595                rgb[:depth, t, 0] = fH2O * blue[0] + fCO2 * violet[0] + fDust * orange[0]
596                rgb[:depth, t, 1] = fH2O * blue[1] + fCO2 * violet[1] + fDust * orange[1]
597                rgb[:depth, t, 2] = fH2O * blue[2] + fCO2 * violet[2] + fDust * orange[2]
598
599            # Mask elevation
600            display_rgb = rgb[elevation_mask, :, :]
601            display_frac = frac_all[elevation_mask, :, :]
602
603            # Compute edges for pcolormesh
604            dt = date_time[1] - date_time[0] if len(date_time) > 1 else 1
605            x_edges = np.concatenate([date_time, [date_time[-1] + dt]])
606            d_e = np.diff(elev)
607            last_e = elev[-1] + (d_e[-1] if len(d_e)>0 else 1)
608            y_edges = np.concatenate([elev, [last_e]])
609
610            # Create figure with legend
611            fig, (ax_main, ax_leg) = plt.subplots(
612                1, 2, figsize=(8, 4), dpi=150,
613                gridspec_kw={'width_ratios': [5, 1]}
614            )
615
616            # Main stratification panel
617            mesh = ax_main.pcolormesh(
618                x_edges,
619                y_edges,
620                display_rgb,
621                shading='auto',
622                edgecolors='none'
623            )
624
625            # Custom coordinate formatter: show time, elevation, and mixture fractions
626            def main_format(x, y):
627                # check bounds
628                if x < x_edges[0] or x > x_edges[-1] or y < y_edges[0] or y > y_edges[-1]:
629                    return ''
630                # locate cell
631                i = np.searchsorted(x_edges, x) - 1
632                j = np.searchsorted(y_edges, y) - 1
633                i = np.clip(i, 0, display_rgb.shape[1] - 1)
634                j = np.clip(j, 0, display_rgb.shape[0] - 1)
635                # get fractions
636                fH2O, fCO2, fDust = display_frac[j, i]
637                return f"Time={x:.2f}, Elev={y:.2f}, H2O={fH2O:.4f}, CO2={fCO2:.4f}, Dust={fDust:.4f}"
638            ax_main.format_coord = main_format
639            ax_main.set_facecolor('white')
640            ax_main.set_title(f"Ternary mix over time (Grid point {ig+1}, Slope {isl+1})", fontweight='bold')
641            ax_main.set_xlabel("Time (Mars years)")
642            ax_main.set_ylabel("Elevation (m)")
643
644            # Legend panel using proper edges
645            u_edges = np.linspace(0, 1, res+1)
646            v_edges = np.linspace(0, np.sqrt(3)/2, res+1)
647            ax_leg.pcolormesh(
648                u_edges,
649                v_edges,
650                legend_rgba,
651                shading='auto',
652                edgecolors='none'
653            )
654            ax_leg.set_aspect('equal')
655
656            # Custom coordinate formatter for legend: show barycentric fractions
657            def legend_format(x, y):
658                # compute barycentric coords from cartesian (x,y)
659                V = 2 * y / np.sqrt(3)
660                U = x - 0.5 * V
661                W = 1 - U - V
662                if U >= 0 and V >= 0 and W >= 0:
663                    return f"H2O: {W:.2f}, Dust: {V:.2f}, CO2: {U:.2f}"
664                else:
665                    return ''
666            ax_leg.format_coord = legend_format
667
668            # Draw triangle border and gridlines
669            triangle = np.array([[0, 0], [1, 0], [0.5, np.sqrt(3)/2], [0, 0]])
670            ax_leg.plot(triangle[:, 0], triangle[:, 1], 'k-', linewidth=1, clip_on=False, zorder=10)
671            ticks = np.linspace(0.25, 0.75, 3)
672            for f in ticks:
673                ax_leg.plot([1 - f, 0.5 * (1 - f)], [0, (1 - f)*np.sqrt(3)/2], '--', color='k', linewidth=0.5, clip_on=False, zorder=9)
674                ax_leg.plot([f, f + 0.5 * (1 - f)], [0, (1 - f)*np.sqrt(3)/2], '--', color='k', linewidth=0.5, clip_on=False, zorder=9)
675                y = (np.sqrt(3)/2) * f
676                ax_leg.plot([0.5 * f, 1 - 0.5 * f], [y, y], '--', color='k', linewidth=0.5, clip_on=False, zorder=9)
677
678            # Legend labels
679            ax_leg.text(0, -0.05, 'H2O ice', ha='center', va='top', fontsize=8)
680            ax_leg.text(1, -0.05, 'CO2 ice', ha='center', va='top', fontsize=8)
681            ax_leg.text(0.5, np.sqrt(3)/2 + 0.05, 'Dust', ha='center', va='bottom', fontsize=8)
682            ax_leg.axis('off')
683
684            # Save figure
685            plt.tight_layout()
686            fname = os.path.join(output_folder, f"layering_rgb_evolution_ig{ig+1}_is{isl+1}.png")
687            fig.savefig(fname, dpi=1200, bbox_inches='tight')
688
689
690def plot_dust_to_ice_ratio_over_time(
691    gridded_data,
692    ref_grid,
693    top_index,
694    heights_data,
695    date_time,
696    exclude_sub=False,
697    output_folder="."
698):
699    """
700    Plot the dust-to-ice ratio in the stratification over time,
701    using a blue-to-orange colormap:
702    - blue: ice-dominated (low dust-to-ice ratio)
703    - orange: dust-dominated (high dust-to-ice ratio)
704    """
705    h2o = gridded_data['h2o_ice']
706    co2 = gridded_data['co2_ice']
707    dust = gridded_data['dust']
708    ngrid, ntime, nslope, nz = h2o.shape
709
710    # Define custom blue-to-orange colormap
711    blue = np.array([0, 0, 255], dtype=float) / 255
712    orange = np.array([255, 165, 0], dtype=float) / 255
713    custom_cmap = LinearSegmentedColormap.from_list('BlueOrange', [blue, orange], N=256)
714
715    # Log‑ratio bounds and small epsilon to avoid log(0)
716    vmin, vmax = -2, 1
717    epsilon = 1e-6
718
719    # Loop over grids and slopes
720    for ig in range(ngrid):
721        for isl in range(nslope):
722            ti = top_index[ig, :, isl].copy().astype(int)
723
724            # Compute log10(dust/ice) profile at each time step
725            log_ratio_array = np.full((nz, ntime), np.nan, dtype=np.float32)
726            for t in range(ntime):
727                if t > 0 and ti[t] <= 0:
728                    ti[t] = ti[t-1]
729                elif ti[t] <= 0:
730                    continue
731                zmax = ti[t]
732                if zmax <= 0:
733                    continue
734
735                cH2O = np.clip(h2o[ig, t, isl, :zmax], 0, None)
736                cCO2 = np.clip(co2[ig, t, isl, :zmax], 0, None)
737                cDust = np.clip(dust[ig, t, isl, :zmax], 0, None)
738
739                with np.errstate(divide='ignore', invalid='ignore'):
740                    ratio = np.where(
741                        cH2O > 0,
742                        cDust / cH2O,
743                        10**(vmax + 1)
744                    )
745                    log_ratio = np.log10(ratio + epsilon)
746                    log_ratio = np.clip(log_ratio, vmin, vmax)
747
748                log_ratio_array[:zmax, t] = log_ratio
749
750            ratio_array = 10**log_ratio_array
751
752            # Compute edges for pcolormesh
753            x_edges = np.concatenate([date_time, [date_time[-1] + (date_time[-1] - date_time[-2])]]) * martian_to_earth
754            y_edges = np.concatenate([ref_grid, [ref_grid[-1] + (ref_grid[-1] - ref_grid[-2])]])
755
756            # Plot
757            fig, ax = plt.subplots(figsize=(8, 6), dpi=150)
758            im = ax.pcolormesh(
759                date_time,
760                elev,
761                ratio_array,
762                shading='auto',
763                cmap='managua_r',
764                norm=LogNorm(vmin=10**vmin, vmax=10**vmax),
765            )
766            attach_format_coord(ax, ratio_array, x_edges, y_edges, is_pcolormesh=True)
767            ax.set_title(f"Dust-to-Ice ratio over time (Grid point {ig+1}, Slope {isl+1})", fontweight='bold')
768            ax.set_xlabel('Time (Mars years)')
769            ax.set_ylabel('Elevation (m)')
770
771            # Add colorbar
772            cbar = fig.colorbar(im, ax=ax, orientation='vertical', pad=0.15)
773            cbar.set_label('Dust / H₂O ice (ratio)')
774            cbar.set_ticks([1e-2, 1e-1, 1, 1e1])
775            cbar.set_ticklabels(['1:100', '1:10', '1:1', '10:1'])
776
777            # Save figure
778            plt.tight_layout()
779            outname = os.path.join(
780                output_folder,
781                f"dust_to_ice_ratio_grid{ig+1}_slope{isl+1}.png"
782            )
783            fig.savefig(outname, dpi=1200, bbox_inches='tight')
784
785
786def plot_strata_count_and_total_height(heights_data, date_time, output_folder="."):
787    """
788    For each grid point and slope, plot:
789      - Number of strata vs time
790      - Total deposit height vs time
791    """
792    ntime = len(heights_data)
793    nslope = len(heights_data[0])
794    ngrid = heights_data[0][0].shape[0]
795
796    for ig in range(ngrid):
797        for isl in range(nslope):
798            n_strata_t = np.zeros(ntime, dtype=int)
799            total_height_t = np.zeros(ntime, dtype=float)
800
801            for t_idx in range(ntime):
802                h_mat = heights_data[t_idx][isl]
803                raw_h = h_mat[ig, :]
804                valid_mask = (~np.isnan(raw_h)) & (raw_h != 0.0)
805                if np.any(valid_mask):
806                    h_valid = raw_h[valid_mask]
807                    n_strata_t[t_idx] = h_valid.size
808                    total_height_t[t_idx] = np.max(h_valid)
809
810            fig, axes = plt.subplots(2, 1, figsize=(8, 6), sharex=True)
811            fig.suptitle(
812                f"Strata count & total height over time for (Grid point {ig+1}, Slope {isl+1})",
813                fontsize=14,
814                fontweight='bold'
815            )
816
817            axes[0].plot(date_time, n_strata_t, marker='+', linestyle='-')
818            axes[0].set_ylabel("Number of strata")
819            axes[0].grid(True)
820
821            axes[1].plot(date_time, total_height_t, marker='+', linestyle='-')
822            axes[1].set_xlabel("Time (Mars years)")
823            axes[1].set_ylabel("Total height (m)")
824            axes[1].grid(True)
825
826            fig.tight_layout(rect=[0, 0, 1, 0.95])
827            fname = os.path.join(
828                output_folder, f"strata_count_height_ig{ig+1}_is{isl+1}.png"
829            )
830            fig.savefig(fname, dpi=150)
831
832
833def read_orbital_data(orb_file, martian_to_earth):
834    """
835    Read the .asc file containing obliquity, eccentricity and Ls p.
836    Columns:
837      0 = time in thousand Martian years
838      1 = obliquity (deg)
839      2 = eccentricity
840      3 = Ls p (deg)
841    Converts times to Earth years.
842    """
843    data = np.loadtxt(orb_file)
844    dates_mka = data[:, 0]
845    dates_yr = dates_mka * 1e3 / martian_to_earth
846    obliquity = data[:, 1]
847    eccentricity = data[:, 2]
848    lsp = data[:, 3]
849    return dates_yr, obliquity, eccentricity, lsp
850
851
852def plot_orbital_parameters(infofile, orb_file, date_time, output_folder="."):
853    """
854    Plot the evolution of obliquity, eccentricity and Ls p
855    versus simulated time.
856    """
857    # Read conversion factor from infofile
858    _, martian_to_earth = read_infofile(infofile)
859
860    # Read orbital data
861    dates_yr, obl, ecc, lsp = read_orbital_data(orb_file, martian_to_earth)
862
863    # Interpolate orbital parameters at simulation dates (date_time)
864    obl_interp = interp1d(dates_yr, obl, kind='linear', bounds_error=False, fill_value="extrapolate")(date_time)
865    ecc_interp = interp1d(dates_yr, ecc, kind='linear', bounds_error=False, fill_value="extrapolate")(date_time)
866    lsp_interp = interp1d(dates_yr, lsp, kind='linear', bounds_error=False, fill_value="extrapolate")(date_time)
867
868    # Plot
869    fig, axes = plt.subplots(3, 1, figsize=(8, 10), sharex=True)
870    fig.suptitle("Orbital parameters vs simulated time", fontsize=14, fontweight='bold')
871
872    axes[0].plot(date_time, obl_interp, 'r-', marker='+')
873    axes[0].set_ylabel("Obliquity (°)")
874    axes[0].grid(True)
875
876    axes[1].plot(date_time, ecc_interp, 'b-', marker='+')
877    axes[1].set_ylabel("Eccentricity")
878    axes[1].grid(True)
879
880    axes[2].plot(date_time, lsp_interp, 'g-', marker='+')
881    axes[2].set_ylabel("Ls of perihelion  (°)")
882    axes[2].set_xlabel("Time (Mars years)")
883    axes[2].grid(True)
884
885    plt.tight_layout(rect=[0, 0, 1, 0.96])
886    fname = os.path.join(output_folder, "orbital_parameters_laskar.png")
887    fig.savefig(fname, dpi=150)
888
889
890def mars_ls(pday, peri_day, e_elips, year_day, lsperi=0.0):
891    """
892    Compute solar longitude (Ls) in radians for a given Mars date array 'pday'.
893    Returns Ls in degrees [0, 360).
894    """
895    zz = (pday - peri_day) / year_day
896    zanom = 2 * np.pi * (zz - np.round(zz))
897    xref = np.abs(zanom)
898
899    # Solve Kepler's equation via Newton–Raphson
900    zx0 = xref + e_elips * np.sin(xref)
901    for _ in range(10):
902        f  = zx0 - e_elips * np.sin(zx0) - xref
903        fp = 1 - e_elips * np.cos(zx0)
904        dz = -f / fp
905        zx0 += dz
906        if np.all(np.abs(dz) <= 1e-7):
907            break
908
909    zx0 = np.where(zanom < 0, -zx0, zx0)
910    zteta = 2 * np.arctan(
911        np.sqrt((1 + e_elips) / (1 - e_elips)) * np.tan(zx0 / 2)
912    )
913    psollong = np.mod(zteta + lsperi, 2 * np.pi)
914
915    return np.degrees(psollong)
916
917
918def read_orbital_data_nc(starts_folder, infofile=None):
919    """
920    Read orbital parameters from restartfi_postPEM*.nc files in starts_folder.
921    """
922    if not os.path.isdir(starts_folder):
923        raise ValueError(f"Invalid starts_folder '{starts_folder}': not a directory.")
924
925    # Read simulation time mapping if provided
926    if infofile:
927        dates_yr, martian_to_earth = read_infofile(infofile)
928    else:
929        dates_yr = None
930
931    pattern = os.path.join(starts_folder, "restartfi_postPEM*.nc")
932    files = glob(pattern)
933    if not files:
934        raise FileNotFoundError(f"No NetCDF restart files found matching {pattern}")
935
936    def extract_number(path):
937        name = os.path.basename(path)
938        prefix = 'restartfi_postPEM'
939        if name.startswith(prefix) and name.endswith('.nc'):
940            num_str = name[len(prefix):-3]
941            if num_str.isdigit():
942                return int(num_str)
943        return float('inf')
944
945    files = sorted(files, key=extract_number)
946
947    all_year_day, all_peri, all_aphe, all_date_peri, all_obl = [], [], [], [], []
948    for nc_path in files:
949        with Dataset(nc_path, 'r') as nc:
950            ctrl = nc.variables['controle'][:]
951            all_year_day.append(ctrl[13])
952            all_peri.append(ctrl[14])
953            all_aphe.append(ctrl[15])
954            all_date_peri.append(ctrl[16])
955            all_obl.append(ctrl[17])
956
957    year_day      = np.array(all_year_day)
958    perihelion    = np.array(all_peri)
959    aphelion      = np.array(all_aphe)
960    date_peri_day = np.array(all_date_peri)
961    obliquity     = np.array(all_obl)
962
963    eccentricity  = (aphelion - perihelion) / (aphelion + perihelion)
964    ls_perihelion = mars_ls(date_peri_day,0.,eccentricity,year_day)
965
966    return dates_yr, obliquity, eccentricity, ls_perihelion, martian_to_earth
967
968
969def plot_orbital_parameters_nc(starts_folder, infofile, date_time, output_folder="."):
970    """
971    Plot the evolution of obliquity, eccentricity and Ls p coming from simulation data
972    versus simulated time, plus an additional figure of sin(eccentricity)*Lsp.
973    versus simulated time.
974    """
975    # Read orbital data
976    times_yr, obl, ecc, lsp, martian_to_earth = read_orbital_data_nc(starts_folder, infofile)
977
978    fargs = dict(kind='linear', bounds_error=False, fill_value='extrapolate')
979    obl_i = interp1d(times_yr, obl, **fargs)(date_time)
980    ecc_i = interp1d(times_yr, ecc, **fargs)(date_time)
981    lsp_i = interp1d(times_yr, lsp, **fargs)(date_time)
982
983    date_time = date_time * martian_to_earth / 1e6
984
985    fig, axes = plt.subplots(3,1, figsize=(8,10), sharex=True)
986    fig.suptitle("Orbital parameters vs simulated time", fontsize=14, fontweight='bold')
987
988    # Plot
989    axes[0].plot(date_time, obl_i, 'r-', marker='+')
990    axes[0].set_ylabel("Obliquity (°)")
991    axes[0].grid(True)
992
993    axes[1].plot(date_time, ecc_i, 'b-', marker='+')
994    axes[1].set_ylabel("Eccentricity")
995    axes[1].grid(True)
996
997    axes[2].plot(date_time, lsp_i, 'g-', marker='+')
998    axes[2].set_ylabel("Ls of perihelion (°)")
999    axes[2].set_xlabel("Time (Myr)")
1000    axes[2].grid(True)
1001
1002    plt.tight_layout(rect=[0,0,1,0.96])
1003    outname = os.path.join(output_folder, "orbital_parameters_simu.png")
1004    fig.savefig(outname, dpi=150)
1005
1006    eps_sin_lsp = ecc_i * np.sin(np.radians(lsp_i)) 
1007
1008    fig2, ax2 = plt.subplots(figsize=(8,5))
1009    fig2.suptitle(r"$\epsilon \times \sin(L_{sp})$", fontweight='bold')
1010
1011    ax2.plot(date_time, eps_sin_lsp, 'm-', marker='+')
1012    ax2.set_ylabel(r"$\epsilon \cdot \sin(L_{sp})$")
1013    ax2.set_xlabel("Time (Myr)")
1014    ax2.grid(True)
1015
1016    plt.tight_layout(rect=[0,0,1,0.95])
1017    outname2 = os.path.join(output_folder, "sin_ecc_times_Lsp.png")
1018    fig2.savefig(outname2, dpi=150)
1019
1020
1021def plot_dust_to_ice_ratio_with_obliquity(
1022    starts_folder,
1023    infofile,
1024    gridded_data,
1025    ref_grid,
1026    top_index,
1027    heights_data,
1028    date_time,
1029    exclude_sub=False,
1030    output_folder="."
1031):
1032    """
1033    Plot the dust-to-ice ratio over time as a heatmap, and overlay the evolution of
1034    obliquity on a secondary y-axis.
1035    """
1036    h2o = gridded_data['h2o_ice']
1037    co2 = gridded_data['co2_ice']
1038    dust = gridded_data['dust']
1039    ngrid, ntime, nslope, nz = h2o.shape
1040
1041    # Read orbital data
1042    times_yr, obl, _, _, martian_to_earth = read_orbital_data_nc(starts_folder, infofile)
1043    fargs = dict(kind='linear', bounds_error=False, fill_value='extrapolate')
1044    obliquity = interp1d(times_yr, obl, **fargs)(date_time)
1045
1046    # Define custom blue-to-orange colormap
1047    blue = np.array([0, 0, 255], dtype=float) / 255
1048    orange = np.array([255, 165, 0], dtype=float) / 255
1049    custom_cmap = LinearSegmentedColormap.from_list('BlueOrange', [blue, orange], N=256)
1050    color_map = { 1: 'green', -1: 'red', 0: 'orange' }
1051
1052    # Log‑ratio bounds and small epsilon to avoid log(0)
1053    vmin, vmax = -2, 1
1054    epsilon = 1e-6
1055
1056    # Loop over grids and slopes
1057    for ig in range(ngrid):
1058        for isl in range(nslope):
1059            # Compute total height time series
1060            total_height_t = np.zeros(ntime, dtype=float)
1061            for t_idx in range(ntime):
1062                h_mat = heights_data[t_idx][isl]
1063                raw_h = h_mat[ig, :]
1064                valid_mask = (~np.isnan(raw_h)) & (raw_h != 0.0)
1065                if np.any(valid_mask):
1066                    h_valid = raw_h[valid_mask]
1067                    total_height_t[t_idx] = np.max(h_valid)
1068
1069            # Compute the per-interval sign of height change
1070            if ntime > 1:
1071                dh = np.diff(total_height_t)
1072                signs = np.sign(dh).astype(int)
1073            else:
1074                dh = np.array([], dtype=float)
1075                signs = np.array([], dtype=int)
1076
1077            # Prepare fraction and ratio arrays
1078            ti = top_index[ig, :, isl].copy().astype(int)
1079            log_ratio_array = np.full((nz, ntime), np.nan, dtype=np.float32)
1080            frac_all = np.full((nz, ntime, 3), np.nan, dtype=float)  # store fH2O, fCO2, fDust
1081            for t in range(ntime):
1082                if t > 0 and ti[t] <= 0:
1083                    ti[t] = ti[t-1]
1084                elif ti[t] <= 0:
1085                    continue
1086                zmax = ti[t]
1087                if zmax <= 0:
1088                    continue
1089
1090                cH2O = np.clip(h2o[ig, t, isl, :zmax], 0, None)
1091                cCO2 = np.clip(co2[ig, t, isl, :zmax], 0, None)
1092                cDust = np.clip(dust[ig, t, isl, :zmax], 0, None)
1093                total = cH2O + cCO2 + cDust
1094                total[total == 0] = 1.0
1095                fH2O = cH2O / total
1096                fCO2 = cCO2 / total
1097                fDust = cDust / total
1098                frac_all[:zmax, t, 0] = fH2O
1099                frac_all[:zmax, t, 1] = fCO2
1100                frac_all[:zmax, t, 2] = fDust
1101
1102                with np.errstate(divide='ignore', invalid='ignore'):
1103                    ratio = np.where(cH2O > 0, cDust / cH2O, 10**(vmax + 1)
1104                    )
1105                    log_ratio = np.log10(ratio + epsilon)
1106                    log_ratio = np.clip(log_ratio, vmin, vmax)
1107
1108                log_ratio_array[:zmax, t] = log_ratio
1109
1110            ratio_array = 10**log_ratio_array
1111
1112            # Compute edges for pcolormesh
1113            dt = date_time[1] - date_time[0] if len(date_time) > 1 else 1
1114            x_edges = np.concatenate([date_time, [date_time[-1] + (date_time[-1] - date_time[-2])]]) * martian_to_earth
1115            y_edges = np.concatenate([ref_grid, [ref_grid[-1] + (ref_grid[-1] - ref_grid[-2])]])
1116
1117            # Plot
1118            fig, ax = plt.subplots(figsize=(8, 6), dpi=150)
1119            im = ax.pcolormesh(
1120                x_edges,
1121                y_edges,
1122                ratio_array,
1123                shading='auto',
1124                cmap='managua_r',
1125                norm=LogNorm(vmin=10**vmin, vmax=10**vmax),
1126            )
1127
1128            # Custom formatter for millions of Earth years
1129            def millions_formatter(x, pos):
1130                return f"{x/1e6:.1f}"
1131
1132            def format_coord_custom(x_input, y_input):
1133                # map onto the main axis
1134                if plt.gca() is ax2:
1135                    x_pix, y_pix = ax2.transData.transform((x_input, y_input))
1136                    x, y = ax.transData.inverted().transform((x_pix, y_pix))
1137                else:
1138                    x, y = x_input, y_input
1139                # check bounds
1140                if x < x_edges[0] or x > x_edges[-1] or y < y_edges[0] or y > y_edges[-1]:
1141                    return ''
1142                # locate cell
1143                i = np.searchsorted(x_edges, x) - 1
1144                j = np.searchsorted(y_edges, y) - 1
1145                i = np.clip(i, 0, ratio_array.shape[1] - 1)
1146                j = np.clip(j, 0, ratio_array.shape[0] - 1)
1147                # get fractions and obliquity
1148                fH2O, fCO2, fDust = frac_all[j, i]
1149                obl   = np.interp(x / martian_to_earth, date_time, obliquity)
1150                return f"Time={x:.2f}, Elev={y:.2f}, H2O={fH2O:.4f}, Dust={fDust:.4f}, Obl={obl:.2f}°"
1151
1152            ax.set_title(f"Dust-to-Ice ratio over time (Grid point {ig+1}, Slope {isl+1})", fontweight='bold')
1153            ax.xaxis.set_major_formatter(FuncFormatter(millions_formatter))
1154            ax.set_xlabel('Time (Myr)')
1155            ax.set_ylabel('Elevation (m)')
1156
1157            # Add colorbar
1158            cbar = fig.colorbar(im, ax=ax, orientation='vertical', pad=0.15)
1159            cbar.set_label('Dust / H₂O ice (ratio)')
1160            cbar.set_ticks([1e-2, 1e-1, 1, 1e1])
1161            cbar.set_ticklabels(['1:100', '1:10', '1:1', '10:1'])
1162
1163            # Overlay obliquity on secondary y-axis
1164            ax2 = ax.twinx()
1165            for i in range(len(dh)):
1166                ax2.plot(
1167                    [date_time[i] * martian_to_earth, date_time[i+1] * martian_to_earth],
1168                    [obliquity[i], obliquity[i+1]],
1169                    color=color_map[signs[i]],
1170                    marker='+',
1171                    linewidth=1.5
1172                )
1173            ax2.format_coord = format_coord_custom
1174            ax2.set_ylabel('Obliquity (°)')
1175            ax2.tick_params(axis='y')
1176            ax2.grid(False)
1177
1178            # Save figure
1179            plt.tight_layout()
1180            outname = os.path.join(
1181                output_folder,
1182                f'dust_ice_obliquity_grid{ig+1}_slope{isl+1}.png'
1183            )
1184            fig.savefig(outname, dpi=1200, bbox_inches='tight')
1185
1186
1187def main():
1188    # 1) Get user inputs
1189    folder_path, base_name, infofile, orbfile = get_user_inputs()
1190
1191    # 2) List and verify NetCDF files
1192    files = list_netcdf_files(folder_path, base_name)
1193    if not files:
1194        print(f"No NetCDF files named \"{base_name}#.nc\" found in \"{folder_path}\".")
1195        sys.exit(1)
1196    print(f"> Found {len(files)} NetCDF file(s).")
1197
1198    # 3) Open one sample to get grid dimensions & coordinates
1199    sample_file = files[0]
1200    ngrid, nslope, longitude, latitude = open_sample_dataset(sample_file)
1201    print(f"> ngrid  = {ngrid}, nslope = {nslope}")
1202
1203    # 4) Collect variable info + global min/max elevations
1204    var_info, max_nb_str, min_base_elev, max_top_elev = collect_stratification_variables(files, base_name)
1205    print(f"> max strata per slope = {max_nb_str}")
1206    print(f"> min base elev = {min_base_elev:.3f} m, max top elev = {max_top_elev:.3f} m")
1207
1208    # 5) Load full datasets
1209    datasets = load_full_datasets(files)
1210
1211    # 6) Extract stratification data
1212    heights_data, raw_prop_arrays, ntime = extract_stratification_data(datasets, var_info, ngrid, nslope, max_nb_str)
1213
1214    # 7) Close datasets
1215    for ds in datasets:
1216        ds.close()
1217
1218    # 8) Normalize to fractions
1219    frac_arrays = normalize_to_fractions(raw_prop_arrays)
1220
1221    # 9) Ask whether to include subsurface
1222    show_subsurface = get_yes_no_input("Show subsurface layers?")
1223    exclude_sub = not show_subsurface
1224    if exclude_sub:
1225        min_base_for_interp = 0.0
1226        print("> Interpolating only elevations >= 0 m (surface strata).")
1227    else:
1228        min_base_for_interp = min_base_elev
1229        print(f"> Interpolating full depth down to {min_base_elev:.3f} m.")
1230
1231    # 10) Prompt discretization step
1232    dz = prompt_discretization_step(max_top_elev)
1233
1234    # 11) Build reference grid and interpolate
1235    ref_grid, gridded_data, top_index = interpolate_data_on_refgrid(
1236        heights_data, frac_arrays, min_base_for_interp, max_top_elev, dz, exclude_sub=exclude_sub
1237    )
1238
1239    # 12) Read timestamps and conversion factor from infofile
1240    date_time, martian_to_earth = read_infofile(infofile)
1241    if date_time.size != ntime:
1242        print(f"Warning: {date_time.size} timestamps vs {ntime} NetCDF files.")
1243
1244    # 13) Plot stratification data over time
1245    plot_stratification_over_time(
1246        gridded_data, ref_grid, top_index, heights_data, date_time,
1247        exclude_sub=exclude_sub, output_folder="."
1248    )
1249    plot_stratification_rgb_over_time(
1250        gridded_data, ref_grid, top_index, heights_data, date_time,
1251        exclude_sub=exclude_sub, output_folder="."
1252    )
1253    #plot_dust_to_ice_ratio_over_time(
1254    #    gridded_data, ref_grid, top_index, heights_data, date_time,
1255    #    exclude_sub=exclude_sub, output_folder="."
1256    #)
1257    plot_dust_to_ice_ratio_with_obliquity(
1258        folder_path, infofile,
1259        gridded_data, ref_grid, top_index, heights_data, date_time,
1260        exclude_sub=exclude_sub, output_folder="."
1261    )
1262    #plot_strata_count_and_total_height(heights_data, date_time, output_folder=".")
1263
1264    # 14) Plot orbital parameters
1265    #plot_orbital_parameters(infofile, orbfile, date_time, output_folder=".")
1266    plot_orbital_parameters_nc(folder_path, infofile, date_time, output_folder=".")
1267
1268    # 15) Show all figures
1269    plt.show()
1270
1271
1272if __name__ == "__main__":
1273    main()
1274
Note: See TracBrowser for help on using the repository browser.