import numpy as np
import math
pi=math.pi
import matplotlib 
matplotlib.use('Agg')
from input import *
from display_netcdf import *

def fmtsci(x, pos):
    a, b = '{:.1e}'.format(x).split('e')
    b = int(b)
    return r'${} \times 10^{{{}}}$'.format(a, b)

def calcaire(lat,lon):
 # Calculate the area of each grid point assuming regular grid
 nblat=np.size(lat)
 nblon=np.size(lon)
 radius=3390.
 # Sup and Inf boundary of latitudes
 latsup=np.zeros(nblat,dtype='f')
 latinf=np.zeros(nblat,dtype='f')
 # Area of each latitudinal band
 peri=np.zeros(nblat,dtype='f')
 alpha=np.zeros(nblat,dtype='f')

 for i in range(nblat-1):
    latsup[i+1]=(lat[i]+lat[i+1])/2.
    latinf[i]=(lat[i+1]+lat[i])/2.
 latsup[0]=-90 #lat[0]
 latinf[nblat-1]=90 #lat[nblat-1]

 print(latsup)
 print(latinf)
 # Area of a latitudinal band:
 for i in range(nblat):
    alpha[i]=abs(np.sin(latsup[i]*pi/180.)-np.sin(latinf[i]*pi/180.))*2.*radius**2*pi
 airetot=sum(alpha)
 return alpha,airetot

def calcairepluto(lat,lon):
 # Calculate the area of each grid point assuming regular grid
 nblat=np.size(lat)
 nblon=np.size(lon)
 radius=1187.
 # Sup and Inf boundary of latitudes
 latsup=np.zeros(nblat,dtype='f')
 latinf=np.zeros(nblat,dtype='f')
 # Area of each latitudinal band
 alpha=np.zeros(nblat,dtype='f')

 for i in range(nblat-1):
    latsup[i+1]=(lat[i]+lat[i+1])/2.
    latinf[i]=(lat[i+1]+lat[i])/2.
 latsup[0]=-90 #lat[0]
 latinf[nblat-1]=90 #lat[nblat-1]
 # Area of a latitudinal band:
 for i in range(nblat):
    alpha[i]=abs(np.sin(latsup[i]*pi/180.)-np.sin(latinf[i]*pi/180.))*2.*radius**2*pi
 airetot=sum(alpha)
 return alpha,airetot

def switchlon(arr):
# changer les longitudes pour mettre TR au centre
 vec=np.shape(arr)
 myvar=np.zeros(vec,dtype='f')
 # i lat : pas de changement
 # j lon :
 for i in range(vec[0]):
    for j in range(vec[1]):
        if j < int(vec[1]/2.) :
           myvar[i,j]=arr[i,j+int(vec[1]/2)]
        else:
           myvar[i,j]=arr[i,j-int(vec[1]/2)]
 return myvar

def switchlon3D(arr):
# changer les longitudes pour mettre TR au centre
 vec=np.shape(arr)
 myvar=np.zeros(vec,dtype='f')
 # i lat : pas de changement
 # j lon :
 for i in range(vec[1]):
    for j in range(vec[2]):
        if j < int(vec[2]/2.) :
           myvar[:,i,j]=arr[:,i,j+int(vec[2]/2)]
        else:
           myvar[:,i,j]=arr[:,i,j-int(vec[2]/2)]
 return myvar

def extractpal(pal,lev):
    import matplotlib as mpl
    cmap = mpl.cm.get_cmap(pal)
    #print '\nlevpal=',lev
    #rgb is a vector with nbl colors
    nbl=np.size(lev)
    rgb=[cmap(0)[0:3]]
    for i in range(nbl-1):
     #aaa=((lev[i+1]+lev[i])/2.-lev[0])/(lev[-1]-lev[0])
     aaa=lev[i+1]
     #print 'index weight / color =',aaa,cmap(aaa)[0:3]
     rgb.append(cmap(aaa)[0:3])
    #print '\nsize rgb=',np.shape(rgb)
    return rgb

def make_colormap(seq):
    """Return a LinearSegmentedColormap
    seq: a sequence of floats and RGB-tuples. The floats should be increasing
    and in the interval (0,1).
    """
    import matplotlib.colors as colors

    seq = [(None,) * 3, 0.0] + list(seq) + [1.0, (None,) * 3]
    print('\nNew sequence=',seq)
    cdict = {'red': [], 'green': [], 'blue': []}
    for i, item in enumerate(seq):
        if isinstance(item, float):
            r1, g1, b1 = seq[i - 1]
            r2, g2, b2 = seq[i + 1]
            #print "i,item=",i,item,r1, g1, b1,r2, g2, b2
            cdict['red'].append([item, r1, r2])
            cdict['green'].append([item, g1, g2])
            cdict['blue'].append([item, b1, b2])
    #print 'cdict=',cdict
    return colors.LinearSegmentedColormap('CustomMap', cdict)

def getcol(lev,myc):
  ## myc : palette color, dimension n
  ## Lev and myc must be of dimension n and n
  nbl=np.size(lev)
  # Compute the nivels for each tick
  onticks=[0] #[lev[0]]
  #onticks2=np.linspace(0,1,14)
  for i in range(nbl-1):
    aaa=(lev[i+1]-lev[0])/(lev[-1]-lev[0])
    onticks.append(aaa)
  print('\nlevels asked=',lev)
  print('\nlevels ticks=',onticks)
  # compute the seq of color
  gg=[myc[0]]
  for tt in range(nbl-2):
    gg.append(onticks[tt+1])
    gg.append(myc[tt+1])
  #gg.append(onticks[nbl])
  #print '\nsequence=',gg
  # cmap
  rvb=make_colormap(gg)
  return rvb

