####################################################################################
### Python script to output the stratification data from the "startpem.nc" files ###
####################################################################################

import netCDF4 as nc
import numpy as np
import matplotlib.pyplot as plt
import sys
import os.path

##############################
### Parameters to fill in
filename = 'startpem.nc' #'starts/restartpem226.nc' # File name
igrid = 1  # Grid point
islope = 1 # Slope number
istr = 4   # Stratum number
##############################


####################################################################################
### Open the NetCDF file
if os.path.isfile(filename):
    nc_file = nc.Dataset(filename,'r')
else:
    sys.exit('The file \"' + filename + '\" does not exist!')

### Get the dimensions
Time = len(nc_file.dimensions['Time'])
ngrid = len(nc_file.dimensions['physical_points'])
nslope = len(nc_file.dimensions['nslope'])
nb_str_max = len(nc_file.dimensions['nb_str_max'])
if igrid > ngrid or igrid < 1:
    sys.exit('Asked grid point is not possible!')
if islope > nslope or islope < 1:
    sys.exit('Asked slope number is not possible!')
if istr > nb_str_max or istr < 1:
    sys.exit('Asked stratum number is not possible!')

### Get the stratification properties
stratif_thickness = []
stratif_top_elevation = []
stratif_co2ice_volfrac = []
stratif_h2oice_volfrac = []
stratif_dust_volfrac = []
stratif_air_volfrac = []
for i in range(1,nslope + 1):
    stratif_thickness.append(nc_file.variables['stratif_slope' + str(i).zfill(2) + '_thickness'][:])
    stratif_top_elevation.append(nc_file.variables['stratif_slope' + str(i).zfill(2) + '_top_elevation'][:])
    stratif_co2ice_volfrac.append(nc_file.variables['stratif_slope' + str(i).zfill(2) + '_co2ice_volfrac'][:])
    stratif_h2oice_volfrac.append(nc_file.variables['stratif_slope' + str(i).zfill(2) + '_h2oice_volfrac'][:])
    stratif_dust_volfrac.append(nc_file.variables['stratif_slope' + str(i).zfill(2) + '_dust_volfrac'][:])
    stratif_air_volfrac.append(nc_file.variables['stratif_slope' + str(i).zfill(2) + '_air_volfrac'][:])

### Display the data
igrid = igrid - 1
islope = islope - 1
labels = 'CO2 ice', 'H2O ice', 'Dust', 'Air'
contents = np.zeros([4,len(stratif_top_elevation[islope][0,:,igrid]) + 1])
height = np.zeros(len(stratif_top_elevation[islope][0,:,igrid]) + 1)
height[0] = stratif_top_elevation[islope][0,0,:] - stratif_thickness[islope][0,0,:]
height[1:] = stratif_top_elevation[islope][0,:,igrid]
for i in range(len(stratif_top_elevation[islope][0,:,igrid])):
    contents[0,1 + i] = stratif_co2ice_volfrac[islope][0,i,igrid]
    contents[1,1 + i] = stratif_h2oice_volfrac[islope][0,i,igrid]
    contents[2,1 + i] = stratif_dust_volfrac[islope][0,i,igrid]
    contents[3,1 + i] = stratif_air_volfrac[islope][0,i,igrid]
contents[:,0] = contents[:,1]

# Simple subplots for a layering
fig, (ax1, ax2, ax3, ax4) = plt.subplots(1,4,layout = 'constrained',sharey = True)
fig.suptitle('Simple content profiles for the layering')
ax1.step(contents[0,:],height,where = 'post')
ax2.step(contents[1,:],height,where = 'post')
ax3.step(contents[2,:],height,where = 'post')
ax4.step(contents[3,:],height,where = 'post')
ax1.set_ylabel('Elevation [m]')
ax1.set_xlabel('Volume fraction [m3/m3]')
ax2.set_xlabel('Volume fraction [m3/m3]')
ax3.set_xlabel('Volume fraction [m3/m3]')
ax4.set_xlabel('Volume fraction [m3/m3]')
ax1.set_title(labels[0])
ax2.set_title(labels[1])
ax3.set_title(labels[2])
ax4.set_title(labels[3])
plt.savefig('layering_simpleprofiles.png')

# Content profiles for a layering
plt.figure()
plt.step(contents[0,:],height,where = 'post',color = 'r',label = labels[0])
#plt.plot(contents[0,:],height,'o--',color = 'r',alpha = 0.3)
plt.step(contents[1,:],height,where = 'post',color = 'b',label = labels[1])
#plt.plot(contents[1,:],height,'o--',color = 'b',alpha = 0.3)
plt.step(contents[2,:],height,where = 'post',color = 'y',label = labels[2])
#plt.plot(contents[2,:],height,'o--',color = 'y',alpha = 0.3)
plt.step(contents[3,:],height,where = 'post',color = 'g',label = labels[3])
#plt.plot(contents[3,:],height,'o--',color = 'g',alpha = 0.3)
plt.grid(axis='x', color='0.95')
plt.grid(axis='y', color='0.95')
plt.xlim(0,1)
plt.xlabel('Volume fraction [m3/m3]')
plt.ylabel('Elevation [m]')
plt.title('Content profiles for the layering')
plt.savefig('layering_simpleprofiles.png')

# Stack content profiles for a layering
fig, ax = plt.subplots()
ax.fill_betweenx(height,0,contents[0,:],label = labels[0],color = 'r',step = 'pre')
ax.fill_betweenx(height,contents[0,:],sum(contents[0:2,:]),label = labels[1],color = 'b',step = 'pre')
ax.fill_betweenx(height,sum(contents[0:2,:]),sum(contents[0:3,:]),label = labels[2],color = 'y',step = 'pre')
ax.fill_betweenx(height,sum(contents[0:3,:]),sum(contents),label = labels[3],color = 'g',step = 'pre')
plt.vlines(x = 0.,ymin = height[0],ymax = height[len(height) - 1],color = 'k',linestyle = '-')
plt.vlines(x = 1.,ymin = height[0],ymax = height[len(height) - 1],color = 'k',linestyle = '-')
for i in range(len(height)):
    plt.hlines(y = height[i],xmin = 0.0,xmax = 1.0,color = 'k',linestyle = '--')
ax.set_title('Stack content profiles for the layering')
plt.xlabel('Volume fraction [m3/m3]')
plt.ylabel('Elevation [m]')
ax.legend(loc = 'center left',bbox_to_anchor = (1,0.5))
plt.savefig('layering_stackprofiles.png')

plt.show()

### Close the NetCDF file
nc_file.close()
