source: trunk/LMDZ.COMMON/libf/evolution/deftank/visu_evol_layering.py

Last change on this file was 3860, checked in by jbclement, 12 days ago

PEM:

  • Correction of "visu_evol_layering.py" to compute Lsp variations.
  • Few cleanings to open files throughout the PEM.

JBC

  • Property svn:executable set to *
File size: 45.3 KB
Line 
1#!/usr/bin/env python3
2#######################################################################################################
3### Python script to output stratification data over time from "restartpem#.nc" files               ###
4### and to plot orbital parameters from "obl_ecc_lsp.asc"                                           ###
5#######################################################################################################
6
7import os
8import sys
9import numpy as np
10from glob import glob
11from netCDF4 import Dataset
12import matplotlib.pyplot as plt
13from mpl_toolkits.axes_grid1.inset_locator import inset_axes
14from matplotlib.colors import LinearSegmentedColormap, LogNorm
15from scipy.interpolate import interp1d
16
17
18def get_user_inputs():
19    """
20    Prompt the user for:
21      - folder_path: directory containing NetCDF files (default: "starts")
22      - base_name:   base filename (default: "restartpem")
23      - infofile:    name of the PEM info file (default: "info_PEM.txt")
24    Validates existence of folder and infofile before returning.
25    """
26    folder_path = input(
27        "Enter the folder path containing the NetCDF files "
28        "(press Enter for default [starts]): "
29    ).strip() or "starts"
30    while not os.path.isdir(folder_path):
31        print(f"  » \"{folder_path}\" does not exist or is not a directory.")
32        folder_path = input(
33            "Enter a valid folder path (press Enter for default [starts]): "
34        ).strip() or "starts"
35
36    base_name = input(
37        "Enter the base name of the NetCDF files "
38        "(press Enter for default [restartpem]): "
39    ).strip() or "restartpem"
40
41    infofile = input(
42        "Enter the name of the PEM info file "
43        "(press Enter for default [info_PEM.txt]): "
44    ).strip() or "info_PEM.txt"
45    while not os.path.isfile(infofile):
46        print(f"  » \"{infofile}\" does not exist or is not a file.")
47        infofile = input(
48            "Enter a valid PEM info filename (press Enter for default [info_PEM.txt]): "
49        ).strip() or "info_PEM.txt"
50
51    orbfile = input(
52        "Enter the name of the orbital parameters ASCII file "
53        "(press Enter for default [obl_ecc_lsp.asc]): "
54    ).strip() or "obl_ecc_lsp.asc"
55    while not os.path.isfile(orbfile):
56        print(f"  » \"{orbfile}\" does not exist or is not a file.")
57        orbfile = input(
58            "Enter a valid orbital parameters ASCII filename (press Enter for default [obl_ecc_lsp.asc]): "
59        ).strip() or "info_PEM.txt"
60
61    return folder_path, base_name, infofile, orbfile
62
63
64def list_netcdf_files(folder_path, base_name):
65    """
66    List and sort all NetCDF files matching the pattern {base_name}#.nc
67    in folder_path. Returns a sorted list of full file paths.
68    """
69    pattern = os.path.join(folder_path, f"{base_name}[0-9]*.nc")
70    all_files = glob(pattern)
71    if not all_files:
72        return []
73
74    def extract_index(pathname):
75        fname = os.path.basename(pathname)
76        idx_str = fname[len(base_name):-3]
77        return int(idx_str) if idx_str.isdigit() else float('inf')
78
79    sorted_files = sorted(all_files, key=extract_index)
80    return sorted_files
81
82
83def open_sample_dataset(file_path):
84    """
85    Open a single NetCDF file and extract:
86      - ngrid, nslope
87      - longitude, latitude
88    Returns (ngrid, nslope, longitude_array, latitude_array).
89    """
90    with Dataset(file_path, 'r') as ds:
91        ngrid = ds.dimensions['physical_points'].size
92        nslope = ds.dimensions['nslope'].size
93        longitude = ds.variables['longitude'][:].copy()
94        latitude = ds.variables['latitude'][:].copy()
95    return ngrid, nslope, longitude, latitude
96
97
98def collect_stratification_variables(files, base_name):
99    """
100    Scan all files to collect:
101      - variable names for each stratification property
102      - max number of strata (max_nb_str)
103      - global min base elevation and max top elevation
104    Returns:
105      - var_info: dict mapping each property_name -> sorted list of var names
106      - max_nb_str: int
107      - min_base_elev: float
108      - max_top_elev: float
109    """
110    max_nb_str = 0
111    min_base_elev = np.inf
112    max_top_elev = -np.inf
113
114    property_markers = {
115        'heights':   'stratif_slope',    # "..._top_elevation"
116        'co2_ice':   'h_co2ice',
117        'h2o_ice':   'h_h2oice',
118        'dust':      'h_dust',
119        'pore':      'h_pore',
120        'pore_ice':  'poreice_volfrac'
121    }
122    var_info = {prop: set() for prop in property_markers}
123
124    for file_path in files:
125        with Dataset(file_path, 'r') as ds:
126            if 'nb_str_max' in ds.dimensions:
127                max_nb_str = max(max_nb_str, ds.dimensions['nb_str_max'].size)
128
129            nslope = ds.dimensions['nslope'].size
130            for k in range(1, nslope + 1):
131                var_name = f"stratif_slope{k:02d}_top_elevation"
132                if var_name in ds.variables:
133                    arr = ds.variables[var_name][:]
134                    min_base_elev = min(min_base_elev, np.min(arr))
135                    max_top_elev = max(max_top_elev, np.max(arr))
136                    var_info['heights'].add(var_name)
137
138            for full_var in ds.variables:
139                for prop, marker in property_markers.items():
140                    if (marker in full_var) and prop != 'heights':
141                        var_info[prop].add(full_var)
142
143    for prop in var_info:
144        var_info[prop] = sorted(var_info[prop])
145
146    return var_info, max_nb_str, min_base_elev, max_top_elev
147
148
149def load_full_datasets(files):
150    """
151    Open all NetCDF files and return a list of Dataset objects.
152    (They should be closed by the caller after use.)
153    """
154    return [Dataset(fp, 'r') for fp in files]
155
156
157def extract_stratification_data(datasets, var_info, ngrid, nslope, max_nb_str):
158    """
159    Build:
160      - heights_data[t_idx][isl] = 2D array (ngrid, n_strata_current) of top_elevations.
161      - raw_prop_arrays[prop] = 4D array (ngrid, ntime, nslope, max_nb_str) of per-strata values.
162    Returns:
163      - heights_data: list (ntime) of lists (nslope) of 2D arrays
164      - raw_prop_arrays: dict mapping each property_name -> 4D array
165      - ntime: number of time steps (files)
166    """
167    ntime = len(datasets)
168
169    heights_data = [
170        [None for _ in range(nslope)]
171        for _ in range(ntime)
172    ]
173    for t_idx, ds in enumerate(datasets):
174        for var_name in var_info['heights']:
175            slope_idx = int(var_name.split("slope")[1].split("_")[0]) - 1
176            if 0 <= slope_idx < nslope:
177                raw = ds.variables[var_name][0, :, :]  # (n_strata, ngrid)
178                heights_data[t_idx][slope_idx] = raw.# (ngrid, n_strata)
179
180    raw_prop_arrays = {}
181    for prop in var_info:
182        if prop == 'heights':
183            continue
184        raw_prop_arrays[prop] = np.zeros((ngrid, ntime, nslope, max_nb_str), dtype=np.float32)
185
186    def slope_index_from_var(vname):
187        return int(vname.split("slope")[1].split("_")[0]) - 1
188
189    for prop in raw_prop_arrays:
190        slope_map = {}
191        for vname in var_info[prop]:
192            isl = slope_index_from_var(vname)
193            if 0 <= isl < nslope:
194                slope_map[isl] = vname
195
196        arr = raw_prop_arrays[prop]
197        for t_idx, ds in enumerate(datasets):
198            for isl, var_name in slope_map.items():
199                raw = ds.variables[var_name][0, :, :]  # (n_strata, ngrid)
200                n_strata_current = raw.shape[0]
201                arr[:, t_idx, isl, :n_strata_current] = raw.T
202
203    return heights_data, raw_prop_arrays, ntime
204
205
206def normalize_to_fractions(raw_prop_arrays):
207    """
208    Given raw_prop_arrays for 'co2_ice', 'h2o_ice', 'dust', 'pore' (in meters),
209    normalize each set of strata so that the sum of those four = 1 per cell.
210    Returns:
211      - frac_arrays: dict mapping same keys -> 4D arrays of fractions (0..1).
212    """
213    co2 = raw_prop_arrays['co2_ice']
214    h2o = raw_prop_arrays['h2o_ice']
215    dust = raw_prop_arrays['dust']
216    pore = raw_prop_arrays['pore']
217
218    total = co2 + h2o + dust + pore
219    mask = total > 0.0
220
221    frac_co2 = np.zeros_like(co2, dtype=np.float32)
222    frac_h2o = np.zeros_like(h2o, dtype=np.float32)
223    frac_dust = np.zeros_like(dust, dtype=np.float32)
224    frac_pore = np.zeros_like(pore, dtype=np.float32)
225
226    frac_co2[mask] = co2[mask] / total[mask]
227    frac_h2o[mask] = h2o[mask] / total[mask]
228    frac_dust[mask] = dust[mask] / total[mask]
229    frac_pore[mask] = pore[mask] / total[mask]
230
231    return {
232        'co2_ice': frac_co2,
233        'h2o_ice': frac_h2o,
234        'dust':     frac_dust,
235        'pore':     frac_pore
236    }
237
238
239def read_infofile(file_name):
240    """
241    Reads "info_PEM.txt". Expects:
242      - First line: parameters where the 3rd value is martian_to_earth conversion factor.
243      - Each subsequent line: floats where first value is simulation timestamp (in Mars years).
244    Returns:
245      - date_time: 1D numpy array of timestamps (Mars years)
246      - martian_to_earth: float conversion factor
247    """
248    date_time = []
249    with open(file_name, 'r') as fp:
250        first = fp.readline().split()
251        martian_to_earth = float(first[2])
252        for line in fp:
253            parts = line.strip().split()
254            if not parts:
255                continue
256            try:
257                date_time.append(float(parts[0]))
258            except ValueError:
259                continue
260    return np.array(date_time, dtype=np.float64), martian_to_earth
261
262
263def get_yes_no_input(prompt: str) -> bool:
264    """
265    Prompt the user with a yes/no question. Returns True for yes, False for no.
266    """
267    while True:
268        choice = input(f"{prompt} (y/n): ").strip().lower()
269        if choice in ['y', 'yes']:
270            return True
271        elif choice in ['n', 'no']:
272            return False
273        else:
274            print("Please respond with y or n.")
275
276
277def prompt_discretization_step(max_top_elev):
278    """
279    Prompt for a positive float dz such that 0 < dz <= max_top_elev.
280    """
281    while True:
282        entry = input(
283            "Enter the discretization step of the reference grid for the elevation [m]: "
284        ).strip()
285        try:
286            dz = float(entry)
287            if dz <= 0:
288                print("  » Discretization step must be strictly positive!")
289                continue
290            if dz > max_top_elev:
291                print(
292                    f"  » {dz:.3e} m is greater than the maximum top elevation "
293                    f"({max_top_elev:.3e} m). Please enter a smaller value."
294                )
295                continue
296            return dz
297        except ValueError:
298            print("  » Invalid numeric value. Please try again.")
299
300
301def interpolate_data_on_refgrid(
302    heights_data,
303    prop_arrays,
304    min_base_for_interp,
305    max_top_elev,
306    dz,
307    exclude_sub=False
308):
309    """
310    Build a reference elevation grid and interpolate strata fractions onto it.
311
312    Returns:
313      - ref_grid: 1D array of elevations (nz,)
314      - gridded_data: dict mapping each property_name to 4D array
315        (ngrid, ntime, nslope, nz) with interpolated fractions.
316      - top_index: 3D array (ngrid, ntime, nslope) of ints:
317        number of levels covered by the topmost stratum.
318    """
319    if exclude_sub and (dz > max_top_elev):
320        ref_grid = np.array([0.0, max_top_elev], dtype=np.float32)
321    else:
322        ref_grid = np.arange(min_base_for_interp, max_top_elev + dz/2, dz)
323    nz = len(ref_grid)
324    print(f"> Number of reference grid points = {nz}")
325
326    sample_prop = next(iter(prop_arrays.values()))
327    ngrid, ntime, nslope, max_nb_str = sample_prop.shape
328
329    gridded_data = {
330        prop: np.full((ngrid, ntime, nslope, nz), -1.0, dtype=np.float32)
331        for prop in prop_arrays
332    }
333    top_index = np.zeros((ngrid, ntime, nslope), dtype=np.int32)
334
335    for ig in range(ngrid):
336        for t_idx in range(ntime):
337            for isl in range(nslope):
338                h_mat = heights_data[t_idx][isl]
339                if h_mat is None:
340                    continue
341
342                raw_h = h_mat[ig, :]
343                h_all = np.full((max_nb_str,), np.nan, dtype=np.float32)
344                n_strata_current = raw_h.shape[0]
345                h_all[:n_strata_current] = raw_h
346
347                if exclude_sub:
348                    epsilon = 1e-6
349                    valid_mask = (h_all >= -epsilon)
350                else:
351                    valid_mask = (~np.isnan(h_all)) & (h_all != 0.0)
352
353                if not np.any(valid_mask):
354                    continue
355
356                h_valid = h_all[valid_mask]
357                top_h = np.max(h_valid)
358                i_zmax = np.searchsorted(ref_grid, top_h, side='right')
359                top_index[ig, t_idx, isl] = i_zmax
360                if i_zmax == 0:
361                    continue
362
363                for prop, arr in prop_arrays.items():
364                    prop_profile_all = arr[ig, t_idx, isl, :]
365                    prop_profile = prop_profile_all[valid_mask]
366                    if prop_profile.size == 0:
367                        continue
368
369                    f_interp = interp1d(
370                        h_valid,
371                        prop_profile,
372                        kind='next',
373                        bounds_error=False,
374                        fill_value=-1.0
375                    )
376                    gridded_data[prop][ig, t_idx, isl, :i_zmax] = f_interp(ref_grid[:i_zmax])
377
378    return ref_grid, gridded_data, top_index
379
380
381def attach_format_coord(ax, mat, x, y, is_pcolormesh=True):
382    """
383    Attach a format_coord function to the axes to display x, y, and value at cursor.
384    Works for both pcolormesh and imshow style grids.
385    """
386    # Determine dimensions
387    if mat.ndim == 2:
388        ny, nx = mat.shape
389    elif mat.ndim == 3 and mat.shape[2] in (3, 4):
390        ny, nx, nc = mat.shape
391    else:
392        raise ValueError(f"Unsupported mat shape {mat.shape}")
393    # Edges or extents
394    if is_pcolormesh:
395        xedges, yedges = x, y
396    else:
397        x0, x1 = x.min(), x.max()
398        y0, y1 = y.min(), y.max()
399
400    def format_coord(xp, yp):
401        # Map to indices
402        if is_pcolormesh:
403            col = np.searchsorted(xedges, xp) - 1
404            row = np.searchsorted(yedges, yp) - 1
405        else:
406            col = int((xp - x0) / (x1 - x0) * nx)
407            row = int((yp - y0) / (y1 - y0) * ny)
408        # Within bounds?
409        if 0 <= row < ny and 0 <= col < nx:
410            if mat.ndim == 2:
411                v = mat[row, col]
412                return f"x={xp:.3g}, y={yp:.3g}, val={v:.3g}"
413            else:
414                vals = mat[row, col]
415                txt = ", ".join(f"{vv:.3g}" for vv in vals[:3])
416                return f"x={xp:.3g}, y={yp:.3g}, val=({txt})"
417        return f"x={xp:.3g}, y={yp:.3g}"
418
419    ax.format_coord = format_coord
420
421
422def plot_stratification_over_time(
423    gridded_data,
424    ref_grid,
425    top_index,
426    heights_data,
427    date_time,
428    exclude_sub=False,
429    output_folder="."
430):
431    """
432    For each grid point and slope, generate a 2×2 figure of:
433      - CO2 ice fraction
434      - H2O ice fraction
435      - Dust fraction
436      - Pore fraction
437    """
438    prop_names = ['co2_ice', 'h2o_ice', 'dust', 'pore']
439    titles = ["CO2 ice", "H2O ice", "Dust", "Pore"]
440    cmap = plt.get_cmap('turbo').copy()
441    cmap.set_under('white')
442    vmin, vmax = 0.0, 1.0
443
444    sample_prop = next(iter(gridded_data.values()))
445    ngrid, ntime, nslope, nz = sample_prop.shape
446
447    if exclude_sub:
448        positive_indices = np.where(ref_grid >= 0.0)[0]
449        sub_ref_grid = ref_grid[positive_indices]
450    else:
451        positive_indices = np.arange(nz)
452        sub_ref_grid = ref_grid
453
454    for ig in range(ngrid):
455        for isl in range(nslope):
456            fig, axes = plt.subplots(2, 2, figsize=(10, 8))
457            fig.suptitle(
458                f"Content variation over time for (Grid point {ig+1}, Slope {isl+1})",
459                fontsize=14,
460                fontweight='bold'
461            )
462
463            # Precompute valid stratum tops per time
464            valid_tops_per_time = []
465            for t_idx in range(ntime):
466                raw_h = heights_data[t_idx][isl][ig, :]
467                h_all = raw_h[~np.isnan(raw_h)]
468                if exclude_sub:
469                    h_all = h_all[h_all >= 0.0]
470                valid_tops_per_time.append(np.unique(h_all))
471
472            for idx, prop in enumerate(prop_names):
473                ax = axes.flat[idx]
474                data_3d = gridded_data[prop][ig, :, isl, :]
475                mat_full = data_3d.T
476                mat = mat_full[positive_indices, :].copy()
477                mat[mat < 0.0] = np.nan
478
479                # Mask above top stratum
480                for t_idx in range(ntime):
481                    i_zmax = top_index[ig, t_idx, isl]
482                    if i_zmax <= positive_indices[0]:
483                        mat[:, t_idx] = np.nan
484                    else:
485                        count_z = np.count_nonzero(positive_indices < i_zmax)
486                        mat[count_z:, t_idx] = np.nan
487
488                im = ax.pcolormesh(
489                    date_time,
490                    sub_ref_grid,
491                    mat,
492                    cmap=cmap,
493                    shading='auto',
494                    vmin=vmin,
495                    vmax=vmax
496                )
497                x_edges = np.concatenate([date_time, [date_time[-1] + (date_time[-1]-date_time[-2])]])
498                attach_format_coord(ax, mat, x_edges, np.concatenate([sub_ref_grid, [sub_ref_grid[-1] + (sub_ref_grid[-1]-sub_ref_grid[-2])]]), is_pcolormesh=True)
499                ax.set_title(titles[idx], fontsize=12)
500                ax.set_xlabel("Time (Mars years)")
501                ax.set_ylabel("Elevation (m)")
502
503            fig.subplots_adjust(right=0.88)
504            fig.tight_layout(rect=[0, 0, 0.88, 1.0])
505            cbar_ax = fig.add_axes([0.90, 0.15, 0.02, 0.7])
506            fig.colorbar(im, cax=cbar_ax, orientation='vertical', label="Content")
507
508            fname = os.path.join(
509                output_folder, f"layering_evolution_ig{ig+1}_is{isl+1}.png"
510            )
511            fig.savefig(fname, dpi=150)
512
513
514def plot_stratification_rgb_over_time(
515    gridded_data,
516    ref_grid,
517    top_index,
518    heights_data,
519    date_time,
520    exclude_sub=False,
521    output_folder="."
522):
523    """
524    Plot stratification over time colored using RGB ternary mix of H2O ice (blue), CO2 ice (violet), and dust (orange).
525    Includes a triangular legend showing the mix proportions.
526    """
527    # Define constant colors
528    violet = np.array([255,   0, 255], dtype=float) / 255
529    blue   = np.array([  0,   0, 255], dtype=float) / 255
530    orange = np.array([255, 165,   0], dtype=float) / 255
531
532    # Elevation mask and array
533    if exclude_sub:
534        elevation_mask = (ref_grid >= 0.0)
535        elev = ref_grid[elevation_mask]
536    else:
537        elevation_mask = np.ones_like(ref_grid, dtype=bool)
538        elev = ref_grid
539
540    # Pre-compute legend triangle
541    res = 300
542    u = np.linspace(0, 1, res)
543    v = np.linspace(0, np.sqrt(3)/2, res)
544    X, Y = np.meshgrid(u, v)
545    V_bary = 2 * Y / np.sqrt(3)
546    U_bary = X - 0.5 * V_bary
547    W_bary = 1 - U_bary - V_bary
548    mask_triangle = (U_bary >= 0) & (V_bary >= 0) & (W_bary >= 0)
549    legend_rgb = (
550        U_bary[..., None] * violet
551        + V_bary[..., None] * orange
552        + W_bary[..., None] * blue
553    )
554    legend_rgb = np.clip(legend_rgb, 0.0, 1.0)
555    legend_rgba = np.zeros((res, res, 4))
556    legend_rgba[..., :3] = legend_rgb
557    legend_rgba[..., 3] = mask_triangle.astype(float)
558
559    # Extract data arrays
560    h2o = gridded_data['h2o_ice']
561    co2 = gridded_data['co2_ice']
562    dust = gridded_data['dust']
563    ngrid, ntime, nslope, nz = h2o.shape
564
565    # Fill missing depths
566    ti = top_index.copy().astype(int)
567    for ig in range(ngrid):
568        for isl in range(nslope):
569            for t in range(1, ntime):
570                if ti[ig, t, isl] <= 0:
571                    ti[ig, t, isl] = ti[ig, t-1, isl]
572
573    # Loop over grid and slope
574    for ig in range(ngrid):
575        for isl in range(nslope):
576            # Compute RGB stratification over time
577            rgb = np.ones((nz, ntime, 3), dtype=float)
578
579            frac_all = np.zeros((nz, ntime, 3), dtype=float)  # store fH2O, fCO2, fDust
580            for t in range(ntime):
581                depth = ti[ig, t, isl]
582                if depth <= 0:
583                    continue
584                cH2O = np.clip(h2o[ig, t, isl, :depth], 0, None)
585                cCO2 = np.clip(co2[ig, t, isl, :depth], 0, None)
586                cDust = np.clip(dust[ig, t, isl, :depth], 0, None)
587                total = cH2O + cCO2 + cDust
588                total[total == 0] = 1.0
589                fH2O = cH2O / total
590                fCO2 = cCO2 / total
591                fDust = cDust / total
592                frac_all[:depth, t, :] = np.stack([fH2O, fCO2, fDust], axis=1)
593                mix = np.outer(fH2O, blue) + np.outer(fCO2, violet) + np.outer(fDust, orange)
594                rgb[:depth, t, :] = np.clip(mix, 0, 1)
595
596            # Mask elevation
597            display_rgb = rgb[elevation_mask, :, :]
598            display_frac = frac_all[elevation_mask, :, :]
599
600            display_rgb = rgb[elevation_mask, :, :]
601
602            # Compute edges for pcolormesh
603            dt = date_time[1] - date_time[0] if len(date_time) > 1 else 1
604            x_edges = np.concatenate([date_time, [date_time[-1] + dt]])
605            d_e = np.diff(elev)
606            last_e = elev[-1] + (d_e[-1] if len(d_e)>0 else 1)
607            y_edges = np.concatenate([elev, [last_e]])
608
609            # Create figure with legend
610            fig, (ax_main, ax_leg) = plt.subplots(
611                1, 2, figsize=(8, 4), dpi=200,
612                gridspec_kw={'width_ratios': [5, 1]}
613            )
614
615            # Main stratification panel
616            mesh = ax_main.pcolormesh(
617                x_edges,
618                y_edges,
619                display_rgb,
620                shading='auto',
621                edgecolors='none'
622            )
623
624            # Custom coordinate formatter: show time, elevation, and mixture fractions
625            def main_format(x, y):
626                # check bounds
627                if x < x_edges[0] or x > x_edges[-1] or y < y_edges[0] or y > y_edges[-1]:
628                    return ''
629                # locate cell
630                i = np.searchsorted(x_edges, x) - 1
631                j = np.searchsorted(y_edges, y) - 1
632                i = np.clip(i, 0, display_rgb.shape[1]-1)
633                j = np.clip(j, 0, display_rgb.shape[0]-1)
634                # get fractions
635                fH2O, fCO2, fDust = display_frac[j, i]
636                return f"Time={x:.2f}, Elev={y:.2f}, H2O={fH2O:.2f}, CO2={fCO2:.2f}, Dust={fDust:.2f}"
637            ax_main.format_coord = main_format
638            ax_main.set_facecolor('white')
639            ax_main.set_title(f"Ternary mix over time (Grid point {ig+1}, Slope {isl+1})", fontweight='bold')
640            ax_main.set_xlabel("Time (Mars years)")
641            ax_main.set_ylabel("Elevation (m)")
642
643            # Legend panel using proper edges
644            u_edges = np.linspace(0, 1, res+1)
645            v_edges = np.linspace(0, np.sqrt(3)/2, res+1)
646            ax_leg.pcolormesh(
647                u_edges,
648                v_edges,
649                legend_rgba,
650                shading='auto',
651                edgecolors='none'
652            )
653            ax_leg.set_aspect('equal')
654
655            # Custom coordinate formatter for legend: show barycentric fractions
656            def legend_format(x, y):
657                # compute barycentric coords from cartesian (x,y)
658                V = 2 * y / np.sqrt(3)
659                U = x - 0.5 * V
660                W = 1 - U - V
661                if U >= 0 and V >= 0 and W >= 0:
662                    return f"H2O: {W:.2f}, Dust: {V:.2f}, CO2: {U:.2f}"
663                else:
664                    return ''
665            ax_leg.format_coord = legend_format
666
667            # Draw triangle border and gridlines
668            triangle = np.array([[0, 0], [1, 0], [0.5, np.sqrt(3)/2], [0, 0]])
669            ax_leg.plot(triangle[:, 0], triangle[:, 1], 'k-', linewidth=1, clip_on=False, zorder=10)
670            ticks = np.linspace(0.25, 0.75, 3)
671            for f in ticks:
672                ax_leg.plot([1 - f, 0.5 * (1 - f)], [0, (1 - f)*np.sqrt(3)/2], '--', color='k', linewidth=0.5, clip_on=False, zorder=9)
673                ax_leg.plot([f, f + 0.5 * (1 - f)], [0, (1 - f)*np.sqrt(3)/2], '--', color='k', linewidth=0.5, clip_on=False, zorder=9)
674                y = (np.sqrt(3)/2) * f
675                ax_leg.plot([0.5 * f, 1 - 0.5 * f], [y, y], '--', color='k', linewidth=0.5, clip_on=False, zorder=9)
676
677            # Legend labels
678            ax_leg.text(0, -0.05, 'H2O ice', ha='center', va='top', fontsize=8)
679            ax_leg.text(1, -0.05, 'CO2 ice', ha='center', va='top', fontsize=8)
680            ax_leg.text(0.5, np.sqrt(3)/2 + 0.05, 'Dust', ha='center', va='bottom', fontsize=8)
681            ax_leg.axis('off')
682
683            # Save figure
684            plt.tight_layout()
685            fname = os.path.join(output_folder, f"layering_rgb_evolution_ig{ig+1}_is{isl+1}.png")
686            fig.savefig(fname, dpi=150, bbox_inches='tight')
687
688
689def plot_dust_to_ice_ratio_over_time(
690    gridded_data,
691    ref_grid,
692    top_index,
693    heights_data,
694    date_time,
695    exclude_sub=False,
696    output_folder="."
697):
698    """
699    Plot the dust-to-ice ratio in the stratification over time,
700    using a blue-to-orange colormap:
701    - blue: ice-dominated (low dust-to-ice ratio)
702    - orange: dust-dominated (high dust-to-ice ratio)
703    """
704    h2o = gridded_data['h2o_ice']
705    dust = gridded_data['dust']
706    ngrid, ntime, nslope, nz = h2o.shape
707
708    # Elevation mask
709    if exclude_sub:
710        elevation_mask = (ref_grid >= 0.0)
711        elev = ref_grid[elevation_mask]
712    else:
713        elevation_mask = np.ones_like(ref_grid, dtype=bool)
714        elev = ref_grid
715
716    # Define custom blue-to-orange colormap
717    blue = np.array([0, 0, 255], dtype=float) / 255
718    orange = np.array([255, 165, 0], dtype=float) / 255
719    custom_cmap = LinearSegmentedColormap.from_list('BlueOrange', [blue, orange], N=256)
720
721    # Log‑ratio bounds and small epsilon to avoid log(0)
722    vmin, vmax = -2, 1
723    epsilon = 1e-6
724   
725    # Loop over grids and slopes
726    for ig in range(ngrid):
727        for isl in range(nslope):
728            ti = top_index[ig, :, isl].copy().astype(int)
729            for t in range(1, ntime):
730                if ti[t] <= 0:
731                    ti[t] = ti[t-1]
732
733            # Compute log10(dust/ice) profile at each time step
734            log_ratio_array = np.full((nz, ntime), np.nan, dtype=np.float32)
735            for t in range(ntime):
736                zmax = ti[t]
737                if zmax <= 0:
738                    continue
739
740                h2o_profile = np.clip(h2o[ig, t, isl, :zmax], 0, None)
741                dust_profile = np.clip(dust[ig, t, isl, :zmax], 0, None)
742
743                with np.errstate(divide='ignore', invalid='ignore'):
744                    ratio_profile = np.where(
745                        h2o_profile > 0,
746                        dust_profile / h2o_profile,
747                        10**(vmax + 1)
748                    )
749                    log_ratio = np.log10(ratio_profile + epsilon)
750                    log_ratio = np.clip(log_ratio, vmin, vmax)
751
752                log_ratio_array[:zmax, t] = log_ratio
753
754            # Convert back to linear ratio and apply elevation mask
755            ratio_array = 10**log_ratio_array
756            ratio_display = ratio_array[elevation_mask, :]
757
758            # Plot
759            fig, ax = plt.subplots(figsize=(8, 6), dpi=150)
760            im = ax.pcolormesh(
761                date_time,
762                elev,
763                ratio_display,
764                shading='auto',
765                cmap='managua_r',
766                norm=LogNorm(vmin=10**vmin, vmax=10**vmax),
767            )
768            x_edges = np.concatenate([date_time, [date_time[-1] + (date_time[-1]-date_time[-2])]])
769            attach_format_coord(ax, ratio_display, x_edges, np.concatenate([elev, [elev[-1] + (elev[-1]-elev[-2])]]), is_pcolormesh=True)
770            ax.set_title(f"Dust-to-Ice ratio over time (Grid point {ig+1}, Slope {isl+1})", fontweight='bold')
771            ax.set_xlabel('Time (Mars years)')
772            ax.set_ylabel('Elevation (m)')
773
774            # Add colorbar with simplified ratio labels
775            cbar = fig.colorbar(im, ax=ax, orientation='vertical')
776            cbar.set_label('Dust / H₂O ice (ratio)')
777
778            # Define custom ticks and labels
779            ticks = [1e-2, 1e-1, 1, 1e1]
780            labels = ['1:100', '1:10', '1:1', '10:1']
781            cbar.set_ticks(ticks)
782            cbar.set_ticklabels(labels)
783
784            # Save figure
785            plt.tight_layout()
786            fname = os.path.join(
787                output_folder,
788                f"dust_to_ice_ratio_grid{ig+1}_slope{isl+1}.png"
789            )
790            fig.savefig(fname, dpi=150)
791
792
793def plot_strata_count_and_total_height(heights_data, date_time, output_folder="."):
794    """
795    For each grid point and slope, plot:
796      - Number of strata vs time
797      - Total deposit height vs time
798    """
799    ntime = len(heights_data)
800    nslope = len(heights_data[0])
801    ngrid = heights_data[0][0].shape[0]
802
803    for ig in range(ngrid):
804        for isl in range(nslope):
805            n_strata_t = np.zeros(ntime, dtype=int)
806            total_height_t = np.zeros(ntime, dtype=float)
807
808            for t_idx in range(ntime):
809                h_mat = heights_data[t_idx][isl]
810                raw_h = h_mat[ig, :]
811                valid_mask = (~np.isnan(raw_h)) & (raw_h != 0.0)
812                if np.any(valid_mask):
813                    h_valid = raw_h[valid_mask]
814                    n_strata_t[t_idx] = h_valid.size
815                    total_height_t[t_idx] = np.max(h_valid)
816
817            fig, axes = plt.subplots(2, 1, figsize=(8, 6), sharex=True)
818            fig.suptitle(
819                f"Strata count & total height over time for (Grid point {ig+1}, Slope {isl+1})",
820                fontsize=14,
821                fontweight='bold'
822            )
823
824            axes[0].plot(date_time, n_strata_t, marker='+', linestyle='-')
825            axes[0].set_ylabel("Number of strata")
826            axes[0].grid(True)
827
828            axes[1].plot(date_time, total_height_t, marker='+', linestyle='-')
829            axes[1].set_xlabel("Time (Mars years)")
830            axes[1].set_ylabel("Total height (m)")
831            axes[1].grid(True)
832
833            fig.tight_layout(rect=[0, 0, 1, 0.95])
834            fname = os.path.join(
835                output_folder, f"strata_count_height_ig{ig+1}_is{isl+1}.png"
836            )
837            fig.savefig(fname, dpi=150)
838
839
840def read_orbital_data(orb_file, martian_to_earth):
841    """
842    Read the .asc file containing obliquity, eccentricity and Ls p.
843    Columns:
844      0 = time in thousand Martian years
845      1 = obliquity (deg)
846      2 = eccentricity
847      3 = Ls p (deg)
848    Converts times to Earth years.
849    """
850    data = np.loadtxt(orb_file)
851    dates_mka = data[:, 0]
852    dates_yr = dates_mka * 1e3 / martian_to_earth
853    obliquity = data[:, 1]
854    eccentricity = data[:, 2]
855    lsp = data[:, 3]
856    return dates_yr, obliquity, eccentricity, lsp
857
858
859def plot_orbital_parameters(infofile, orb_file, date_time, output_folder="."):
860    """
861    Plot the evolution of obliquity, eccentricity and Ls p
862    versus simulated time.
863    """
864    # Read conversion factor from infofile
865    _, martian_to_earth = read_infofile(infofile)
866
867    # Read orbital data
868    dates_yr, obl, ecc, lsp = read_orbital_data(orb_file, martian_to_earth)
869
870    # Interpolate orbital parameters at simulation dates (date_time)
871    obl_interp = interp1d(dates_yr, obl, kind='linear', bounds_error=False, fill_value="extrapolate")(date_time)
872    ecc_interp = interp1d(dates_yr, ecc, kind='linear', bounds_error=False, fill_value="extrapolate")(date_time)
873    lsp_interp = interp1d(dates_yr, lsp, kind='linear', bounds_error=False, fill_value="extrapolate")(date_time)
874
875    # Plot
876    fig, axes = plt.subplots(3, 1, figsize=(8, 10), sharex=True)
877    fig.suptitle("Orbital parameters vs simulated time", fontsize=14, fontweight='bold')
878
879    axes[0].plot(date_time, obl_interp, 'r-', marker='+')
880    axes[0].set_ylabel("Obliquity (°)")
881    axes[0].grid(True)
882
883    axes[1].plot(date_time, ecc_interp, 'b-', marker='+')
884    axes[1].set_ylabel("Eccentricity")
885    axes[1].grid(True)
886
887    axes[2].plot(date_time, lsp_interp, 'g-', marker='+')
888    axes[2].set_ylabel("Ls of perihelion  (°)")
889    axes[2].set_xlabel("Time (Mars years)")
890    axes[2].grid(True)
891
892    plt.tight_layout(rect=[0, 0, 1, 0.96])
893    fname = os.path.join(output_folder, "orbital_parameters_laskar.png")
894    fig.savefig(fname, dpi=150)
895
896
897def mars_ls(pday, peri_day, e_elips, year_day, lsperi=0.0):
898    """
899    Compute solar longitude (Ls) in radians for a given Mars date array 'pday'.
900    Returns Ls in degrees [0, 360).
901    """
902    zz = (pday - peri_day) / year_day
903    zanom = 2 * np.pi * (zz - np.round(zz))
904    xref = np.abs(zanom)
905
906    # Solve Kepler's equation via Newton–Raphson
907    zx0 = xref + e_elips * np.sin(xref)
908    for _ in range(10):
909        f  = zx0 - e_elips * np.sin(zx0) - xref
910        fp = 1 - e_elips * np.cos(zx0)
911        dz = -f / fp
912        zx0 += dz
913        if np.all(np.abs(dz) <= 1e-7):
914            break
915
916    zx0 = np.where(zanom < 0, -zx0, zx0)
917    zteta = 2 * np.arctan(
918        np.sqrt((1 + e_elips) / (1 - e_elips)) * np.tan(zx0 / 2)
919    )
920    psollong = np.mod(zteta + lsperi, 2 * np.pi)
921
922    return np.degrees(psollong)
923
924
925def read_orbital_data_nc(starts_folder, infofile=None):
926    """
927    Read orbital parameters from restartfi_postPEM*.nc files in starts_folder.
928    """
929    if not os.path.isdir(starts_folder):
930        raise ValueError(f"Invalid starts_folder '{starts_folder}': not a directory.")
931
932    # Read simulation time mapping if provided
933    if infofile:
934        dates_yr, martian_to_earth = read_infofile(infofile)
935    else:
936        dates_yr = None
937
938    pattern = os.path.join(starts_folder, "restartfi_postPEM*.nc")
939    files = glob(pattern)
940    if not files:
941        raise FileNotFoundError(f"No NetCDF restart files found matching {pattern}")
942
943    def extract_number(path):
944        name = os.path.basename(path)
945        prefix = 'restartfi_postPEM'
946        if name.startswith(prefix) and name.endswith('.nc'):
947            num_str = name[len(prefix):-3]
948            if num_str.isdigit():
949                return int(num_str)
950        return float('inf')
951
952    files = sorted(files, key=extract_number)
953
954    all_year_day, all_peri, all_aphe, all_date_peri, all_obl = [], [], [], [], []
955    for nc_path in files:
956        with Dataset(nc_path, 'r') as nc:
957            ctrl = nc.variables['controle'][:]
958            all_year_day.append(ctrl[13])
959            all_peri.append(ctrl[14])
960            all_aphe.append(ctrl[15])
961            all_date_peri.append(ctrl[16])
962            all_obl.append(ctrl[17])
963
964    year_day      = np.array(all_year_day)
965    perihelion    = np.array(all_peri)
966    aphelion      = np.array(all_aphe)
967    date_peri_day = np.array(all_date_peri)
968    obliquity     = np.array(all_obl)
969
970    eccentricity  = (aphelion - perihelion) / (aphelion + perihelion)
971    ls_perihelion = mars_ls(date_peri_day,0.,eccentricity,year_day)
972
973    return dates_yr, obliquity, eccentricity, ls_perihelion
974
975
976def plot_orbital_parameters_nc(starts_folder, infofile, date_time, output_folder="."):
977    """
978    Plot the evolution of obliquity, eccentricity and Ls p coming from simulation data
979    versus simulated time.
980    """
981    # Read orbital data
982    times_yr, obl, ecc, lsp = read_orbital_data_nc(starts_folder, infofile)
983
984    fargs = dict(kind='linear', bounds_error=False, fill_value='extrapolate')
985    obl_i = interp1d(times_yr, obl, **fargs)(date_time)
986    ecc_i = interp1d(times_yr, ecc, **fargs)(date_time)
987    lsp_i = interp1d(times_yr, lsp, **fargs)(date_time)
988
989    fig, axes = plt.subplots(3,1, figsize=(8,10), sharex=True)
990    fig.suptitle("Orbital parameters vs simulated time", fontsize=14, fontweight='bold')
991
992    # Plot
993    axes[0].plot(date_time, obl_i, 'r-', marker='+')
994    axes[0].set_ylabel("Obliquity (°)")
995    axes[0].grid(True)
996
997    axes[1].plot(date_time, ecc_i, 'b-', marker='+')
998    axes[1].set_ylabel("Eccentricity")
999    axes[1].grid(True)
1000
1001    axes[2].plot(date_time, lsp_i, 'g-', marker='+')
1002    axes[2].set_ylabel("Ls of perihelion (°)")
1003    axes[2].set_xlabel("Time (Mars years)")
1004    axes[2].grid(True)
1005
1006    plt.tight_layout(rect=[0,0,1,0.96])
1007    outname = os.path.join(output_folder, "orbital_parameters_simu.png")
1008    fig.savefig(outname, dpi=150)
1009
1010
1011def plot_dust_to_ice_ratio_with_obliquity(
1012    starts_folder,
1013    infofile,
1014    gridded_data,
1015    ref_grid,
1016    top_index,
1017    heights_data,
1018    date_time,
1019    exclude_sub=False,
1020    output_folder="."
1021):
1022    """
1023    Plot the dust-to-ice ratio over time as a heatmap, and overlay the evolution of
1024    obliquity on a secondary y-axis.
1025    """
1026    h2o = gridded_data['h2o_ice']
1027    co2 = gridded_data['co2_ice']
1028    dust = gridded_data['dust']
1029    ngrid, ntime, nslope, nz = h2o.shape
1030
1031    # Read orbital data
1032    times_yr, obl, _, _ = read_orbital_data_nc(starts_folder, infofile)
1033    fargs = dict(kind='linear', bounds_error=False, fill_value='extrapolate')
1034    obliquity = interp1d(times_yr, obl, **fargs)(date_time)
1035
1036    # Computed total height
1037    for ig in range(ngrid):
1038        for isl in range(nslope):
1039            total_height_t = np.zeros(ntime, dtype=float)
1040
1041            for t_idx in range(ntime):
1042                h_mat = heights_data[t_idx][isl]
1043                raw_h = h_mat[ig, :]
1044                valid_mask = (~np.isnan(raw_h)) & (raw_h != 0.0)
1045                if np.any(valid_mask):
1046                    h_valid = raw_h[valid_mask]
1047                    total_height_t[t_idx] = np.max(h_valid)
1048
1049    # Compute the per-interval sign of height change
1050    dh = np.diff(total_height_t)
1051    signs = np.sign(dh)
1052    color_map = { 1: 'green', -1: 'red', 0: 'orange' }
1053
1054    # Elevation mask
1055    if exclude_sub:
1056        elevation_mask = (ref_grid >= 0.0)
1057        elev = ref_grid[elevation_mask]
1058    else:
1059        elevation_mask = np.ones_like(ref_grid, dtype=bool)
1060        elev = ref_grid
1061
1062    # Custom colormap: blue (ice) to orange (dust)
1063    blue = np.array([0, 0, 255]) / 255
1064    orange = np.array([255, 165, 0]) / 255
1065    custom_cmap = LinearSegmentedColormap.from_list('BlueOrange', [blue, orange], N=256)
1066
1067    # Log‑ratio bounds and small epsilon to avoid log(0)
1068    vmin, vmax = -2, 1
1069    epsilon = 1e-6
1070
1071    # Loop over grids and slopes
1072    for ig in range(ngrid):
1073        for isl in range(nslope):
1074            ti = top_index[ig, :, isl].copy().astype(int)
1075            frac_all = np.zeros((nz, ntime, 3), dtype=float)  # store fH2O, fCO2, fDust
1076            for t in range(1, ntime):
1077                if ti[t] <= 0:
1078                    ti[t] = ti[t-1]
1079
1080            # Compute log10(dust/ice) profile at each time step
1081            log_ratio_array = np.full((nz, ntime), np.nan, dtype=np.float32)
1082            for t in range(ntime):
1083                zmax = ti[t]
1084                if zmax <= 0:
1085                    continue
1086                cH2O = np.clip(h2o[ig, t, isl, :zmax], 0, None)
1087                cCO2 = np.clip(co2[ig, t, isl, :zmax], 0, None)
1088                cDust = np.clip(dust[ig, t, isl, :zmax], 0, None)
1089                total = cH2O + cCO2 + cDust
1090                total[total == 0] = 1.0
1091                fH2O = cH2O / total
1092                fCO2 = cCO2 / total
1093                fDust = cDust / total
1094                frac_all[:zmax, t, :] = np.stack([fH2O, fCO2, fDust], axis=1)
1095
1096                h2o_profile = np.clip(h2o[ig, t, isl, :zmax], 0, None)
1097                dust_profile = np.clip(dust[ig, t, isl, :zmax], 0, None)
1098
1099                with np.errstate(divide='ignore', invalid='ignore'):
1100                    ratio_profile = np.where(
1101                        h2o_profile > 0,
1102                        dust_profile / h2o_profile,
1103                        10**(vmax + 1)
1104                    )
1105                    log_ratio = np.log10(ratio_profile + epsilon)
1106                    log_ratio = np.clip(log_ratio, vmin, vmax)
1107
1108                log_ratio_array[:zmax, t] = log_ratio
1109
1110            # Convert back to linear ratio and mask
1111            ratio_array = 10**log_ratio_array
1112            ratio_display = ratio_array[elevation_mask, :]
1113            display_frac = frac_all[elevation_mask, :, :]
1114
1115            # Plot
1116            fig, ax = plt.subplots(figsize=(8, 6), dpi=150)
1117            im = ax.pcolormesh(
1118                date_time,
1119                elev,
1120                ratio_display,
1121                shading='auto',
1122                cmap='managua_r',
1123                norm=LogNorm(vmin=10**vmin, vmax=10**vmax),
1124            )
1125            x_edges = np.concatenate([date_time, [date_time[-1] + (date_time[-1]-date_time[-2])]])
1126            y_edges = np.concatenate([elev, [elev[-1] + (elev[-1]-elev[-2])]])
1127            def format_coord_all(x, y):
1128                # check bounds
1129                if x < x_edges[0] or x > x_edges[-1] or y < y_edges[0] or y > y_edges[-1]:
1130                    return ''
1131                # locate cell
1132                i = np.searchsorted(x_edges, x) - 1
1133                j = np.searchsorted(y_edges, y) - 1
1134                i = np.clip(i, 0, display_frac.shape[1]-1)
1135                j = np.clip(j, 0, display_frac.shape[0]-1)
1136                # get fractions
1137                fH2O  = display_frac[j, i, 0]
1138                fDust = display_frac[j, i, 2]
1139                obl   = np.interp(x, date_time, obliquity)
1140                return f"Time={x:.2f}, Elev={y:.2f}, H2O={fH2O:.2f}, Dust={fDust:.2f}, Obl={obl:.2f}°"
1141
1142            ax.format_coord = format_coord_all
1143            ax.set_title(f"Dust-to-Ice ratio over time (Grid point {ig+1}, Slope {isl+1})", fontweight='bold')
1144            ax.set_xlabel('Time (Mars years)')
1145            ax.set_ylabel('Elevation (m)')
1146
1147            # Add colorbar
1148            cbar = fig.colorbar(im, ax=ax, orientation='vertical', pad=0.15)
1149            cbar.set_label('Dust / H₂O ice (ratio)')
1150            cbar.set_ticks([1e-2, 1e-1, 1, 1e1])
1151            cbar.set_ticklabels(['1:100', '1:10', '1:1', '10:1'])
1152
1153            # Overlay obliquity on secondary y-axis
1154            ax2 = ax.twinx()
1155            for i in range(len(dh)):
1156                ax2.plot(
1157                    [date_time[i], date_time[i+1]],
1158                    [obliquity[i], obliquity[i+1]],
1159                    color=color_map[signs[i]],
1160                    marker='+',
1161                    linewidth=1.5
1162                )
1163            ax2.format_coord = format_coord_all
1164            ax2.set_ylabel('Obliquity (°)')
1165            ax2.tick_params(axis='y')
1166            ax2.grid(False)
1167
1168            # Save
1169            os.makedirs(output_folder, exist_ok=True)
1170            outname = os.path.join(
1171                output_folder,
1172                f'dust_ice_obliquity_grid{ig+1}_slope{isl+1}.png'
1173            )
1174            plt.tight_layout()
1175            fig.savefig(outname, dpi=150)
1176
1177
1178def main():
1179    # 1) Get user inputs
1180    folder_path, base_name, infofile, orbfile = get_user_inputs()
1181
1182    # 2) List and verify NetCDF files
1183    files = list_netcdf_files(folder_path, base_name)
1184    if not files:
1185        print(f"No NetCDF files named \"{base_name}#.nc\" found in \"{folder_path}\".")
1186        sys.exit(1)
1187    print(f"> Found {len(files)} NetCDF file(s).")
1188
1189    # 3) Open one sample to get grid dimensions & coordinates
1190    sample_file = files[0]
1191    ngrid, nslope, longitude, latitude = open_sample_dataset(sample_file)
1192    print(f"> ngrid  = {ngrid}, nslope = {nslope}")
1193
1194    # 4) Collect variable info + global min/max elevations
1195    var_info, max_nb_str, min_base_elev, max_top_elev = collect_stratification_variables(files, base_name)
1196    print(f"> max strata per slope = {max_nb_str}")
1197    print(f"> min base elev = {min_base_elev:.3f} m, max top elev = {max_top_elev:.3f} m")
1198
1199    # 5) Load full datasets
1200    datasets = load_full_datasets(files)
1201
1202    # 6) Extract stratification data
1203    heights_data, raw_prop_arrays, ntime = extract_stratification_data(datasets, var_info, ngrid, nslope, max_nb_str)
1204
1205    # 7) Close datasets
1206    for ds in datasets:
1207        ds.close()
1208
1209    # 8) Normalize to fractions
1210    frac_arrays = normalize_to_fractions(raw_prop_arrays)
1211
1212    # 9) Ask whether to include subsurface
1213    show_subsurface = get_yes_no_input("Show subsurface layers?")
1214    exclude_sub = not show_subsurface
1215    if exclude_sub:
1216        min_base_for_interp = 0.0
1217        print("> Interpolating only elevations >= 0 m (surface strata).")
1218    else:
1219        min_base_for_interp = min_base_elev
1220        print(f"> Interpolating full depth down to {min_base_elev:.3f} m.")
1221
1222    # 10) Prompt discretization step
1223    dz = prompt_discretization_step(max_top_elev)
1224
1225    # 11) Build reference grid and interpolate
1226    ref_grid, gridded_data, top_index = interpolate_data_on_refgrid(
1227        heights_data, frac_arrays, min_base_for_interp, max_top_elev, dz, exclude_sub=exclude_sub
1228    )
1229
1230    # 12) Read timestamps and conversion factor from infofile
1231    date_time, martian_to_earth = read_infofile(infofile)
1232    if date_time.size != ntime:
1233        print(f"Warning: {date_time.size} timestamps vs {ntime} NetCDF files.")
1234
1235    # 13) Plot stratification data over time
1236    plot_stratification_over_time(
1237        gridded_data, ref_grid, top_index, heights_data, date_time,
1238        exclude_sub=exclude_sub, output_folder="."
1239    )
1240    plot_stratification_rgb_over_time(
1241        gridded_data, ref_grid, top_index, heights_data, date_time,
1242        exclude_sub=exclude_sub, output_folder="."
1243    )
1244    #plot_dust_to_ice_ratio_over_time(
1245    #    gridded_data, ref_grid, top_index, heights_data, date_time,
1246    #    exclude_sub=exclude_sub, output_folder="."
1247    #)
1248    plot_dust_to_ice_ratio_with_obliquity(
1249        folder_path, infofile,
1250        gridded_data, ref_grid, top_index, heights_data, date_time,
1251        exclude_sub=exclude_sub, output_folder="."
1252    )
1253    #plot_strata_count_and_total_height(heights_data, date_time, output_folder=".")
1254
1255    # 14) Plot orbital parameters
1256    #plot_orbital_parameters(infofile, orbfile, date_time, output_folder=".")
1257    plot_orbital_parameters_nc(folder_path, infofile, date_time, output_folder=".")
1258
1259    # 15) Show all figures
1260    plt.show()
1261
1262
1263if __name__ == "__main__":
1264    main()
1265
Note: See TracBrowser for help on using the repository browser.