def name_regions():
    # lon,lat,name,rotation,font
    font=18
    i=2
    j=4
    mylist=[]
    mylist.append([72,-43,'Hellas',0,font,'center'])
    mylist.append([90,-15,'Tyrrhena\nTerra',0,font,'center'])
    mylist.append([110,-28,'Hesperia\nPlanum',0,font,'center'])
    mylist.append([115,-50,'Promethei\nTerra',0,font,'center'])
    mylist.append([155,-45,'Cimmeria\nTerra',0,font,'center'])
    mylist.append([195,-40,'Sirenum\nTerra',0,font,'center'])
    mylist.append([63,8,'Syrtis\nMajor',0,font-j,'center'])
    #mylist.append([220,-10,'Arsia\nMons',0,font-j,'center'])
    mylist.append([242,40,'Alba\nPatera',0,font-j,'center'])
    mylist.append([285,35,'Tempe\nTerra',0,font-i,'center'])
    mylist.append([41,-3,'Sabaea\nTerra',0,font,'center'])
    mylist.append([150,12,'Elysium Planitia',0,font,'center'])
    mylist.append([120,50,'Utopia Planitia',0,font,'center'])
    #mylist.append([280,-8,'Valles Marineris',-10,font-3,'center'])
    mylist.append([220,-25,'Daedalia\nPlanum',0,font,'center'])
    mylist.append([232,-40,'Icaria\nPlanum',0,font-j,'center'])
    mylist.append([330,45,'Acidalia\nPlanitia',0,font,'center'])
    mylist.append([315,25,'Chryse\nPlanitia',0,font-i,'center'])
    mylist.append([90,12,'Isidis\nPlanitia',0,font-j,'center'])
    mylist.append([185,25,'Amazonis\nPlanitia',0,font,'center'])
    mylist.append([190,50,'Arcadia\nPlanitia',0,font,'center'])
    mylist.append([20,30,'Arabia\nTerra',0,font,'center'])
    mylist.append([17,-45,'Noachis\nTerra',0,font,'center'])
    mylist.append([260,5,'Tharsis',0,font-i,'center'])
    mylist.append([315,0,'Xanthe\nTerra',0,font-i,'center'])
    mylist.append([270,-25,'Solis',0,font-j,'center'])
    mylist.append([255,-10,'Syria',0,font-j,'center'])
    mylist.append([270,-15,'Sinai',0,font-j,'center'])
    mylist.append([315,-50,'Argyre',0,font,'center'])
    mylist.append([270,-55,'Aonia\nTerra',0,font,'center'])
    #mylist.append([270,-45,'Thaumasia\nHighlands',0,font-i,'center'])
    #mylist.append([250,-15,'Claritas Fossae',-80,font-3,'center'])

    return mylist


def getwinds(lon,lat,vecx,vecy,svx,svy,scale,width,val):
          import matplotlib.pyplot as mpl
          angle='uv'       # 'xy'
          color='black'    # arrow color
          pivot='mid'      # arrow around middle of box. Alternative : tip
          linewidths=0.5   # epaisseur contour arrow
          edgecolors='k'   # couleur contour arrow

    #  *scale*: [ *None* | float ]
    #  Data units per arrow length unit, e.g., m/s per plot width; a smaller
    #  scale parameter makes the arrow longer.  If *None*, a simple
    #  autoscaling algorithm is used, based on the average vector length
    #  and the number of vectors.  The arrow length unit is given by
    #  the *scale_units* parameter

    #  *scale_units*: *None*, or any of the *units* options.
    #  For example, if *scale_units* is 'inches', *scale* is 2.0, and
    #  ``(u,v) = (1,0)``, then the vector will be 0.5 inches long.
    #  If *scale_units* is 'width', then the vector will be half the width
    #  of the axes.

    #  If *scale_units* is 'x' then the vector will be 0.5 x-axis
    #  units.  To plot vectors in the x-y plane, with u and v having
    #  the same units as x and y, use
    #  "angles='xy', scale_units='xy', scale=1".

          x, y = np.meshgrid(lon,lat)
          q = mpl.quiver( x[::svy,::svx],y[::svy,::svx],vecx[::svy,::svx],vecy[::svy,::svx],angles=angle,color=color,pivot=pivot,scale=scale,width=width,linewidths=linewidths,edgecolors=edgecolors)

          # make vector key.
          #keyh = 1.025 ; keyv = 1.05 # upper right corner over colorbar
          keyh = 0.95 ; keyv = 1.03
        #   keyh = 0.03 ; keyv = 1.07
          #keyh = -0.03 ; keyv = 1.08 # upper left corner
          labelpos='E'    # position label compared to arrow : N S E W
          p = mpl.quiverkey(q,keyh,keyv,val,str(val)+' m/s',fontproperties={'size': 28,'weight': 'bold'},color='black',labelpos=labelpos,labelsep = 0.07)

###############################################################################
###############################################################################

DEFAULT = object()
## function : get local times array
def tshift2(array, lon=DEFAULT, timex=DEFAULT, nsteps_out=DEFAULT):
    #================================================================
    #
    #    Conversion to uniform local time
    #   Assume longitude in the first dimension and time in the last dimension
    #
    #   Interpolate onto a new time grid with nsteps_out samples per sol
    #   New time:   [ 0 ... nn-1/nsteps_out ]*24
    #   Default:   nsteps_out = length(timex)
    #
    #    timex should be in units of hours  (only timex(1) is actually relevant)

        if np.shape(array) == len(array):
                print('Need longitude and time dimensions')
                return

        dims=np.shape(array)  #get dimensions of array
        end=len(dims)-1
        id=dims[0]   #number of longitudes in file
        if lon is DEFAULT:
                lon = np.linspace(0.,360.,num=id,endpoint=False)
        if timex is DEFAULT:
                nsteps=dims[end]
                timex = np.linspace(0.,24.,num=nsteps,endpoint=False)
        else:
                nsteps=len(timex)


        nsf = np.float_(nsteps)

        timex = np.squeeze(timex)

        if timex.max() <= 1.:   #if timex is in fractions of day
                timex = 24.*timex

        if nsteps_out is DEFAULT:
                nsteps_out = nsteps

        #Assuming time is last dimension, check if it is local time timex
        #If not, reshape the array into (stuff, days, local time)
        print('dims,nsteps=',dims[end],nsteps)
        if dims[end] != nsteps:
                ndays = dims[end] / nsteps
                print('ndays=',ndays,dims[end])
                if ndays*nsteps != dims[end]:
                        print('Time dimensions do not conform')
                        return
                print('dims=',dims,end)
                array = np.reshape(array,(dims[0,end-1], nsteps, ndays))
                newdims=np.linspace(len(dims+1),dtype=np.int32)
                newdims[len(dims)-1]=len(dims)
                newdims[len(dims)]=len(dims)-1
                array = np.transpose(array,newdims)

        dims=np.shape(array) #get new dims of array if reshaped


        if len(dims) > 2:
                recl = np.prod(dims[1:len(dims)-1])
        else:
                recl=1


        array=np.reshape(array,(id,recl,nsteps))
        #create output array
        narray=np.zeros((id,recl,nsteps_out))

        dt_samp = 24.0/nsteps      #   Time increment of input data (in hours)
        dt_save = 24.0/nsteps_out  #   Time increment of output data (in hours)

        #             calculate interpolation indices
        # convert east longitude to equivalent hours
        xshif = 24.0*lon/360.
        kk=np.where(xshif < 0)
        xshif[kk]=xshif[kk]+24.

        fraction = np.zeros((id,nsteps_out))
        imm = np.zeros((id,nsteps_out))
        ipp = np.zeros((id,nsteps_out))

        for nd in range(nsteps_out):
                dtt = nd*dt_save - xshif - timex[0] + dt_samp
                #      insure that data local time is bounded by [0,24] hours
                kk = np.where(dtt < 0.)
                dtt[kk] = dtt[kk] + 24.

                im = np.floor(dtt/dt_samp)    #  this is index into the data aray
                fraction[:,nd] = dtt-im*dt_samp
                kk = np.where(im < 0.)
                im[kk] = im[kk] + nsf

                ipa = im + 1.
                kk = np.where(ipa >= nsf)
                ipa[kk] = ipa[kk] - nsf

                imm[:,nd] = im[:]
                ipp[:,nd] = ipa[:]
        fraction = fraction / dt_samp # assume uniform tinc between input data samples
        #           Now carry out the interpolation
        for nd in range(nsteps_out):    #   Number of output time levels
                for i in range(id):         #   Number of longitudes
                        im = np.int(imm[i,nd])
                        ipa= np.int(ipp[i,nd])
                        frac = fraction[i,nd]

                        narray[i,:,nd] = (1.-frac)*array[i,:,im] + frac*array[i,:,ipa]

        narray = np.squeeze(narray)
        ndimsfinal=np.zeros(len(dims),dtype=int)
        for nd in range(end):
                ndimsfinal[nd]=dims[nd]
        ndimsfinal[end]=nsteps_out
        narray = np.reshape(narray,ndimsfinal)

        return narray

