####################################################################################
### Python script to output the stratification data from the "startpem.nc" files ###
####################################################################################

import netCDF4 as nc
import numpy as np
import matplotlib.pyplot as plt
import sys

##############################
### Parameters to fill in
filename = 'startpem3.nc' # File name
igrid = 1  # Grid point
islope = 1 # Slope number
istr = 4   # Stratum number
##############################

### Open the NetCDF file
nc_file = nc.Dataset(filename,'r')

### Get the dimensions
Time = len(nc_file.dimensions['Time'])
ngrid = len(nc_file.dimensions['physical_points'])
nslope = len(nc_file.dimensions['nslope'])
nb_str_max = len(nc_file.dimensions['nb_str_max'])
if igrid > ngrid or igrid < 1:
    sys.exit("Asked grid point is not possible!.")
if islope > nslope or islope < 1:
    sys.exit("Asked slope number is not possible!")
if istr > nb_str_max or istr < 1:
   sys.exit("Asked stratum number is not possible!")

### Get the stratification properties
stratif_thickness = []
stratif_top_elevation = []
stratif_co2ice_volfrac = []
stratif_h2oice_volfrac = []
stratif_dust_volfrac = []
stratif_air_volfrac = []
for i in range(1,nslope + 1):
    stratif_thickness.append(nc_file.variables['stratif_slope' + str(i).zfill(2) + '_thickness'][:])
    stratif_top_elevation.append(nc_file.variables['stratif_slope' + str(i).zfill(2) + '_top_elevation'][:])
    stratif_co2ice_volfrac.append(nc_file.variables['stratif_slope' + str(i).zfill(2) + '_co2ice_volfrac'][:])
    stratif_h2oice_volfrac.append(nc_file.variables['stratif_slope' + str(i).zfill(2) + '_h2oice_volfrac'][:])
    stratif_dust_volfrac.append(nc_file.variables['stratif_slope' + str(i).zfill(2) + '_dust_volfrac'][:])
    stratif_air_volfrac.append(nc_file.variables['stratif_slope' + str(i).zfill(2) + '_air_volfrac'][:])

### Display the data
igrid = igrid - 1
islope = islope - 1
istr = istr - 1
labels = 'CO2 ice', 'H2O ice', 'Dust', 'Air'
contents = stratif_co2ice_volfrac[islope][0,:,igrid], stratif_h2oice_volfrac[islope][0,:,igrid], stratif_dust_volfrac[islope][0,:,igrid], stratif_air_volfrac[islope][0,:,igrid]
x = np.zeros([4,len(stratif_top_elevation[islope][0,:,igrid]) + 1])
y = np.zeros(len(stratif_top_elevation[islope][0,:,igrid]) + 1)
y[0] = stratif_top_elevation[islope][0,0,:] - stratif_thickness[islope][0,0,:]
y[1:] = stratif_top_elevation[islope][0,:,igrid]
for i in range(len(stratif_top_elevation[islope][0,:,igrid])):
    x[0,1 + i] = stratif_co2ice_volfrac[islope][0,i,igrid]
    x[1,1 + i] = stratif_h2oice_volfrac[islope][0,i,igrid]
    x[2,1 + i] = stratif_dust_volfrac[islope][0,i,igrid]
    x[3,1 + i] = stratif_air_volfrac[islope][0,i,igrid]
x[:,0] = x[:,1]

# Simple multiple subplots for a layering
fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4)
fig.suptitle('Volume fractions [m3/m3] in the layering')
ax1.plot(contents[0],stratif_top_elevation[islope][0,:,igrid])
ax2.plot(contents[1],stratif_top_elevation[islope][0,:,igrid])
ax3.plot(contents[2],stratif_top_elevation[islope][0,:,igrid])
ax4.plot(contents[3],stratif_top_elevation[islope][0,:,igrid])
ax1.set_title(labels[0])
ax2.set_title(labels[1])
ax3.set_title(labels[2])
ax4.set_title(labels[3])

# Pie chart for a stratum
fig, ax = plt.subplots(figsize = (6, 3),subplot_kw = dict(aspect = "equal"))
def func(pct,allvals):
    absolute = int(np.round(pct/100.*np.sum(allvals)))
    return f"{pct:.1f}%\n({absolute:d} m3/m3)"
wedges, texts, autotexts = ax.pie(x[:,istr + 1],autopct = lambda pct: func(pct,x[:,istr + 1]),textprops = dict(color = "w"))
ax.legend(wedges,labels,title = "Content",loc = "center left",bbox_to_anchor = (1, 0, 0.5, 1))
plt.setp(autotexts,size = 8,weight = "bold")
ax.set_title("Content of the stratum " + str(istr + 1))

# Stackplot for a layering
fig, ax = plt.subplots()
ax.fill_betweenx(y,0,x[0,:],label = labels[0],step = 'pre')
ax.fill_betweenx(y,x[0,:],sum(x[0:2,:]),label = labels[1],step = 'pre')
ax.fill_betweenx(y,sum(x[0:2,:]),sum(x[0:3,:]),label = labels[2],step = 'pre')
ax.fill_betweenx(y,sum(x[0:3,:]),sum(x),label = labels[3],step = 'pre')
plt.vlines(x = 0.,ymin = y[0],ymax = y[len(y) - 1],color = 'k',linestyle = '-')
plt.vlines(x = 1.,ymin = y[0],ymax = y[len(y) - 1],color = 'k',linestyle = '-')
for i in range(len(y)):
    plt.hlines(y = y[i],xmin = 0.0,xmax = 1.0,color = 'k',linestyle = '--')
ax.set_title("Layering")
plt.xlabel("Volume fraction [m3/m3]")
plt.ylabel("Elevation [m]")
ax.legend(loc = 'center left',bbox_to_anchor = (1,0.5))

plt.show()

### Close the NetCDF file
nc_file.close()
