#! /usr/bin/env python
import os
from    netCDF4               import    Dataset
from	numpy		      import	*
import  numpy                 as        np
import  matplotlib.pyplot     as        mpl
from matplotlib.cm import get_cmap
import pylab
from matplotlib import ticker
import matplotlib.colors as colors
import datetime
from mpl_toolkits.basemap import Basemap, shiftgrid
from matplotlib.cm import get_cmap
from FV3_utils import *
from input import * # import name
print("Running "+os.path.basename(__file__))

############################
fa='sans-serif'
hfont = {'fontname':'Arial'}
mpl.rc('font',family=fa)
mpl.rc('pdf',fonttype=42)
font=30
cc=['k']
pal=get_cmap(name="rainbow")
norm=colors.LogNorm()
lvls=np.logspace(-6,-4,21)
norm=None #colors.LogNorm()

### Data
try:
    print("Plotting "+name+"_A.nc")
    nc1=Dataset(name+"_A.nc")
except:
    print("Plotting "+name+".nc")
    nc1=Dataset(name+".nc")
alt=getvar(nc1,"altitude")
lat=getvar(nc1,"latitude")
lon=getvar(nc1,"longitude")
# temp=switchlon(temp)

if lon[0]<0:
    lon=lon+180.

def plot_alt(altitude = 1):
    temp=nc1.variables["temperature"][:,:,:,:]
    numalt=getind(altitude,alt)
    temp=temp[:,numalt,:,:]
    temp=np.mean(temp,axis=0)

    min_t=temp.min()
    max_t=temp.max()
    lvls=np.linspace(min_t,max_t,21)

    ### Figure
    fig=mpl.figure(figsize=(15, 10))
    CF=mpl.contourf(lon, lat, temp,lvls,cmap=pal,norm=norm)
    cbar=mpl.colorbar(CF, shrink=1, ticks=lvls[::2],format="%1.1f")
    cbar.ax.set_title("[kg m$^{-2}$]",y=1.04,fontsize=font)
    for t in cbar.ax.get_yticklabels():
        t.set_fontsize(font)

    vect=lvls
    CS=mpl.contour(lon,lat,temp,lvls[:],colors='k',linewidths=0.5)
    lab=mpl.clabel(CS, inline=1, fontsize=20, fmt='%1.1f',inline_spacing=1)
    for l in lab:
        l.set_rotation(0)

    mpl.grid()
    mpl.title(f"Temperatures @ z={altitude}km",fontsize=font)
    mpl.ylabel(r'Latitude',labelpad=10,fontsize=font)
    mpl.xlabel('Longitude',labelpad=10, fontsize=font)
    pylab.ylim([-90,90])
    yticks=np.linspace(-90,90,13)
    pylab.xlim([0,360])
    xticks=np.linspace(0,360,7)
    mpl.yticks(yticks,fontsize=font)
    mpl.xticks(xticks,fontsize=font)
    output=f"maptemp{altitude}"
    mpl.savefig(output,bbox_inches='tight',dpi=70)
    print(f"Saved {output}")
    #mpl.show()


plot_alt(1)
plot_alt(5)
plot_alt(20)
plot_alt(50)
plot_alt(100)