####################################################################
####################################################################
def getareaff2(nc,latabs):
  """
  Compute area over latitude
  """

  lat0=nc.variables['lat'][:]
  lon0=nc.variables['lon'][:]
  nblat=np.size(lat0)
  nblon=np.size(lon0)

  # Lat sup / Lat inf
  latsup=np.zeros(nblat,dtype='f')
  latinf=np.zeros(nblat,dtype='f')
  for i in range(nblat-1):
    latsup[i+1]=(lat0[i]+lat0[i+1])/2.
    latinf[i]=(lat0[i+1]+lat0[i])/2.
  latsup[0]=lat0[0]
  latinf[nblat-1]=lat0[nblat-1]

  # Area
  rad=3390.
  area=np.zeros(nblat,dtype='f')
  alpha=np.zeros(nblat,dtype='f')
  for i in range(nblat):
    if abs(lat0[i])<latabs:
     alpha[i]=abs(np.sin(latsup[i]*pi/180.)-np.sin(latinf[i]*pi/180.))*rad
     area[i]=alpha[i]*rad*2.*pi
  areatot=sum(area)
  return area,areatot,nblon,nblat

def getareaff(nc):
  """
  Compute area over latitude
  """

  lat0=nc.variables['lat'][:]
  lon0=nc.variables['lon']
  nblat=np.size(lat0)
  nblon=np.size(lon0)

  # Lat sup / Lat inf
  latsup=np.zeros(nblat,dtype='f')
  latinf=np.zeros(nblat,dtype='f')
  for i in range(nblat-1):
    latsup[i+1]=(lat0[i]+lat0[i+1])/2.
    latinf[i]=(lat0[i+1]+lat0[i])/2.
  latsup[0]=lat0[0]
  latinf[nblat-1]=lat0[nblat-1]

  # Area
  rad=3390.
  area=np.zeros(nblat,dtype='f')
  alpha=np.zeros(nblat,dtype='f')
  for i in range(nblat):
    alpha[i]=abs(np.sin(latsup[i]*pi/180.)-np.sin(latinf[i]*pi/180.))*rad
    area[i]=alpha[i]*rad*2.*pi
  areatot=sum(area)
  return area,areatot,nblon,nblat

def getarea_r(lat0,lon0,rad):
  """
  Compute area over latitude
  """

  nblat=np.size(lat0)
  nblon=np.size(lon0)

  # Lat sup / Lat inf
  latsup=np.zeros(nblat,dtype='f')
  latinf=np.zeros(nblat,dtype='f')
  for i in range(nblat-1):
    latsup[i+1]=(lat0[i]+lat0[i+1])/2.
    latinf[i]=(lat0[i+1]+lat0[i])/2.
  latsup[0]=lat0[0]
  latinf[nblat-1]=lat0[nblat-1]

  # Area
  area=np.zeros(nblat,dtype='f')
  alpha=np.zeros(nblat,dtype='f')
  for i in range(nblat):
    alpha[i]=abs(np.sin(latsup[i]*pi/180.)-np.sin(latinf[i]*pi/180.))*rad
    area[i]=alpha[i]*rad*2.*pi
  areatot=sum(area)
  return area,areatot,nblon,nblat


def fms_press_calc(psfc,ak,bk,lev_type='full'):
    """
    Return the 3d pressure field from the surface pressure and the ak/bk coefficients.

    Args:
        psfc: the surface pressure in [Pa] or array of surface pressures 1D or 2D, or 3D (if time dimension)
        ak: 1st vertical coordinate parameter
        bk: 2nd vertical coordinate parameter
        lev_type: "full" (centers of the levels) or "half" (layer interfaces)
                  Default is "full"
    Returns:
        The 3D pressure field at the full PRESS_f(:,:,Nk-1) or half levels PRESS_h(:,:,Nk) in [Pa]
    --- 0 --- TOP        ========  p_half
    --- 1 ---
                         --------  p_full

                         ========  p_half
    ---Nk-1---           --------  p_full
    --- Nk --- SFC       ========  p_half
                        / / / / /

    *NOTE*
        Some litterature uses pk (pressure) instead of ak.
        With p3d=  ps*bk +pref*ak  vs the current  p3d= ps*bk +ak


    """

    Nk=len(ak)
    # If psfc is a float (e.g. psfc=7.) make it a one element array (e.g. psfc=[7.])
    if len(np.atleast_1d(psfc))==1: psfc=np.array([np.squeeze(psfc)])

    #Flatten the pressure array to generalize to N dimensions
    psfc_flat=psfc.flatten()

    # Expands the dimensions vectorized calculations:
    psfc_v=np.repeat(psfc_flat[:,np.newaxis],Nk, axis=1)    #(Np) ->(Np,Nk)
    ak_v=np.repeat(ak[np.newaxis,:],len(psfc_flat), axis=0) #(Nk) ->(Np,Nk)
    bk_v=np.repeat(bk[np.newaxis,:],1, axis=0)              #(Nk) ->(1, Nk)

    #Pressure at half level = layers interfaces. The size of z axis is Nk
    PRESS_h=psfc_v*bk_v+ak_v

    #Pressure at full levels = centers of the levels. The size of z axis is Nk-1
    PRESS_f=np.zeros((len(psfc_flat),Nk-1))

    #Top layer (1st element is i=0 in Python)
    if ak[0]==0 and bk[0]==0:
        PRESS_f[:,0]= 0.5*(PRESS_h[:,0]+PRESS_h[:,1])
    else:
        PRESS_f[:,0] = (PRESS_h[:,1]-PRESS_h[:,0])/np.log(PRESS_h[:,1]/PRESS_h[:,0])

    #Rest of the column (i=1..Nk).
    #[2:] goes from the 3rd element to Nk and [1:-1] goes from the 2nd element to Nk-1
    PRESS_f[:,1:]= (PRESS_h[:,2:]-PRESS_h[:,1:-1])/np.log(PRESS_h[:,2:]/PRESS_h[:,1:-1])

    # Reshape PRESS(:,Nk) to the original pressure shape PRESS(:,:,Nk) (resp. Nk-1)

    if lev_type=="full":
        new_dim_f=np.append(psfc.shape,Nk-1)
        return np.squeeze(PRESS_f.reshape(new_dim_f))
    elif lev_type=="half" :
        new_dim_h=np.append(psfc.shape,Nk)
        return np.squeeze(PRESS_h.reshape(new_dim_h))
    else:
        raise Exception("""Pressure levels type not recognized in press_lev(): use 'full' or 'half' """)

