source: trunk/LMDZ.COMMON/libf/evolution/deftank/visu_evol_layering.py @ 3840

Last change on this file since 3840 was 3840, checked in by jbclement, 3 days ago

PEM:

  • Correction of a bug in the launching script.
  • Update of "visu_evol_layering.py", in particular to show value at cursor for 2D heatmaps.
  • Few cleanings.

JBC

  • Property svn:executable set to *
File size: 32.9 KB
Line 
1#!/usr/bin/env python3
2#######################################################################################################
3### Python script to output stratification data over time from "restartpem#.nc" files               ###
4### and to plot orbital parameters from "obl_ecc_lsp.asc"                                           ###
5#######################################################################################################
6
7import os
8import sys
9import numpy as np
10from glob import glob
11from netCDF4 import Dataset
12import matplotlib.pyplot as plt
13from mpl_toolkits.axes_grid1.inset_locator import inset_axes
14from matplotlib.colors import LinearSegmentedColormap, LogNorm
15from scipy.interpolate import interp1d
16
17
18def get_user_inputs():
19    """
20    Prompt the user for:
21      - folder_path: directory containing NetCDF files (default: "starts")
22      - base_name:   base filename (default: "restartpem")
23      - infofile:    name of the PEM info file (default: "info_PEM.txt")
24    Validates existence of folder and infofile before returning.
25    """
26    folder_path = input(
27        "Enter the folder path containing the NetCDF files "
28        "(press Enter for default [starts]): "
29    ).strip() or "starts"
30    while not os.path.isdir(folder_path):
31        print(f"  » \"{folder_path}\" does not exist or is not a directory.")
32        folder_path = input(
33            "Enter a valid folder path (press Enter for default [starts]): "
34        ).strip() or "starts"
35
36    base_name = input(
37        "Enter the base name of the NetCDF files "
38        "(press Enter for default [restartpem]): "
39    ).strip() or "restartpem"
40
41    infofile = input(
42        "Enter the name of the PEM info file "
43        "(press Enter for default [info_PEM.txt]): "
44    ).strip() or "info_PEM.txt"
45    while not os.path.isfile(infofile):
46        print(f"  » \"{infofile}\" does not exist or is not a file.")
47        infofile = input(
48            "Enter a valid PEM info filename (press Enter for default [info_PEM.txt]): "
49        ).strip() or "info_PEM.txt"
50
51    orbfile = input(
52        "Enter the name of the orbital parameters ASCII file "
53        "(press Enter for default [obl_ecc_lsp.asc]): "
54    ).strip() or "obl_ecc_lsp.asc"
55    while not os.path.isfile(orbfile):
56        print(f"  » \"{orbfile}\" does not exist or is not a file.")
57        orbfile = input(
58            "Enter a valid orbital parameters ASCII filename (press Enter for default [obl_ecc_lsp.asc]): "
59        ).strip() or "info_PEM.txt"
60
61    return folder_path, base_name, infofile, orbfile
62
63
64def list_netcdf_files(folder_path, base_name):
65    """
66    List and sort all NetCDF files matching the pattern {base_name}#.nc
67    in folder_path. Returns a sorted list of full file paths.
68    """
69    pattern = os.path.join(folder_path, f"{base_name}[0-9]*.nc")
70    all_files = glob(pattern)
71    if not all_files:
72        return []
73
74    def extract_index(pathname):
75        fname = os.path.basename(pathname)
76        idx_str = fname[len(base_name):-3]
77        return int(idx_str) if idx_str.isdigit() else float('inf')
78
79    sorted_files = sorted(all_files, key=extract_index)
80    return sorted_files
81
82
83def open_sample_dataset(file_path):
84    """
85    Open a single NetCDF file and extract:
86      - ngrid, nslope
87      - longitude, latitude
88    Returns (ngrid, nslope, longitude_array, latitude_array).
89    """
90    with Dataset(file_path, 'r') as ds:
91        ngrid = ds.dimensions['physical_points'].size
92        nslope = ds.dimensions['nslope'].size
93        longitude = ds.variables['longitude'][:].copy()
94        latitude = ds.variables['latitude'][:].copy()
95    return ngrid, nslope, longitude, latitude
96
97
98def collect_stratification_variables(files, base_name):
99    """
100    Scan all files to collect:
101      - variable names for each stratification property
102      - max number of strata (max_nb_str)
103      - global min base elevation and max top elevation
104    Returns:
105      - var_info: dict mapping each property_name -> sorted list of var names
106      - max_nb_str: int
107      - min_base_elev: float
108      - max_top_elev: float
109    """
110    max_nb_str = 0
111    min_base_elev = np.inf
112    max_top_elev = -np.inf
113
114    property_markers = {
115        'heights':   'stratif_slope',    # "..._top_elevation"
116        'co2_ice':   'h_co2ice',
117        'h2o_ice':   'h_h2oice',
118        'dust':      'h_dust',
119        'pore':      'h_pore',
120        'pore_ice':  'poreice_volfrac'
121    }
122    var_info = {prop: set() for prop in property_markers}
123
124    for file_path in files:
125        with Dataset(file_path, 'r') as ds:
126            if 'nb_str_max' in ds.dimensions:
127                max_nb_str = max(max_nb_str, ds.dimensions['nb_str_max'].size)
128
129            nslope = ds.dimensions['nslope'].size
130            for k in range(1, nslope + 1):
131                var_name = f"stratif_slope{k:02d}_top_elevation"
132                if var_name in ds.variables:
133                    arr = ds.variables[var_name][:]
134                    min_base_elev = min(min_base_elev, np.min(arr))
135                    max_top_elev = max(max_top_elev, np.max(arr))
136                    var_info['heights'].add(var_name)
137
138            for full_var in ds.variables:
139                for prop, marker in property_markers.items():
140                    if (marker in full_var) and prop != 'heights':
141                        var_info[prop].add(full_var)
142
143    for prop in var_info:
144        var_info[prop] = sorted(var_info[prop])
145
146    return var_info, max_nb_str, min_base_elev, max_top_elev
147
148
149def load_full_datasets(files):
150    """
151    Open all NetCDF files and return a list of Dataset objects.
152    (They should be closed by the caller after use.)
153    """
154    return [Dataset(fp, 'r') for fp in files]
155
156
157def extract_stratification_data(datasets, var_info, ngrid, nslope, max_nb_str):
158    """
159    Build:
160      - heights_data[t_idx][isl] = 2D array (ngrid, n_strata_current) of top_elevations.
161      - raw_prop_arrays[prop] = 4D array (ngrid, ntime, nslope, max_nb_str) of per-strata values.
162    Returns:
163      - heights_data: list (ntime) of lists (nslope) of 2D arrays
164      - raw_prop_arrays: dict mapping each property_name -> 4D array
165      - ntime: number of time steps (files)
166    """
167    ntime = len(datasets)
168
169    heights_data = [
170        [None for _ in range(nslope)]
171        for _ in range(ntime)
172    ]
173    for t_idx, ds in enumerate(datasets):
174        for var_name in var_info['heights']:
175            slope_idx = int(var_name.split("slope")[1].split("_")[0]) - 1
176            if 0 <= slope_idx < nslope:
177                raw = ds.variables[var_name][0, :, :]  # (n_strata, ngrid)
178                heights_data[t_idx][slope_idx] = raw.# (ngrid, n_strata)
179
180    raw_prop_arrays = {}
181    for prop in var_info:
182        if prop == 'heights':
183            continue
184        raw_prop_arrays[prop] = np.zeros((ngrid, ntime, nslope, max_nb_str), dtype=np.float32)
185
186    def slope_index_from_var(vname):
187        return int(vname.split("slope")[1].split("_")[0]) - 1
188
189    for prop in raw_prop_arrays:
190        slope_map = {}
191        for vname in var_info[prop]:
192            isl = slope_index_from_var(vname)
193            if 0 <= isl < nslope:
194                slope_map[isl] = vname
195
196        arr = raw_prop_arrays[prop]
197        for t_idx, ds in enumerate(datasets):
198            for isl, var_name in slope_map.items():
199                raw = ds.variables[var_name][0, :, :]  # (n_strata, ngrid)
200                n_strata_current = raw.shape[0]
201                arr[:, t_idx, isl, :n_strata_current] = raw.T
202
203    return heights_data, raw_prop_arrays, ntime
204
205
206def normalize_to_fractions(raw_prop_arrays):
207    """
208    Given raw_prop_arrays for 'co2_ice', 'h2o_ice', 'dust', 'pore' (in meters),
209    normalize each set of strata so that the sum of those four = 1 per cell.
210    Returns:
211      - frac_arrays: dict mapping same keys -> 4D arrays of fractions (0..1).
212    """
213    co2 = raw_prop_arrays['co2_ice']
214    h2o = raw_prop_arrays['h2o_ice']
215    dust = raw_prop_arrays['dust']
216    pore = raw_prop_arrays['pore']
217
218    total = co2 + h2o + dust + pore
219    mask = total > 0.0
220
221    frac_co2 = np.zeros_like(co2, dtype=np.float32)
222    frac_h2o = np.zeros_like(h2o, dtype=np.float32)
223    frac_dust = np.zeros_like(dust, dtype=np.float32)
224    frac_pore = np.zeros_like(pore, dtype=np.float32)
225
226    frac_co2[mask] = co2[mask] / total[mask]
227    frac_h2o[mask] = h2o[mask] / total[mask]
228    frac_dust[mask] = dust[mask] / total[mask]
229    frac_pore[mask] = pore[mask] / total[mask]
230
231    return {
232        'co2_ice': frac_co2,
233        'h2o_ice': frac_h2o,
234        'dust':     frac_dust,
235        'pore':     frac_pore
236    }
237
238
239def read_infofile(file_name):
240    """
241    Reads "info_PEM.txt". Expects:
242      - First line: parameters where the 3rd value is martian_to_earth conversion factor.
243      - Each subsequent line: floats where first value is simulation timestamp (in Mars years).
244    Returns:
245      - date_time: 1D numpy array of timestamps (Mars years)
246      - martian_to_earth: float conversion factor
247    """
248    date_time = []
249    with open(file_name, 'r') as fp:
250        first = fp.readline().split()
251        martian_to_earth = float(first[2])
252        for line in fp:
253            parts = line.strip().split()
254            if not parts:
255                continue
256            try:
257                date_time.append(float(parts[0]))
258            except ValueError:
259                continue
260    return np.array(date_time, dtype=np.float64), martian_to_earth
261
262
263def get_yes_no_input(prompt: str) -> bool:
264    """
265    Prompt the user with a yes/no question. Returns True for yes, False for no.
266    """
267    while True:
268        choice = input(f"{prompt} (y/n): ").strip().lower()
269        if choice in ['y', 'yes']:
270            return True
271        elif choice in ['n', 'no']:
272            return False
273        else:
274            print("Please respond with y or n.")
275
276
277def prompt_discretization_step(max_top_elev):
278    """
279    Prompt for a positive float dz such that 0 < dz <= max_top_elev.
280    """
281    while True:
282        entry = input(
283            "Enter the discretization step of the reference grid for the elevation [m]: "
284        ).strip()
285        try:
286            dz = float(entry)
287            if dz <= 0:
288                print("  » Discretization step must be strictly positive!")
289                continue
290            if dz > max_top_elev:
291                print(
292                    f"  » {dz:.3e} m is greater than the maximum top elevation "
293                    f"({max_top_elev:.3e} m). Please enter a smaller value."
294                )
295                continue
296            return dz
297        except ValueError:
298            print("  » Invalid numeric value. Please try again.")
299
300
301def interpolate_data_on_refgrid(
302    heights_data,
303    prop_arrays,
304    min_base_for_interp,
305    max_top_elev,
306    dz,
307    exclude_sub=False
308):
309    """
310    Build a reference elevation grid and interpolate strata fractions onto it.
311
312    Returns:
313      - ref_grid: 1D array of elevations (nz,)
314      - gridded_data: dict mapping each property_name to 4D array
315        (ngrid, ntime, nslope, nz) with interpolated fractions.
316      - top_index: 3D array (ngrid, ntime, nslope) of ints:
317        number of levels covered by the topmost stratum.
318    """
319    if exclude_sub and (dz > max_top_elev):
320        ref_grid = np.array([0.0, max_top_elev], dtype=np.float32)
321    else:
322        ref_grid = np.arange(min_base_for_interp, max_top_elev + dz/2, dz)
323    nz = len(ref_grid)
324    print(f"> Number of reference grid points = {nz}")
325
326    sample_prop = next(iter(prop_arrays.values()))
327    ngrid, ntime, nslope, max_nb_str = sample_prop.shape
328
329    gridded_data = {
330        prop: np.full((ngrid, ntime, nslope, nz), -1.0, dtype=np.float32)
331        for prop in prop_arrays
332    }
333    top_index = np.zeros((ngrid, ntime, nslope), dtype=np.int32)
334
335    for ig in range(ngrid):
336        for t_idx in range(ntime):
337            for isl in range(nslope):
338                h_mat = heights_data[t_idx][isl]
339                if h_mat is None:
340                    continue
341
342                raw_h = h_mat[ig, :]
343                h_all = np.full((max_nb_str,), np.nan, dtype=np.float32)
344                n_strata_current = raw_h.shape[0]
345                h_all[:n_strata_current] = raw_h
346
347                if exclude_sub:
348                    epsilon = 1e-6
349                    valid_mask = (h_all >= -epsilon)
350                else:
351                    valid_mask = (~np.isnan(h_all)) & (h_all != 0.0)
352
353                if not np.any(valid_mask):
354                    continue
355
356                h_valid = h_all[valid_mask]
357                top_h = np.max(h_valid)
358                i_zmax = np.searchsorted(ref_grid, top_h, side='right')
359                top_index[ig, t_idx, isl] = i_zmax
360                if i_zmax == 0:
361                    continue
362
363                for prop, arr in prop_arrays.items():
364                    prop_profile_all = arr[ig, t_idx, isl, :]
365                    prop_profile = prop_profile_all[valid_mask]
366                    if prop_profile.size == 0:
367                        continue
368
369                    f_interp = interp1d(
370                        h_valid,
371                        prop_profile,
372                        kind='next',
373                        bounds_error=False,
374                        fill_value=-1.0
375                    )
376                    gridded_data[prop][ig, t_idx, isl, :i_zmax] = f_interp(ref_grid[:i_zmax])
377
378    return ref_grid, gridded_data, top_index
379
380
381def attach_format_coord(ax, mat, x, y, is_pcolormesh=True):
382    """
383    Attach a format_coord function to the axes to display x, y, and value at cursor.
384    Works for both pcolormesh and imshow style grids.
385    """
386    # Determine dimensions
387    if mat.ndim == 2:
388        ny, nx = mat.shape
389    elif mat.ndim == 3 and mat.shape[2] in (3, 4):
390        ny, nx, nc = mat.shape
391    else:
392        raise ValueError(f"Unsupported mat shape {mat.shape}")
393    # Edges or extents
394    if is_pcolormesh:
395        xedges, yedges = x, y
396    else:
397        x0, x1 = x.min(), x.max()
398        y0, y1 = y.min(), y.max()
399
400    def format_coord(xp, yp):
401        # Map to indices
402        if is_pcolormesh:
403            col = np.searchsorted(xedges, xp) - 1
404            row = np.searchsorted(yedges, yp) - 1
405        else:
406            col = int((xp - x0) / (x1 - x0) * nx)
407            row = int((yp - y0) / (y1 - y0) * ny)
408        # Within bounds?
409        if 0 <= row < ny and 0 <= col < nx:
410            if mat.ndim == 2:
411                v = mat[row, col]
412                return f"x={xp:.3g}, y={yp:.3g}, val={v:.3g}"
413            else:
414                vals = mat[row, col]
415                txt = ", ".join(f"{vv:.3g}" for vv in vals[:3])
416                return f"x={xp:.3g}, y={yp:.3g}, val=({txt})"
417        return f"x={xp:.3g}, y={yp:.3g}"
418
419    ax.format_coord = format_coord
420
421
422def plot_stratification_over_time(
423    gridded_data,
424    ref_grid,
425    top_index,
426    heights_data,
427    date_time,
428    exclude_sub=False,
429    output_folder="."
430):
431    """
432    For each grid point and slope, generate a 2×2 figure of:
433      - CO2 ice fraction
434      - H2O ice fraction
435      - Dust fraction
436      - Pore fraction
437    """
438    prop_names = ['co2_ice', 'h2o_ice', 'dust', 'pore']
439    titles = ["CO2 ice", "H2O ice", "Dust", "Pore"]
440    cmap = plt.get_cmap('turbo').copy()
441    cmap.set_under('white')
442    vmin, vmax = 0.0, 1.0
443
444    sample_prop = next(iter(gridded_data.values()))
445    ngrid, ntime, nslope, nz = sample_prop.shape
446
447    if exclude_sub:
448        positive_indices = np.where(ref_grid >= 0.0)[0]
449        sub_ref_grid = ref_grid[positive_indices]
450    else:
451        positive_indices = np.arange(nz)
452        sub_ref_grid = ref_grid
453
454    for ig in range(ngrid):
455        for isl in range(nslope):
456            fig, axes = plt.subplots(2, 2, figsize=(10, 8))
457            fig.suptitle(
458                f"Content variation over time for (Grid point {ig+1}, Slope {isl+1})",
459                fontsize=14,
460                fontweight='bold'
461            )
462
463            # Precompute valid stratum tops per time
464            valid_tops_per_time = []
465            for t_idx in range(ntime):
466                raw_h = heights_data[t_idx][isl][ig, :]
467                h_all = raw_h[~np.isnan(raw_h)]
468                if exclude_sub:
469                    h_all = h_all[h_all >= 0.0]
470                valid_tops_per_time.append(np.unique(h_all))
471
472            for idx, prop in enumerate(prop_names):
473                ax = axes.flat[idx]
474                data_3d = gridded_data[prop][ig, :, isl, :]
475                mat_full = data_3d.T
476                mat = mat_full[positive_indices, :].copy()
477                mat[mat < 0.0] = np.nan
478
479                # Mask above top stratum
480                for t_idx in range(ntime):
481                    i_zmax = top_index[ig, t_idx, isl]
482                    if i_zmax <= positive_indices[0]:
483                        mat[:, t_idx] = np.nan
484                    else:
485                        count_z = np.count_nonzero(positive_indices < i_zmax)
486                        mat[count_z:, t_idx] = np.nan
487
488                im = ax.pcolormesh(
489                    date_time,
490                    sub_ref_grid,
491                    mat,
492                    cmap=cmap,
493                    shading='auto',
494                    vmin=vmin,
495                    vmax=vmax
496                )
497                x_edges = np.concatenate([date_time, [date_time[-1] + (date_time[-1]-date_time[-2])]])
498                attach_format_coord(ax, mat, x_edges, np.concatenate([sub_ref_grid, [sub_ref_grid[-1] + (sub_ref_grid[-1]-sub_ref_grid[-2])]]), is_pcolormesh=True)
499                ax.set_title(titles[idx], fontsize=12)
500                ax.set_xlabel("Time (Mars years)")
501                ax.set_ylabel("Elevation (m)")
502
503            fig.subplots_adjust(right=0.88)
504            fig.tight_layout(rect=[0, 0, 0.88, 1.0])
505            cbar_ax = fig.add_axes([0.90, 0.15, 0.02, 0.7])
506            fig.colorbar(im, cax=cbar_ax, orientation='vertical', label="Content")
507
508            fname = os.path.join(
509                output_folder, f"layering_evolution_ig{ig+1}_is{isl+1}.png"
510            )
511            fig.savefig(fname, dpi=150)
512
513
514def plot_stratification_rgb_over_time(
515    gridded_data,
516    ref_grid,
517    top_index,
518    heights_data,
519    date_time,
520    exclude_sub=False,
521    output_folder="."
522):
523    """
524    Plot stratification over time colored using RGB ternary mix of H2O ice (blue), CO2 ice (violet), and dust (orange).
525    Includes a triangular legend showing the mix proportions.
526    """
527
528    # Define constant colors
529    violet = np.array([255, 0, 255], dtype=float) / 255
530    blue   = np.array([0, 0, 255], dtype=float) / 255
531    orange = np.array([255, 165, 0], dtype=float) / 255
532
533    # Prepare elevation mask
534    mask_elev = (ref_grid >= 0.0) if exclude_sub else np.ones_like(ref_grid, dtype=bool)
535    elev = ref_grid[mask_elev]
536
537    # Generate legend image once
538    res = 300
539    u = np.linspace(0, 1, res)
540    v = np.linspace(0, np.sqrt(3)/2, res)
541    X, Y = np.meshgrid(u, v)
542    V_bary = 2 * Y / np.sqrt(3)
543    U_bary = X - 0.5 * V_bary
544    W_bary = 1 - U_bary - V_bary
545    mask_triangle = (U_bary >= 0) & (V_bary >= 0) & (W_bary >= 0)
546
547    legend_rgb = (
548        U_bary[..., None] * violet
549        + V_bary[..., None] * orange
550        + W_bary[..., None] * blue
551    )
552    legend_rgb = np.clip(legend_rgb, 0.0, 1.0)
553    legend_rgba = np.zeros((res, res, 4))
554    legend_rgba[..., :3] = legend_rgb
555    legend_rgba[..., 3] = mask_triangle.astype(float)
556
557    # Loop over grid and slope
558    h2o = gridded_data['h2o_ice']
559    co2 = gridded_data['co2_ice']
560    dust = gridded_data['dust']
561    ngrid, ntime, nslope, nz = h2o.shape
562
563    for ig in range(ngrid):
564        for isl in range(nslope):
565            # Compute RGB stratification over time
566            rgb = np.ones((nz, ntime, 3), dtype=float)
567            for t in range(ntime):
568                mask_z = np.arange(nz) < top_index[ig, t, isl]
569                if not mask_z.any():
570                    continue
571                cH2O = np.clip(h2o[ig, t, isl, mask_z], 0, None)
572                cCO2 = np.clip(co2[ig, t, isl, mask_z], 0, None)
573                cDust = np.clip(dust[ig, t, isl, mask_z], 0, None)
574                total = cH2O + cCO2 + cDust
575                total[total == 0] = 1.0
576                fH2O = cH2O / total
577                fCO2 = cCO2 / total
578                fDust = cDust / total
579                mix = (
580                    np.outer(fH2O, blue)
581                    + np.outer(fCO2, violet)
582                    + np.outer(fDust, orange)
583                )
584                mix = np.clip(mix, 0.0, 1.0)
585                rgb[mask_z, t, :] = mix
586
587            display_rgb = rgb[mask_elev, :, :]
588
589            # Create figure with legend
590            fig, (ax_main, ax_leg) = plt.subplots(
591                1, 2, figsize=(12, 5), dpi=200,
592                gridspec_kw={'width_ratios': [5, 1]}
593            )
594
595            # Main stratification panel
596            ax_main.imshow(
597                display_rgb,
598                aspect='auto',
599                extent=[date_time[0], date_time[-1], elev.min(), elev.max()],
600                interpolation='nearest',
601                origin='lower'
602            )
603            x_centers = np.linspace(date_time[0], date_time[-1], display_rgb.shape[1])
604            y_centers = np.linspace(elev.min(), elev.max(), display_rgb.shape[0])
605            attach_format_coord(ax_main, display_rgb, x_centers, y_centers, is_pcolormesh=False)
606            ax_main.set_facecolor('white')
607            ax_main.set_title(f"Ternary mix over time (Grid point {ig+1}, Slope {isl+1})", fontweight='bold')
608            ax_main.set_xlabel("Time (Mars years)")
609            ax_main.set_ylabel("Elevation (m)")
610
611            # Legend panel
612            ax_leg.imshow(
613                legend_rgba,
614                extent=[0, 1, 0, np.sqrt(3)/2],
615                origin='lower',
616                interpolation='nearest'
617            )
618            attach_format_coord(ax_leg, legend_rgba, np.array([0, 1]), np.array([0, np.sqrt(3)/2]), is_pcolormesh=False)
619
620            # Draw triangle border
621            triangle = np.array([[0, 0], [1, 0], [0.5, np.sqrt(3)/2], [0, 0]])
622            ax_leg.plot(triangle[:, 0], triangle[:, 1], 'k-', linewidth=1)
623
624            # Dashed gridlines
625            ticks = np.linspace(0.25, 0.75, 3)
626            for f in ticks:
627                ax_leg.plot([1 - f, 0.5 * (1 - f)], [0, (1 - f)*np.sqrt(3)/2], '--', color='k', linewidth=0.5)
628                ax_leg.plot([f, f + 0.5 * (1 - f)], [0, (1 - f)*np.sqrt(3)/2], '--', color='k', linewidth=0.5)
629                y = (np.sqrt(3)/2) * f
630                ax_leg.plot([0.5 * f, 1 - 0.5 * f], [y, y], '--', color='k', linewidth=0.5)
631
632            # Legend labels
633            ax_leg.text(0, -0.05, 'H2O ice', ha='center', va='top', fontsize=8)
634            ax_leg.text(1, -0.05, 'CO2 ice', ha='center', va='top', fontsize=8)
635            ax_leg.text(0.5, np.sqrt(3)/2 + 0.05, 'Dust', ha='center', va='bottom', fontsize=8)
636            ax_leg.axis('off')
637
638            plt.tight_layout()
639
640            # Save figure
641            fname = os.path.join(output_folder, f"layering_rgb_evolution_ig{ig+1}_is{isl+1}.png")
642            fig.savefig(fname, dpi=150, bbox_inches='tight')
643
644
645def plot_dust_to_ice_ratio_over_time(
646    gridded_data,
647    ref_grid,
648    top_index,
649    heights_data,
650    date_time,
651    exclude_sub=False,
652    output_folder="."
653):
654    """
655    Plot the dust-to-ice ratio in the stratification over time,
656    using a blue-to-orange colormap:
657    - blue: ice-dominated (low dust-to-ice ratio)
658    - orange: dust-dominated (high dust-to-ice ratio)
659    """
660    h2o = gridded_data['h2o_ice']
661    dust = gridded_data['dust']
662    ngrid, ntime, nslope, nz = h2o.shape
663
664    # Elevation mask
665    if exclude_sub:
666        elevation_mask = (ref_grid >= 0.0)
667        elev = ref_grid[elevation_mask]
668    else:
669        elevation_mask = np.ones_like(ref_grid, dtype=bool)
670        elev = ref_grid
671
672    # Define custom blue-to-orange colormap
673    blue = np.array([0, 0, 255], dtype=float) / 255
674    orange = np.array([255, 165, 0], dtype=float) / 255
675    custom_cmap = LinearSegmentedColormap.from_list('BlueOrange', [blue, orange], N=256)
676
677    # Log‑ratio bounds and small epsilon to avoid log(0)
678    vmin, vmax = -2, 1
679    epsilon = 1e-6
680
681    # Loop over grids and slopes
682    for ig in range(ngrid):
683        for isl in range(nslope):
684            log_ratio_array = np.full((nz, ntime), np.nan, dtype=np.float32)
685
686            # Compute log10(dust/ice) profile at each time step
687            for t in range(ntime):
688                zmax = top_index[ig, t, isl]
689                if zmax <= 0:
690                    continue
691
692                h2o_profile = np.clip(h2o[ig, t, isl, :zmax], 0, None)
693                dust_profile = np.clip(dust[ig, t, isl, :zmax], 0, None)
694
695                with np.errstate(divide='ignore', invalid='ignore'):
696                    ratio_profile = np.where(
697                        h2o_profile > 0,
698                        dust_profile / h2o_profile,
699                        10**(vmax + 1)
700                    )
701                    log_ratio = np.log10(ratio_profile + epsilon)
702                    log_ratio = np.clip(log_ratio, vmin, vmax)
703
704                log_ratio_array[:zmax, t] = log_ratio
705
706            # Convert back to linear ratio and apply elevation mask
707            ratio_array = 10**log_ratio_array
708            ratio_display = ratio_array[elevation_mask, :]
709
710            # Plot
711            fig, ax = plt.subplots(figsize=(8, 6), dpi=150)
712            ax.set_title(f"Dust-to-Ice ratio over time (Grid point {ig+1}, Slope {isl+1})", fontweight='bold')
713            im = ax.imshow(
714                ratio_display,
715                aspect='auto',
716                extent=[date_time[0], date_time[-1], elev.min(), elev.max()],
717                origin='lower',
718                interpolation='nearest',
719                cmap='managua_r',
720                norm=LogNorm(vmin=10**vmin, vmax=10**vmax)
721            )
722            x_centers = np.linspace(date_time[0], date_time[-1], ratio_display.shape[1])
723            y_centers = np.linspace(elev.min(), elev.max(), ratio_display.shape[0])
724            attach_format_coord(ax, ratio_display, x_centers, y_centers, is_pcolormesh=False)
725
726            # Add colorbar with simplified ratio labels
727            cbar = fig.colorbar(im, ax=ax, orientation='vertical')
728            cbar.set_label('Dust / H₂O ice (ratio)')
729
730            # Define custom ticks and labels
731            ticks = [1e-2, 1e-1, 1, 1e1]
732            labels = ['1:100', '1:10', '1:1', '10:1']
733            cbar.set_ticks(ticks)
734            cbar.set_ticklabels(labels)
735
736            # Save figure
737            plt.tight_layout()
738            fname = os.path.join(
739                output_folder,
740                f"dust_to_ice_ratio_grid{ig+1}_slope{isl+1}.png"
741            )
742            fig.savefig(fname, dpi=150)
743
744
745def plot_strata_count_and_total_height(heights_data, date_time, output_folder="."):
746    """
747    For each grid point and slope, plot:
748      - Number of strata vs time
749      - Total deposit height vs time
750    """
751    ntime = len(heights_data)
752    nslope = len(heights_data[0])
753    ngrid = heights_data[0][0].shape[0]
754
755    for ig in range(ngrid):
756        for isl in range(nslope):
757            n_strata_t = np.zeros(ntime, dtype=int)
758            total_height_t = np.zeros(ntime, dtype=float)
759
760            for t_idx in range(ntime):
761                h_mat = heights_data[t_idx][isl]
762                raw_h = h_mat[ig, :]
763                valid_mask = (~np.isnan(raw_h)) & (raw_h != 0.0)
764                if np.any(valid_mask):
765                    h_valid = raw_h[valid_mask]
766                    n_strata_t[t_idx] = h_valid.size
767                    total_height_t[t_idx] = np.max(h_valid)
768
769            fig, axes = plt.subplots(2, 1, figsize=(8, 6), sharex=True)
770            fig.suptitle(
771                f"Strata count & total height over time for (Grid point {ig+1}, Slope {isl+1})",
772                fontsize=14,
773                fontweight='bold'
774            )
775
776            axes[0].plot(date_time, n_strata_t, marker='+', linestyle='-')
777            axes[0].set_ylabel("Number of strata")
778            axes[0].grid(True)
779
780            axes[1].plot(date_time, total_height_t, marker='+', linestyle='-')
781            axes[1].set_xlabel("Time (Mars years)")
782            axes[1].set_ylabel("Total height (m)")
783            axes[1].grid(True)
784
785            fig.tight_layout(rect=[0, 0, 1, 0.95])
786            fname = os.path.join(
787                output_folder, f"strata_count_height_ig{ig+1}_is{isl+1}.png"
788            )
789            fig.savefig(fname, dpi=150)
790
791
792def read_orbital_data(orb_file, martian_to_earth):
793    """
794    Read the .asc file containing obliquity, eccentricity and Ls p.
795    Columns:
796      0 = time in thousand Martian years
797      1 = obliquity (deg)
798      2 = eccentricity
799      3 = Ls p (deg)
800    Converts times to Earth years.
801    """
802    data = np.loadtxt(orb_file)
803    dates_mka = data[:, 0]
804    dates_yr = dates_mka * 1e3 / martian_to_earth
805    obliquity = data[:, 1]
806    eccentricity = data[:, 2]
807    lsp = data[:, 3]
808    return dates_yr, obliquity, eccentricity, lsp
809
810
811def plot_orbital_parameters(infofile, orb_file, date_time, output_folder="."):
812    """
813    Plot the evolution of obliquity, eccentricity and Ls p
814    versus simulated time.
815    """
816    # Read conversion factor from infofile
817    _, martian_to_earth = read_infofile(infofile)
818
819    # Read orbital data
820    dates_yr, obl, ecc, lsp = read_orbital_data(orb_file, martian_to_earth)
821
822    # Interpolate orbital parameters at simulation dates (date_time)
823    obl_interp = interp1d(dates_yr, obl, kind='linear', bounds_error=False, fill_value="extrapolate")(date_time)
824    ecc_interp = interp1d(dates_yr, ecc, kind='linear', bounds_error=False, fill_value="extrapolate")(date_time)
825    lsp_interp = interp1d(dates_yr, lsp, kind='linear', bounds_error=False, fill_value="extrapolate")(date_time)
826
827    # Plot
828    fig, axes = plt.subplots(3, 1, figsize=(8, 10), sharex=True)
829    fig.suptitle("Orbital Parameters vs Simulated Time", fontsize=14, fontweight='bold')
830
831    axes[0].plot(date_time, obl_interp, 'r+', linestyle='-')
832    axes[0].set_ylabel("Obliquity (°)")
833    axes[0].grid(True)
834
835    axes[1].plot(date_time, ecc_interp, 'b+', linestyle='-')
836    axes[1].set_ylabel("Eccentricity")
837    axes[1].grid(True)
838
839    axes[2].plot(date_time, lsp_interp, 'g+', linestyle='-')
840    axes[2].set_ylabel("Ls p (°)")
841    axes[2].set_xlabel("Time (Mars years)")
842    axes[2].grid(True)
843
844    plt.tight_layout(rect=[0, 0, 1, 0.96])
845    fname = os.path.join(output_folder, "orbital_parameters.png")
846    fig.savefig(fname, dpi=150)
847
848
849def main():
850    # 1) Get user inputs
851    folder_path, base_name, infofile, orbfile = get_user_inputs()
852
853    # 2) List and verify NetCDF files
854    files = list_netcdf_files(folder_path, base_name)
855    if not files:
856        print(f"No NetCDF files named \"{base_name}#.nc\" found in \"{folder_path}\".")
857        sys.exit(1)
858    print(f"> Found {len(files)} NetCDF file(s).")
859
860    # 3) Open one sample to get grid dimensions & coordinates
861    sample_file = files[0]
862    ngrid, nslope, longitude, latitude = open_sample_dataset(sample_file)
863    print(f"> ngrid  = {ngrid}, nslope = {nslope}")
864
865    # 4) Collect variable info + global min/max elevations
866    var_info, max_nb_str, min_base_elev, max_top_elev = collect_stratification_variables(files, base_name)
867    print(f"> max strata per slope = {max_nb_str}")
868    print(f"> min base elev = {min_base_elev:.3f} m, max top elev = {max_top_elev:.3f} m")
869
870    # 5) Load full datasets
871    datasets = load_full_datasets(files)
872
873    # 6) Extract stratification data
874    heights_data, raw_prop_arrays, ntime = extract_stratification_data(datasets, var_info, ngrid, nslope, max_nb_str)
875
876    # 7) Close datasets
877    for ds in datasets:
878        ds.close()
879
880    # 8) Normalize to fractions
881    frac_arrays = normalize_to_fractions(raw_prop_arrays)
882
883    # 9) Ask whether to include subsurface
884    show_subsurface = get_yes_no_input("Show subsurface layers?")
885    exclude_sub = not show_subsurface
886    if exclude_sub:
887        min_base_for_interp = 0.0
888        print("> Interpolating only elevations >= 0 m (surface strata).")
889    else:
890        min_base_for_interp = min_base_elev
891        print(f"> Interpolating full depth down to {min_base_elev:.3f} m.")
892
893    # 10) Prompt discretization step
894    dz = prompt_discretization_step(max_top_elev)
895
896    # 11) Build reference grid and interpolate
897    ref_grid, gridded_data, top_index = interpolate_data_on_refgrid(
898        heights_data, frac_arrays, min_base_for_interp, max_top_elev, dz, exclude_sub=exclude_sub
899    )
900
901    # 12) Read timestamps and conversion factor from infofile
902    date_time, martian_to_earth = read_infofile(infofile)
903    if date_time.size != ntime:
904        print(f"Warning: {date_time.size} timestamps vs {ntime} NetCDF files.")
905
906    # 13) Plot stratification data over time
907    plot_stratification_over_time(
908        gridded_data, ref_grid, top_index, heights_data, date_time,
909        exclude_sub=exclude_sub, output_folder="."
910    )
911    plot_stratification_rgb_over_time(
912        gridded_data, ref_grid, top_index, heights_data, date_time,
913        exclude_sub=exclude_sub, output_folder="."
914    )
915    plot_dust_to_ice_ratio_over_time(
916        gridded_data, ref_grid, top_index, heights_data, date_time,
917        exclude_sub=exclude_sub, output_folder="."
918    )
919    plot_strata_count_and_total_height(heights_data, date_time, output_folder=".")
920
921    # 14) Plot orbital parameters
922    plot_orbital_parameters(infofile, orbfile, date_time, output_folder=".")
923
924    # 15) Show all figures
925    plt.show()
926
927
928if __name__ == "__main__":
929    main()
930
Note: See TracBrowser for help on using the repository browser.