def fms_Z_calc(psfc,ak,bk,T,topo=0.,lev_type='full'):
    """
    Return the 3d altitude field in [m]

    Args:
        psfc: the surface pressure in [Pa] or array of surface pressures 1D or 2D, or 3D (if time dimension)
        ak: 1st vertical coordinate parameter
        bk: 2nd vertical coordinate parameter
        T : the air temperature profile, 1D array (for a single grid point) or 2D, 3D 4D
        topo: the surface elevation, same dimension as psfc
        lev_type: "full" (centers of the levels) or "half" (layer interfaces)
                  Default is "full"
    Returns:
        The layers' altitude  at the full Z_f(:,:,Nk-1) or half levels Z_h(:,:,Nk) in [m]

    --- 0 --- TOP        ========  z_half
    --- 1 ---
                         --------  z_full

                         ========  z_half
    ---Nk-1---           --------  z_full
    --- Nk --- SFC       ========  z_half
                        / / / / /


    *NOTE*
        Calculation is derived from ./atmos_cubed_sphere_mars/Mars_phys.F90:
        We have dp/dz = -rho g => dz= dp/(-rho g) and rho= p/(r T)  => dz=rT/g *(-dp/p)
        Let's definethe log-pressure u as u = ln(p). We have du = du/dp *dp = (1/p)*dp =dp/p

        Finally , we have dz for the half layers:  dz=rT/g *-(du) => dz=rT/g *(+dp/p)   with N the layers defined from top to bottom.
    """
    g=3.72 #acc. m/s2
    r_co2= 191.00 # kg/mol
    Nk=len(ak)
    #===get the half and full pressure levels from fms_press_calc==

    PRESS_f=fms_press_calc(psfc,ak,bk,'full')
    PRESS_h=fms_press_calc(psfc,ak,bk,'half')

    # If psfc is a float, turn it into a one-element array:
    if len(np.atleast_1d(psfc))==1:
        psfc=np.array([np.squeeze(psfc)])
        topo=np.array([np.squeeze(topo)])

    psfc_flat=psfc.flatten()
    topo_flat=topo.flatten()

    #  reshape arrays for vector calculations and compute the log pressure====

    PRESS_h=PRESS_h.reshape((len(psfc_flat),Nk))
    PRESS_f=PRESS_f.reshape((len(psfc_flat),Nk-1))
    T=T.reshape((len(psfc_flat),Nk-1))

    logPPRESS_h=np.log(PRESS_h)

    #===Initialize the output arrays===
    Z_f=np.zeros((len(psfc_flat),Nk-1))
    Z_h=np.zeros((len(psfc_flat),Nk))

    #First helf layer is equal to the surface elevation

    Z_h[:,-1] = topo_flat

    # Other layes, from the bottom-ip:
    for k in range(Nk-2,-1,-1):
        Z_h[:,k] = Z_h[:,k+1]+(r_co2*T[:,k]/g)*(logPPRESS_h[:,k+1]-logPPRESS_h[:,k])
        Z_f[:,k] = Z_h[:,k+1]+(r_co2*T[:,k]/g)*(1-PRESS_h[:,k]/PRESS_f[:,k])

    #return the arrays
    if lev_type=="full":
        new_dim_f=np.append(psfc.shape,Nk-1)
        return np.squeeze(Z_f.reshape(new_dim_f))
    elif lev_type=="half" :
        new_dim_h=np.append(psfc.shape,Nk)
        return  np.squeeze(Z_h.reshape(new_dim_h))
    #=====return the levels in Z coordinates [m]====
    else:
        raise Exception("""Altitudes levels type not recognized: use 'full' or 'half' """)


def akbk_loader(NLAY,data_dir='/u/mkahre/MCMC/data_files'):
    """
    Return the ak and bk values given a number of layers for standards resolutions
    Default directory is /lou/s2n/mkahre/MCMC/data_files/
    Args:
        NLAY: the number of layers (float or integer)
    Returns:
        ak: 1st vertical coordinate parameter [Pa]
        bk: 2nd vertical coordinate parameter [none]

    *NOTE*    ak,bk have a size NLAY+1 since they define the position of the layer interfaces (half layers):
              p_half = ak + bk*p_sfc
    """

    from netCDF4 import Dataset
    NLAY=int(NLAY)
    file=Dataset(data_dir+'/akbk_L%i.nc'%(NLAY), 'r', format='NETCDF4')
    ak=file.variables['pk'][:]
    bk=file.variables['bk'][:]
    file.close()
    return ak,bk


def zonal_avg_P_lat(Ls,var,Ls_target,Ls_angle,symmetric=True):
    """
    Return the zonally averaged mean value of a pressure interpolated 4D variable.

    Args:
        Ls: 1D array of solar longitude of the input variable in degree (0->360)
        var: a 4D variable var [time,levels,lat,lon] interpolated on the pressure levels (f_average_plevs file)
        Ls_target: central solar longitude of interest.
        Ls_angle:  requested window angle centered around   Expl:  Ls_angle = 10.  (Window will go from Ls 85
        symmetric: a boolean (default =True) If True, and if the requested window is out of range, Ls_angle is reduced
                                             If False, the time average is done on the data available
    Returns:
        The zonnally and latitudinally-averaged field zpvar[level,lat]

    Expl:  Ls_target= 90.
           Ls_angle = 10.

           ---> Nominally, the time average is done over solar longitudes      85 <Ls_target < 95 (10 degree)

           ---> If  symmetric =True and the input data ranges from Ls 88 to 100     88 <Ls_target < 92 (4  degree, symmetric)
                If  symmetric =False and the input data ranges from Ls 88 to 100    88 <Ls_target < 95 (7  degree, assymetric)
    *NOTE*

    [Alex] as of 6/8/18, the routine will bin data from muliples Mars years if provided

    """
    #compute bounds from Ls_target and Ls_angle
    Ls_min= Ls_target-Ls_angle/2.
    Ls_max= Ls_target+Ls_angle/2.

    if (Ls_min<0.):Ls_min+=360.
    if (Ls_max>360.):Ls_max-=360.

    #Initialize output array
    zpvar=np.zeros((var.shape[1],var.shape[2])) #nlev, nlat

    #check is the Ls of interest is within the data provided, raise execption otherwise
    if Ls_target <= Ls.min() or Ls_target >=Ls.max() :
        raise Exception("Error \nNo data found, requested  data :       Ls %.2f <-- (%.2f)--> %.2f\nHowever, data in file only ranges      Ls %.2f <-- (%.2f)--> %.2f"%(Ls_min,Ls_target,Ls_max,Ls.min(),(Ls.min()+Ls.max())/2.,Ls.max()))


    else : #If only some of the requested data is outside the ranges, process this data
        if Ls_min <Ls.min() or Ls_max >Ls.max():
            print(("In zonal_avg_P_lat() Warning: \nRequested  data ranging    Ls %.2f <-- (%.2f)--> %.2f"%(Ls_min,Ls_target,Ls_max)))
            if symmetric: #Case 1: reduce the window
                if Ls_min <Ls.min():
                    Ls_min =Ls.min()
                    Ls_angle=2*(Ls_target-Ls_min)
                    Ls_max= Ls_target+Ls_angle/2.

                if Ls_max >Ls.max():
                    Ls_max =Ls.max()
                    Ls_angle=2*(Ls_max-Ls_target)
                    Ls_min= Ls_target-Ls_angle/2.

                print(("Reshaping data ranging     Ls %.2f <-- (%.2f)--> %.2f"%(Ls_min,Ls_target,Ls_max)))
            else: #Case 2: Use all data available
                print(("I am only using            Ls %.2f <-- (%.2f)--> %.2f \n"%(max(Ls.min(),Ls_min),Ls_target,min(Ls.max(),Ls_max))))
    count=0
    #perform longitude average on the field
    zvar= np.mean(var,axis=3)

    for t in range(len(Ls)):
    #special case Ls around Ls =0 (wrap around)
        if (Ls_min<=Ls[t] <= Ls_max):
            zpvar[:,:]=zpvar[:,:]+zvar[t,:,:]
            count+=1

    if  count>0:
        zpvar/=count
    return zpvar



def alt_KM(press,scale_height_KM=8.,reference_press=610.):
    """
    Gives the approximate altitude in km for a given pressure
    Args:
        press: the pressure in [Pa]
        scale_height_KM: a scale height in [km], (default is 10 km)
        reference_press: reference surface pressure in [Pa], (default is 610 Pa)
    Returns:
        z_KM: the equivalent altitude for that pressure level in [km]

    """
    return -scale_height_KM*np.log(press/reference_press) # p to altitude in km

def press_pa(alt_KM,scale_height_KM=8.,reference_press=610.):
    """
    Gives the approximate altitude in km for a given pressure
    Args:
        alt_KM: the altitude in  [km]
        scale_height_KM: a scale height in [km], (default is 8 km)
        reference_press: reference surface pressure in [Pa], (default is 610 Pa)
    Returns:
         press_pa: the equivalent pressure at that altitude in [Pa]

    """
    return reference_press*np.exp(-alt_KM/scale_height_KM) # p to altitude in km

def lon180_to_360(lon):
    lon=np.array(lon)
    """
    Transform a float or an array from the -180/+180 coordinate system to 0-360
    Args:
        lon: a float, 1D or 2D array of longitudes in the 180/+180 coordinate system
    Returns:
        lon: the equivalent longitudes in the 0-360 coordinate system

    """
    if len(np.atleast_1d(lon))==1: #lon180 is a float
        if lon<0:lon+=360
    else:                            #lon180 is an array
        lon[lon<0]+=360
    return lon

def lon360_to_180(lon):
    lon=np.array(lon)
    """
    Transform a float or an array from the 0-360 coordinate system to -180/+180
    Args:
        lon: a float, 1D or 2D array of longitudes in the 0-360 coordinate system
    Returns:
        lon: the equivalent longitudes in the -180/+180 coordinate system

    """
    if len(np.atleast_1d(lon))==1:   #lon is a float
        if lon>180:lon-=360
    else:                            #lon is an array
        lon[lon>180]-=360
    return lon


def second_hhmmss(seconds,lon_180=0.,show_mmss=True):
    """
    Given the time seconds return Local true Solar Time at a certain longitude
    Args:
        seconds: a float, the time in seconds
        lon_180: a float, the longitude in a -/+180 coordinate
        show_mmss: returns min and second if true
    Returns:
        hours: float, the local time or  (hours,minutes, seconds)

    """
    hours = seconds // (60*60)
    seconds %= (60*60)
    minutes = seconds // 60
    seconds %= 60
    #Add timezone offset (1hr/15 degree)
    hours=np.mod(hours+lon_180/15.,24)

    if show_mmss:
        return np.int(hours), np.int(minutes), np.int(seconds)
    else:
        return np.int(hours)


def sol2LTST(time_sol,lon_180=0.,show_minute=False):
    """
    Given the time in days, return the Local true Solar Time at a certain longitude
    Args:
        time_sol: a float, the time, eg. sols 2350.24
        lon_180: a float, the longitude in a -/+180 coordinate
        show_minute: show minutes if true, otherwise show whole hours
    Returns:
        hours: float, the local time or  (hours,minutes, seconds)

    """
    return second_hhmmss(time_sol*86400.,lon_180,show_minute)

def space_time(lon,timex, varIN,kmx,tmx):
    """
    Obtain west and east propagating waves. This is a Python implementation of John Wilson's  space_time routine by Alex
    Args:
        lon:   longitude array in [degrees]   0->360
        timex: 1D time array in units of [day]. Expl 1.5 days sampled every hour is  [0/24,1/24, 2/24,.. 1,.. 1.5]
        varIN: input array for the Fourier analysis.
               First axis must be longitude and last axis must be time.  Expl: varIN[lon,time] varIN[lon,lat,time],varIN[lon,lev,lat,time]
        kmx: an integer for the number of longitudinal wavenumber to extract   (max allowable number of wavenumbers is nlon/2)
        tmx: an integer for the number of tidal harmonics to extract           (max allowable number of harmonics  is nsamples/2)

    Returns:
        ampe:   East propagating wave amplitude [same unit as varIN]
        ampw:   West propagating wave amplitude [same unit as varIN]
        phasee: East propagating phase [degree]
        phasew: West propagating phase [degree]



    *NOTE*  1. ampe,ampw,phasee,phasew have dimensions [kmx,tmx] or [kmx,tmx,lat] [kmx,tmx,lev,lat] etc...
            2. The x and y axis may be constructed as follow to display the easter and western modes:

                klon=np.arange(0,kmx)  [wavenumber]  [cycle/sol]
                ktime=np.append(-np.arange(tmx,0,-1),np.arange(0,tmx))
                KTIME,KLON=np.meshgrid(ktime,klon)

                amplitude=np.concatenate((ampw[:,::-1], ampe), axis=1)
                phase=    np.concatenate((phasew[:,::-1], phasee), axis=1)

    """

    dims= varIN.shape             #get input variable dimensions

    lon_id= dims[0]    # lon
    time_id= dims[-1]  # time
    dim_sup_id=dims[1:-1] #additional dimensions stacked in the middle
    jd= np.int(np.prod( dim_sup_id))     #jd is the total number of variable in the middle is varIN>3D

    varIN= np.reshape(varIN, (lon_id, jd, time_id) )   #flatten the middle dimensions in any

    #Initialize 4 empty arrays
    ampw, ampe,phasew,phasee =[np.zeros((kmx,tmx,jd)) for _x in range(0,4)]

    #TODO not implemented yet: zamp,zphas=[np.zeros((jd,tmx)) for _x in range(0,2)]

    tpi= 2*np.pi
    argx= lon * 2*np.pi/360  #nomalize longitude array
    rnorm= 2./len(argx)

    arg= timex * 2* np.pi
    rnormt= 2./len(arg)

    #
    for kk in range(0,kmx):
        progress(kk,kmx)
        cosx= np.cos( kk*argx )*rnorm
        sinx= np.sin( kk*argx )*rnorm

    #   Inner product to calculate the Fourier coefficients of the cosine
    #   and sine contributions of the spatial variation
        acoef = np.dot(varIN.T,cosx)
        bcoef = np.dot(varIN.T,sinx)

    # Now get the cos/sine series expansions of the temporal
    #variations of the acoef and bcoef spatial terms.
        for nn in range(0,tmx):
            cosray= rnormt*np.cos(nn*arg )
            sinray= rnormt*np.sin(nn*arg )

            cosA=  np.dot(acoef.T,cosray)
            sinA=  np.dot(acoef.T,sinray)
            cosB=  np.dot(bcoef.T,cosray)
            sinB=  np.dot(bcoef.T,sinray)


            wr= 0.5*(  cosA - sinB )
            wi= 0.5*( -sinA - cosB )
            er= 0.5*(  cosA + sinB )
            ei= 0.5*(  sinA - cosB )

            aw= np.sqrt( wr**2 + wi**2 )
            ae= np.sqrt( er**2 + ei**2)
            pe= np.arctan2(ei,er) * 180/np.pi
            pw= np.arctan2(wi,wr) * 180/np.pi

            pe= np.mod( -np.arctan2(ei,er) + tpi, tpi ) * 180/np.pi
            pw= np.mod( -np.arctan2(wi,wr) + tpi, tpi ) * 180/np.pi

            ampw[kk,nn,:]= aw.T
            ampe[kk,nn,:]= ae.T
            phasew[kk,nn,:]= pw.T
            phasee[kk,nn,:]= pe.T
    #End loop


    ampw=   np.reshape( ampw,    (kmx,tmx)+dim_sup_id )
    ampe=   np.reshape( ampe,    (kmx,tmx)+dim_sup_id )
    phasew= np.reshape( phasew,  (kmx,tmx)+dim_sup_id )
    phasee= np.reshape( phasee,  (kmx,tmx)+dim_sup_id )

    #TODO implement zonal mean: zamp,zphas,stamp,stphs
    '''
    #  varIN= reshape( varIN, dims );

    #if nargout < 5;  return;  end ---> only  ampe,ampw,phasee,phasew are requested


    #   Now calculate the axisymmetric tides  zamp,zphas

    zvarIN= np.mean(varIN,axis=0)
    zvarIN= np.reshape( zvarIN, (jd, time_id) )

    arg= timex * 2* np.pi
    arg= np.reshape( arg, (len(arg), 1 ))
    rnorm= 2/time_id

    for nn in range(0,tmx):
        cosray= rnorm*np.cos( nn*arg )
        sinray= rnorm*np.sin( nn*arg )

        cosser=  np.dot(zvarIN,cosray)
        sinser=  np.dot(zvarIN,sinray)

        zamp[:,nn]= np.sqrt( cosser[:]**2 + sinser[:]**2 ).T
        zphas[:,nn]= np.mod( -np.arctan2( sinser, cosser )+tpi, tpi ).T * 180/np.pi


    zamp=  zamp.T #np.permute( zamp,  (2 1) )
    zphas= zphas.T #np.permute( zphas, (2,1) )

    if len(dims)> 2:
        zamp=  np.reshape( zamp,  (tmx,)+dim_sup_id )
        zamp=  np.reshape( zphas, (tmx,)+dim_sup_id )



    #if nargout < 7;  return;  end

    sxx= np.mean(varIN,ndims(varIN));
    [stamp,stphs]= amp_phase( sxx, lon, kmx );

    if len(dims)> 2;
        stamp= reshape( stamp, [kmx dims(2:end-1)] );
        stphs= reshape( stphs, [kmx dims(2:end-1)] );
    end

    '''

    return ampe,ampw,phasee,phasew


def give_permission(filename):
    '''
    # NAS system only: set group permission to the file
    '''
    import subprocess
    import os

    try:
        subprocess.check_call(['setfacl -v'],shell=True,stdout=open(os.devnull, "w"),stderr=open(os.devnull, "w")) #catch error and standard output
        cmd_txt='setfacl -R -m g:s0846:r '+filename
        subprocess.call(cmd_txt,shell=True)
    except subprocess.CalledProcessError:
        pass


def progress(k,Nmax):
    """
    Display a progress bar to monitor heavy calculations.
    Args:
        k: current iteration of the outer loop
        Nmax: max iteration of the outer loop
    Returns:
        Running... [#---------] 10.64 %
    """
    import sys
    from math import ceil #round yo the 2nd digit
    progress=float(k)/Nmax
    barLength = 10 # Modify this to change the length of the progress bar
    status = ""
    if isinstance(progress, int):
        progress = float(progress)
    if not isinstance(progress, float):
        progress = 0
        status = "error: progress var must be float\r\n"
    if progress < 0:
        progress = 0
        status = "Halt...\r\n"
    if progress >= 1:
        progress = 1
        status = "Done...\r\n"
    block = int(round(barLength*progress))
    text = "\rRunning... [{0}] {1} {2}%".format( "#"*block + "-"*(barLength-block), ceil(progress*100*100)/100, status)
    sys.stdout.write(text)
    sys.stdout.flush()


def dvar_dh(arr, h):
    '''
    Differentiate an array A(dim1,dim2,dim3...) with respect to h.  h and dim1 must have the same length and be the first dimension.
    Args:
        arr:   an array of dimension n
        h:     the dimensions, eg Z, P, lat, lon

    Returns:
        d_arr: the array differentiated with respect to h

    *Example*
     #Compute dT/dz where T[time,lev,lat,lon] is the temperature and Zkm are the level heights in Km:
     #First we transpose t so the vertical dimension comes first as T[LEV,time,lat,lon] and then we transpose back to get dTdz[time,LEV,lat,lon].
     dTdz=dvar_dh(t.transpose([1,0,2,3]),Zkm).transpose([1,0,2,3])

    '''


    d_arr = np.copy(arr)
    reshape_shape=np.append([arr.shape[0]-2],[1 for i in range(0,arr.ndim -1)]) #arr.shape[i]
    d_arr[0,...] = (arr[1,...]-arr[0,...])/(h[1]-h[0])
    d_arr[-1,...] = (arr[-1,...]-arr[-2,...])/(h[-1]-h[-2])
    d_arr[1:-1,...] = (arr[2:,...]-arr[0:-2,...])/(np.reshape(h[2:]-h[0:-2],reshape_shape))
    return d_arr

#=========================================================================
#=============Wrapper for creation of netcdf files========================
#=========================================================================

class Ncdf(object):
    '''
    Alex K.
    NetCdf wrapper for quick archiving of data into netcdf format

    USAGE:

    from netcdf_wrapper import Ncdf

    Fgeo= 0.03 #W/m2, a constant
    TG=np.ones((24,8)) #ground temperature

    #---create file---
    filename="/lou/s2n/mkahre/MCMC/analysis/working/myfile.nc"
    description="results from new simulation, Alex 01-01-19"
    Log=Ncdf(filename,description)

    #---Save the constant to the file---
    Log.add_constant('Fgeo',Fgeo,"geothermal flux","W/m2")

    #---Save the TG array to the file---
    Log.add_dimension('Nx',8)
    Log.add_dimension('time',24)

    Log.log_variable('TG',TG,('time','Nx'),'soil temperature','K')

    Log.close()


    '''
    def __init__(self,filename=None,description_txt="",action='w'):
        if filename:
            if filename[-3:]!=".nc":
            #assume that only path is provided so make a name for the file
                import datetime;now = datetime.datetime.now()
                filename=filename+\
                '/run_%02d-%02d-%04d_%i-%i-%i.nc'%(now.day,now.month,now.year,now.hour,now.minute,now.second)
        else:   #create a default file name  if path and filename are not provided
            import os #use a default path if not provided
            pathname=os.getcwd()+'/'
            import datetime;now = datetime.datetime.now()
            filename=pathname+\
            'run_%02d-%02d-%04d_%i-%i-%i.nc'%(now.day,now.month,now.year,now.hour,now.minute,now.second)
        self.filename=filename
        from netCDF4 import Dataset
        if action=='w':
            self.f_Ncdf = Dataset(filename, 'w', format='NETCDF4')
            self.f_Ncdf.description = description_txt
        elif action=='a': #append to file
            self.f_Ncdf = Dataset(filename, 'a', format='NETCDF4')
        #create dictionaries to hold dimensions and variables
        self.dim_dict=dict()
        self.var_dict=dict()
        print((filename+ " was created"))

    def close(self):
        self.f_Ncdf.close()
        print((self.filename+" was closed"))

    def add_dimension(self,dimension_name,length):
        self.dim_dict[dimension_name]= self.f_Ncdf.createDimension(dimension_name,length)

    def print_dimension(self):
        print((list(self.dim_dict.items())))
    def print_variable(self):
        print((list(self.var_dict.keys())))

    def add_constant(self,variable_name,value,longname_txt="",unit_txt=""):
        if not any('constant' in s for s in list(self.dim_dict.keys())):
            self.add_dimension('constant',1)
        longname_txt =longname_txt+' (%g)'%(value)   #add the value to the longname
        self.def_variable(variable_name,('constant'),longname_txt,unit_txt)
        self.var_dict[variable_name][:]=value

    def def_variable(self,variable_name,dim_array,longname_txt="",unit_txt=""):
        self.var_dict[variable_name]= self.f_Ncdf.createVariable(variable_name,'f4',dim_array)
        self.var_dict[variable_name].units=unit_txt
        self.var_dict[variable_name].long_name=longname_txt
        self.var_dict[variable_name].dim_name=dim_array

    def log_variable(self,variable_name,DATAin,dim_array,longname_txt="",unit_txt=""):
        if not any(variable_name in s for s in list(self.var_dict.keys())):
            self.def_variable(variable_name,dim_array,longname_txt,unit_txt)
        self.var_dict[variable_name].long_name=longname_txt
        self.var_dict[variable_name].dim_name=dim_array
        self.var_dict[variable_name].units=unit_txt
        self.var_dict[variable_name][:]=DATAin

#=========================================================================
#=======================vertical grid utilities===========================
#=========================================================================
def gauss_profile(x, alpha,x0=0.):
    """ Return Gaussian line shape at x This can be used to generate a bell-shaped mountain"""
    return np.sqrt(np.log(2) / np.pi) / alpha\
                             * np.exp(-((x-x0) / alpha)**2 * np.log(2))

def compute_uneven_sigma(num_levels, N_scale_heights, surf_res, exponent, zero_top ):
    """
    Construct an initial array of sigma based on the number of levels, an exponent
    Args:
        num_levels: the number of levels
        N_scale_heights: the number of scale heights to the top of the model (e.g scale_heights =12.5 ~102km assuming 8km scale height)
        surf_res: the resolution at the surface
        exponent: an exponent to increase th thickness of the levels
        zero_top: if True, force the top pressure boundary (in N=0) to 0 Pa
    Returns:
        b: an array of sigma layers

    """
    b=np.zeros(int(num_levels)+1)
    for k in range(0,num_levels):
        zeta = 1.-k/np.float(num_levels) #zeta decreases with k
        z  = surf_res*zeta + (1.0 - surf_res)*(zeta**exponent)
        b[k] = np.exp(-z*N_scale_heights)
    b[-1] = 1.0
    if(zero_top):  b[0] = 0.0
    return b


def transition( pfull, p_sigma=0.1, p_press=0.05):
    """
    Return the transition factor to construct the ak and bk
    Args:
        pfull: the pressure in Pa
        p_sigma: the pressure level where the vertical grid starts transitioning from sigma to pressure
        p_press: the pressure level above those  the vertical grid is pure (constant) pressure
    Returns:
        t: the transition factor =1 for pure sigma, 0 for pure pressure and 0<t<1 for the transition

    NOTE:
    In the FV code full pressure are computed from:
                       del(phalf)
         pfull = -----------------------------
                 log(phalf(k+1/2)/phalf(k-1/2))
    """
    t=np.zeros_like(pfull)
    for k in range(0,len(pfull)):
        if( pfull[k] <= p_press):
            t[k] = 0.0
        elif ( pfull[k] >= p_sigma) :
            t[k] = 1.0
        else:
            x  = pfull[k]    - p_press
            xx = p_sigma - p_press
            t[k] = (np.sin(0.5*np.pi*x/xx))**2

    return t


def swinbank(plev, psfc, ptrans=1.):
    """
    Compute ak and bk values with a transition based on Swinbank
    Args:
        plev: the pressure levels in Pa
        psfc: the surface pressure in Pa
        ptrans:the transition pressure in Pa
    Returns:
         aknew, bknew,ks: the coefficients for the new layers
    """

    ktrans= np.argmin(np.abs( plev- ptrans) ) # ks= number of pure pressure levels
    km= len(plev)-1

    aknew=np.zeros(len(plev))
    bknew=np.zeros(len(plev))

    #   pnorm= 1.e5;
    pnorm= psfc
    eta= plev / pnorm

    ep= eta[ktrans+1]       #  ks= number of pure pressure levels
    es= eta[-1]
    rnorm= 1. / (es-ep)**2

    #   Compute alpha, beta, and gamma using Swinbank's formula
    alpha = (ep**2 - 2.*ep*es) / (es-ep)**2
    beta  =        2.*ep*es**2 / (es-ep)**2
    gamma =        -(ep*es)**2 / (es-ep)**2

    #   Pure Pressure levels
    aknew= eta * pnorm

    #  Hybrid pressure-sigma levels
    kdex= list(range(ktrans+1,km))
    aknew[kdex] = alpha*eta[kdex] + beta + gamma/eta[kdex]
    aknew[kdex]= aknew[kdex] * pnorm
    aknew[-1]= 0.0

    bknew[kdex] = (plev[kdex] - aknew[kdex])/psfc
    bknew[-1] = 1.0

    #find the transition level ks where (bk[ks]>0)
    ks=0
    while bknew[ks]==0. :
        ks+=1
    #ks is the one that would be use in fortran indexing in fv_eta.f90
    return  aknew, bknew,ks

def printvar(infile):
    '''Get all variable names for a netCDF file '''
    print("Variables:")
    variableNames = list(infile.variables.keys());
    print(variableNames)
    print("\n")

def getind(myloc,field):
    '''get a specific index in the lat, lon, time or pfull 1D field'''
    res=[]
    for loc in np.atleast_1d(myloc):
        myind=np.where(abs(field[:]-loc)==min(abs(field[:]-loc)))[0][0]
        res.append(myind)
    res=np.atleast_1d(res)
    if len(res)== 1: return res[0]
    return np.atleast_1d(res)


# getinds=np.vectorize(getind, excluded="field")

def getvar(nc1,var,times=None,longitudes=None,latitudes=None,altitudes=None,t_mean=False,l_mean=False):
    """Get variables from netcdf file and slice it up!

    Args:
        nc1 : netcdf identifier
        var (str): variable name
        times (list, optional): times to select. Defaults to None.
        longitudes (list, optional): longitudes to select. Defaults to None.
        latitudes (list, optional): latitudes to select. Defaults to None.
        t_mean (bool, optional): Apply mean on time dimension. Defaults to False.
        l_mean (bool, optional): Apply mean on longitude dimension. Defaults to False.

    Returns:
        array: variable data array
    """
    
    if var == "latitude" and var not in nc1.variables: var="lat"
    if var == "longitude" and var not in nc1.variables: var="lon"
    if var == "Time" and var not in nc1.variables: var="time_counter"
    if var == "aire" and var not in nc1.variables: var="area"
    if var == "phisinit" and var not in nc1.variables: var="phisfi"

    var_data = nc1.variables[var]
    dims = var_data.dimensions
    myvar = var_data[:]

    if times or longitudes or latitudes:
        slicer = [slice(None)] * len(dims)

        time_it = (times,TIME_DIMS)
        alt_it = (altitudes,ALT_DIMS)
        lon_it = (longitudes,LON_DIMS)
        lat_it = (latitudes,LAT_DIMS)

        # iterate over dimensions to pick indices to select
        for values, dim in (time_it, lon_it, lat_it, alt_it):
            try:
                dim_idx = find_dim_index(dims, dim)
                tmp_var = find_coord_var(nc1, dim)
                file_values = nc1.variables[tmp_var][:]
                i = getind(values, file_values)
                print(dim, i)
                if len(i)==2: # range of indices
                    slicer[dim_idx] = slice(min(i),max(i)+1)
                else: # specific index
                    slicer[dim_idx] = i
            except:
                continue
        
        # slice the variable according to the indices
        myvar = myvar[tuple(slicer)]
        
    if l_mean: # longitudinal mean
        myvar=np.mean(myvar,axis=find_dim_index(dims, LON_DIMS))
    if t_mean: # temporal mean
        myvar=np.mean(myvar,axis=find_dim_index(dims, TIME_DIMS)) 
    print(np.shape(myvar))
    try:
        if var == "altitude":
            units = nc1.variables[var].units
            if units == "m":
                myvar/=1000 # plot in km
    except:
        pass
    return myvar

def getarea(filename2,var,tint):
    myvar = nc2.variables["aire"][:]
    print((shape(myvar)))
    return myvar
