# Python script to vertically nterpolate a 3D field from a netCDF file
# L. Fita, LMD. Jussieu, July 2014
#
## export PATH=/u/lflmd/bin/gcc_Python-2.7.5/bin:${PATH}
## g.e. # vertical_interpolation.py -f /home/lluis/WRF/util/p_interp/wrfout_d01_1995-01-15_00:00:00 -o WRFp -i 100000.,92500.,85000.,70000.,60000.,50000.,40000.,30000.,20000.,10000. -k 'lin' -v WRFt,WRFght -d T:Time,Z:bottom_top,Y:south_north,X:west_east -D T:Times,Z:ZNU,Y:XLAT,X:XLONG
## g.e. # vertical_interpolation.py -f /media/data1/etudes/WRF_LMDZ/WaquaL/WRF_LMDZ/NPv31/wrfout/wrfout_d01_1980-03-01_00:00:00 -o WRFp -i 100000.,97500.,95000.,92500.,90000.,85000.,80000.,75000.,70000.,65000.,60000.,55000.,50000.,45000.,40000.,35000.,30000.,25000.,20000.,15000.,10000. -k 'lin' -v WRFt,U,V,WRFrh,WRFght -d T:Time,Z:bottom_top,Y:south_north,X:west_east -D T:Times,Z:ZNU,Y:XLAT,X:XLONG

import numpy as np
from netCDF4 import Dataset as NetCDFFile
import os
import re
import nc_var_tools as ncvar
from optparse import OptionParser

def lonlat_creation(dimpn,nco,nwnc):
    """ Function to create longitude/latitude variables according to a dictionary
      dimpn: dictionary of varibale names for each dimension: 'T', 'Z', 'Y', 'x' 
      nco: original netCDF object
      nwnc: new netCDF object
    """
    fname = 'lonlat_creation' 

    print 'Lluis:',dimpn['X']

    if dimpn.has_key('X') and dimpn.has_key('Y'):
        lonobj = nco.variables[dimpn['X']]
        if len(lonobj.shape) == 3:
            newvar = nwnc.createVariable('lon', 'f8', ('y', 'x'))
            neattr = ncvar.basicvardef(newvar, 'lon', 'longitude', 'degrees East')
            newvar[:] = lonobj[0,:,:]
        
            newvar = nwnc.createVariable('lat', 'f8', ('y', 'x'))
            neattr = ncvar.basicvardef(newvar, 'lat', 'latitude', 'degrees Nord')
            newvar[:] = nco.variables[dimvns[1]][0,:,:]
        elif len(lonobj.shape) == 2:
            newvar = nwnc.createVariable('lon', 'f8', ('y', 'x'))
            neattr = ncvar.basicvardef(newvar, 'lon', 'longitude', 'degrees East')
            newvar[:] = lonobj[:]
        
            newvar = nwnc.createVariable('lat', 'f8', ('y', 'x'))
            neattr = ncvar.basicvardef(newvar, 'lat', 'latitude', 'degrees Nord')
            newvar[:] = nco.variables[dimvns[1]][:]
        else:
            print errormsg
            print '  ' + main + ': shape of longitude: ',lonobj.shape,' not ready !!'
            quit(-1)
    
    elif dimpn.has_key('X') and not dimpn.has_key('Y'):
        print warnmsg
        print '  ' + main + ": variable pressure with 'X', but not 'Y'!!"
        lonobj = nco.variables[dimpn['X']]
        if len(lonobj.shape) == 3:
            newvar = nwnc.createVariable('lon', 'f8', ('x'))
            neattr = ncvar.basicvardef(newvar, 'lon', 'longitude', 'degrees East')
            newvar[:] = lonobj[0,0,:]
        
        elif len(lonobj.shape) == 2:
            newvar = nwnc.createVariable('lon', 'f8', ('x'))
            neattr = ncvar.basicvardef(newvar, 'lon', 'longitude', 'degrees East')
            newvar[:] = lonobj[0,:]    
    
        elif len(lonobj.shape) == 1:
            newvar = nwnc.createVariable('lon', 'f8', ('x'))
            neattr = ncvar.basicvardef(newvar, 'lon', 'longitude', 'degrees East')
            newvar[:] = lonobj[:]
        else:
            print errormsg
            print '  ' + main + ': shape of longitude: ',lonobj.shape,' not ready !!'
            quit(-1)
    
    elif not dimpn.has_key('X') and dimpn.has_key('Y'):
        print warnmsg
        print '  ' + main + ": variable pressure with 'Y', but not 'X'!!"
        latobj = nco.variables[dimpn['Y']]
        if len(latobj.shape) == 3:
            newvar = nwnc.createVariable('lat', 'f8', ('y'))
            neattr = ncvar.basicvardef(newvar, 'lat', 'latitude', 'degrees North')
            newvar[:] = latobj[0,0,:]
        
        elif len(latobj.shape) == 2:
            newvar = nwnc.createVariable('lat', 'f8', ('x'))
            neattr = ncvar.basicvardef(newvar, 'lat', 'latitude', 'degrees North')
            newvar[:] = latobj[0,:]    
    
        elif len(latobj.shape) == 1:
            newvar = nwnc.createVariable('lat', 'f8', ('x'))
            neattr = ncvar.basicvardef(newvar, 'lat', 'latitude', 'degrees North')
            newvar[:] = latobj[:]
        else:
            print errormsg
            print '  ' + main + ': shape of latitude: ',latobj.shape,' not ready !!'
            quit(-1)
    else:
        print warnmsg
        print '  ' + main + ": variable pressure without 'X' and 'Y'!!"

    return

####### ###### ##### #### ### ## #

main='vertical_interpolation.py'
errormsg='ERROR -- error -- ERROR -- error'
warnmsg='WARNING -- warning -- WARNING -- warning'

parser = OptionParser()
parser.add_option("-o", "--original_ZValues", dest="ozvals",
                  help="variable name with the original z-values ('WRFp', WRF derived pressure values)", metavar="VALUE")
parser.add_option("-i", "--interpolate_ZValues", dest="izvals",
                  help="comma-separated list of values to interpolate ([zi1],[zi2],[...[ziN]])", metavar="VALUES")
parser.add_option("-f", "--file", dest="file",
                  help="netCDF file to use", metavar="FILE")
parser.add_option("-k", "--interpolation_kind", dest="kfig",
                  help="kind of vertical inerpolation ('lin', linear)", metavar="LABEL")
parser.add_option("-v", "--variables", dest="vars",
                  help="comma-separated list of name of the variables to inerpolate ('all', all variables)",
   metavar="LABEL")
parser.add_option("-d", "--dimensions", dest="dims",
                  help="comma-separated list of dimensions name (where applicable, T:[dimt],Z:[dimz],Y:[dimy],X:[dimx])",
   metavar="LABEL")
parser.add_option("-D", "--dimension_vnames", dest="dimvns",
                  help="comma-separated list of variables with the dimensions values (T:[dimt],Y:[dimy],X:[dimx])", metavar="LABEL")

(opts, args) = parser.parse_args()

####### ###### ##### #### ### ## #

# Variables wich do not exist in the file, but they will be computed
NOfileVars = ['WRFght', 'WRFrh', 'WRFt']

#######    #######
## MAIN
    #######

ofile = 'vertical_interpolation_' + opts.ozvals +'.nc'

if not os.path.isfile(opts.file):
    print errormsg
    print '  ' + main + ' file: "' + opts.file + '" does not exist !!'
    quit(-1)

ncobj = NetCDFFile(opts.file, 'r')

if opts.ozvals == 'WRFp':
    ncdims = ncobj.variables['P'].dimensions
else:
    ncdims = ncobj.dimensions


dimns = opts.dims.split(',')

# Variables with values of the dimensions
dimvns = []
d4dimvns = {}

for i in range(4):
    dimvid = opts.dimvns.split(',')[i].split(':')[0]
    dimvv = opts.dimvns.split(',')[i].split(':')[1]

    d4dimvns[dimvid] = dimvv
    if dimvid != 'Z': dimvns.append(dimvv)

dimpresv = {}
idim = 0
dimpresn = {}
dimpresnv = {}
d4dims = {}
for dim in dimns:
    dimid = dim.split(':')[0]
    dimn = dim.split(':')[1]
    if not ncvar.searchInlist(ncdims, dimn):
        print warnmsg
        print '  ' + main + ' in file: "' + opts.file + '" dimension "' + dimn +     \
          '" does not exist !!'
#        quit(-1)
    else:
        dimpresv[dimid] = len(ncobj.dimensions[dimn])
        dimpresn[dimid] = dimn
        dimpresnv[dimn] = len(ncobj.dimensions[dimn])

    d4dims[dimid] = dimn
    idim = idim + 1

print 'Generic 4d dimensions:',d4dims

dimvl = []
for dimid in ['T', 'Z', 'Y', 'X']:
    if dimpresv.has_key(dimid): 
        dimvl.append(dimpresv[dimid])

dimv = np.array(dimvl)
print 'dimensions pressure variable: ',dimpresnv

# Interpolation values
intervals = opts.izvals.split(',')
intvals = ncvar.list_toVec(intervals, 'npfloat')
Nintvals = len(intvals)

# Obtaining original vertical coordinate
if opts.ozvals == 'WRFp':
    print '    ' + main + ': Retrieving original pressure values from WRF ' +       \
      'files as P + PB'
    zorigvals = np.zeros((dimv), dtype = np.float)
    zorigvals = ncobj.variables['P'][:] + ncobj.variables['PB'][:]
    zorigname = 'pressure'
    zorigsn = 'pressure'
    zorigln = 'pressure of the air'
    zorigu = 'hPa'
else:
    zorigvals = ncobj.variables[opts.ozvals][:]
    zorigname = opts.ozvals
    varattrs = ncobj.variables[opts.ozvals].ncattrs()
    if ncvar.searchInlist(varattrs,'standard_name'):
        zorigsn = ncobj.variables[var].getncattr('standard_name')
    else:
        zorigsn = var

    if ncvar.searchInlist(varattrs,'long_name'):
        zorigln = ncobj.variables[var].getncattr('long_name')
    else:
        zorigln = ncobj.variables[var].getncattr('description')

    zorigu = ncobj.variables[var].getncattr('units')

# Which is the sense of the original vertical coordinate?
Ndimszorig = len(zorigvals.shape)

if Ndimszorig == 4:
    Lzorig = zorigvals.shape[1]
    inczorig = zorigvals[0,Lzorig-1,0,0] - zorigvals[0,0,0,0]
elif Ndimszorig == 3:
    Lzorig = zorigvals.shape[0]
    inczorig = zorigvals[Lzorig-1,0,0] - zorigvals[0,0,0]
elif Ndimszorig == 2:
# Assuming Time,z
    Lzorig = zorigvals.shape[1]
    inczorig = zorigvals[0,Lzorig-1] - zorigvals[0,0]
else:
    print errormsg
    print '  ' + main + ': vertical coordinate shape:',zorigvals.shape,'not ready!!!'
    quit(-1)

print '  ' + main + ': vertical coordinate to interpolate:',dimpresn['Z']
# Which are the desired variables
if opts.vars == 'all':
    varns0 = ncobj.variables
    varns = []
# only getting that variables which contain dimns[1]
    for varn in varns0:
        vobj = ncobj.variables[varn]
        if ncvar.searchInlist(vobj.dimensions, dimpresn['Z']):
            varns.append(varn)
    print '    ' + main + ': Interpolation of all variables !!!'
    print '      ', varns
else:
    varns = opts.vars.split(',')

# Creation of output file
##
newnc = NetCDFFile(ofile, 'w')

# Creation of dimensions
if dimpresn.has_key('X'): newdim = newnc.createDimension('x', dimpresv['X'])
if dimpresn.has_key('Y'): newdim = newnc.createDimension('y', dimpresv['Y'])
if dimpresn.has_key('Z'): newdim = newnc.createDimension('z', Nintvals)
if dimpresn.has_key('T'):
    if not newnc.dimensions.has_key('time'): 
        timeobj = ncobj.variables[d4dimvns['T']]
        if opts.ozvals == 'WRFp':
            newdim = newnc.createDimension('time', None)
            newvar = newnc.createVariable('time', 'f8', ('time'))

            timevals = np.zeros((dimv[0]), dtype=np.float)

            if d4dimvns['T'] == 'Times':
                timeu = 'hours since 1949-12-01 00:00:00'
                for it in range(dimv[0]):
                    gdate = ncvar.datetimeStr_conversion(timeobj[it,:],'WRFdatetime',\
                      'matYmdHMS')
                    timevals[it] = ncvar.realdatetime1_CFcompilant(gdate,            \
                      '19491201000000', 'hours')
            else:
                timevals = timeobj[:]
                timeu = timeobj.getncattr('units')

        newattr = ncvar.basicvardef(newvar, 'time', 'time', timeu)
        newvar[:] = timevals

# Varibales dimension
#ofunc = lonlat_creation(dimpresn,ncobj,newnc)
ofunc = lonlat_creation(d4dimvns,ncobj,newnc)

newvar = newnc.createVariable(zorigname, 'f8', ('z'))
neattr = ncvar.basicvardef(newvar, zorigsn, zorigln, zorigu)
newvar[:] = intvals

# Looping along the variables
for var in varns:
    print 'variable:',var

    if not ncvar.searchInlist(ncobj.variables,var) and not                           \
      ncvar.searchInlist(NOfileVars,var):
        print errormsg
        print '  ' + main + ' in file: "' + opts.file + '" variable "' + var +       \
          '" does not exist !!'
        quit(-1)

    if ncvar.searchInlist(NOfileVars,var):
        if var == 'WRFght':
            print '    ' + main + ': computing geopotential height from WRF as ' +   \
              ' PH + PHB ...' 
            varvals = ncobj.variables['PH'][:] + ncobj.variables['PHB'][:]
            varsn = 'zg'
            varln = 'geopotential height'
            varu = 'gpm'
            varinterp = np.zeros((dimv[0], Nintvals, dimv[2], dimv[3]), dtype=np.float)
            newvar = newnc.createVariable(var, 'f4', ('time','z','y','x'))
        elif var == 'WRFrh':
            print '    ' + main + ': computing relative humidity from WRF as ' +     \
              " Tetens' equation (T,P) ..." 
            p0=100000.
            p=ncobj.variables['P'][:] + ncobj.variables['PB'][:]

            tk = (ncobj.variables['T'][:] + 300.)*(p/p0)**(2./7.)
            qv = ncobj.variables['QVAPOR'][:]

            data1 = 10.*0.6112*np.exp(17.67*(tk-273.16)/(tk-29.65))
            data2 = 0.622*data1/(0.01*p-(1.-0.622)*data1)
            varvals = qv/data2

            varsn = 'rh'
            varln = 'relative humidity of the air'
            varu = '%'
            varinterp = np.zeros((dimv[0], Nintvals, dimv[2], dimv[3]), dtype=np.float)
            newvar = newnc.createVariable(var, 'f4', ('time','z','y','x'))
        elif var == 'WRFt':
            print '    ' + main + ': computing temperature from WRF as ' +           \
              ' inv_potT(T + 300) ...' 
            p0=100000.
            p=ncobj.variables['P'][:] + ncobj.variables['PB'][:]

            varvals = (ncobj.variables['T'][:] + 300.)*(p/p0)**(2./7.)
            varsn = 'ta'
            varln = 'temperature of the air'
            varu = 'K'
            varinterp = np.zeros((dimv[0], Nintvals, dimv[2], dimv[3]), dtype=np.float)
            newvar = newnc.createVariable(var, 'f4', ('time','z','y','x'))
        else:
            print errormsg
            print '  ' + fname + ': variable "' + var + '" not ready !!!!'
            quit(-1)
    else:
        varobj = ncobj.variables[var]
        vardims = varobj.dimensions
        varattrs = varobj.ncattrs()
        if ncvar.searchInlist(varattrs,'standard_name'):
            varsn = ncobj.variables[var].getncattr('standard_name')
        else:
            varsn = var

        if ncvar.searchInlist(varattrs,'long_name'):
            varln = ncobj.variables[var].getncattr('long_name')
        else:
            varln = ncobj.variables[var].getncattr('description')

        varu = ncobj.variables[var].getncattr('units')

# Getting variable values:
        if len(varobj.shape) == 4:
            varinterp = np.zeros((dimv[0], Nintvals, dimv[2], dimv[3]), dtype=np.float)
            newvar = newnc.createVariable(var, 'f4', ('time','z','y','x'))
            varvals = varobj[:]
        elif len(varobj.shape) <= 3 and len(varobj.shape) >= 1:
            varpdimvs = []
            varpdimns = []
            varpslice = []
            for vdim in vardims:
                vardim = len(ncobj.dimensions[vdim])
                if ncvar.searchInlist(ncdims,vdim):
                    if vdim == dimpresn['Z']:
                        varpdimvs.append(Nintvals)
                        varpdimns.append('z')
                        varpslice.append(slice(0,vardim))
                    else:
                        varpdimvs.append(vardim)
                        varpslice.append(slice(0,vardim))
                        if dimpresn.has_key('X') and vdim == dimpresn['X']:
                            varpdimns.append('x')
                        elif dimpresn.has_key('Y') and vdim == dimpresn['Y']:
                            varpdimns.append('y')
                        elif dimpresn.has_key('T') and vdim == dimpresn['T']:
                            varpdimns.append('time')
                        else:
                            print errormsg
                            print '  ' + main + ": dimension variable '" + vdim +        \
                              "' is in pressure but it is not found?"
                            print '    pressure dimensions:', ncdims
                            quit(-1)
                else:
# Dimension of the variable is not in the pressure variable
                    varpdimvs.append(vardim)
                    varpdimns.append(vdim)
                    varpslice.append(slice(0,vardim))
                    if not newnc.dimensions.has_key(vdim) and not                        \
                      ncvar.searchInlist(dimpresn,vdim):
                        print '  ' + main + ": dimension '" + vdim + "' not in the " +   \
                          'pressure variable adding it!'
                        if ncobj.dimensions[vdim].isunlimited():
                            newnc.createDimension(vdim, None)
                        else:
                            newnc.createDimension(vdim, vardim)
    
            varinterp = np.zeros((varpdimvs), dtype=np.float)
            newvar = newnc.createVariable(var, 'f4', tuple(varpdimns))
    
            varvals = varobj[tuple(varpslice)]
        else:
            print errormsg
            print '  ' + main + ': variable shape "', varobj.shape, '" not ready !!!!'
            quit(-1)

#    print 'variable:',var,'shape:',varvals.shape,'len:',len(varvals.shape)    
    if len(varvals.shape) == 3 and len(zorigvals.shape) == 3:
        if (varvals.shape[1] + varvals.shape[2]) != (dimv[1] + dimv[2]):
            print warnmsg
            print '  ' + main + ': variable=', varvals.shape[1:3],                   \
              'with different shape!',dimv[1],dimv[2]
            if (varvals.shape[1] + varvals.shape[2]) - (dimv[1] + dimv[2]) == 1:
                print '    Assuming staggered variable'
                varvals0 = np.zeros((Nintvals, dimv[1], dimv[2]), dtype=np.float)

                if (varvals.shape[1] > dimv[1]):
                    varvals0 = (varvals[0:dimv[1],:] + varvals[1:dimv[1]+1,:])/2.
                else:
                    varvals0 = (varvals[:,0:dimv[2]] + varvals[:,1:dimv[2]+1])/2.
            varinterp = ncvar.interpolate_3D(zorigvals, varvals0, intvals, 'lin')
        else:
            varinterp = ncvar.interpolate_3D(zorigvals, varvals, intvals, 'lin')

    elif len(varvals.shape) == 4 and len(zorigvals.shape) == 4:
        if (varvals.shape[2] + varvals.shape[3]) != (dimv[2] + dimv[3]):
            print warnmsg
            print '  ' + main + ': variable=', varvals.shape[2:4],                   \
              'with different shape!',dimv[2],dimv[3]
            if (varvals.shape[2] + varvals.shape[3]) - (dimv[2] + dimv[3]) == 1:
                print '    Assuming staggered variable'
                varvals0 = np.zeros((dimv[0],dimv[1],dimv[2],dimv[3]),               \
                  dtype=np.float)
                if (varvals.shape[2] > dimv[2]):
                    varvals0=(varvals[:,:,0:dimv[2],:]+varvals[:,:,1:dimv[2]+1,:])/2.
                else:
                    varvals0=(varvals[:,:,:,0:dimv[3]]+varvals[:,:,:,1:dimv[3]+1])/2.

            for it in range(dimv[0]):
                ncvar.percendone(it, dimv[0], 5, 'interpolated')

                varinterp[it,:,:,:] = ncvar.interpolate_3D(zorigvals[it,:,:,:],      \
                  varvals0[it,:,:,:], intvals, 'lin') 
        else:
            for it in range(dimv[0]):
                ncvar.percendone(it, dimv[0], 5, 'interpolated')

                varinterp[it,:,:,:] = ncvar.interpolate_3D(zorigvals[it,:,:,:],      \
                  varvals[it,:,:,:], intvals, 'lin')
#            print varinterp[it,:,:,:]
#            quit()
    elif len(varvals.shape) == 2 and len(zorigvals.shape) == 2:
        zdimid = varpdimns.index('z')
        varinterp = ncvar.interpolate_2D(zorigvals, varvals, zdimid, intvals, 'lin')

    elif len(varvals.shape) == 3 and len(zorigvals.shape) == 2:
        print 'Lluis here variable with an extra dimension in comparison to pressure'
        zdimid = varpdimns.index('z')
        presdn = []
        for idi in range(2):
            presdn.append(dimpresn.keys()[idi].lower())
        for dimid in range(3):
            if not ncvar.searchInlist(presdn,varpdimns): loopd = dimid
# Looping along the extra dimension
        if loopd == 0:
            for il in range(varvals.shape[loopd]):
                varinterp[il,:,:] = ncvar.interpolate_2D(zorigvals, varvals[il,:,:], \
                  zdimid, intvals, 'lin')
        elif loopd == 1:
            for il in range(varvals.shape[loopd]):
                varinterp[:,il,:] = ncvar.interpolate_2D(zorigvals, varvals[:,il,:], \
                  zdimid, intvals, 'lin')
        elif loopd == 2:
            for il in range(varvals.shape[loopd]):
                varinterp[:,:,il] = ncvar.interpolate_2D(zorigvals, varvals[:,:,il], \
                  zdimid, intvals, 'lin')
    else:
#        print errormsg
        print warnmsg
        print '  ' + main + ': dimension of values:',varvals.shape,                  \
          ' not ready to interpolate using:',zorigvals.shape,'!!!'
        print '    skipping variable'
#        quit(-1)

    print 'v_interp Lluis dims newvar:',newvar.dimensions
    print 'v_interp Lluis shapes: newvar',newvar.shape,'varinterp:',varinterp.shape
    newvar[:] = varinterp
    ncvar.basicvardef(newvar, varsn, varln, varu)

newnc.sync()

# Global attributes
##

atvar = ncvar.set_attribute(newnc, 'program', 'vertical_inerpolation.py')
atvar = ncvar.set_attribute(newnc, 'version', '1.0')
atvar = ncvar.set_attribute(newnc, 'author', 'Fita Borrell, Lluis')
atvar = ncvar.set_attribute(newnc, 'institution', 'Laboratoire Meteorologie ' +      \
  'Dynamique')
atvar = ncvar.set_attribute(newnc, 'university', 'Universite Pierre et Marie ' +     \
  'Curie -- Jussieu')
atvar = ncvar.set_attribute(newnc, 'centre', 'Centre national de la recherche ' +    \
  'scientifique')
atvar = ncvar.set_attribute(newnc, 'city', 'Paris')
atvar = ncvar.set_attribute(newnc, 'original_file', opts.file)

gorigattrs = ncobj.ncattrs()

for attr in gorigattrs:
    attrv = ncobj.getncattr(attr)
    atvar = ncvar.set_attribute(newnc, attr, attrv)

ncobj.close()
newnc.close()

print main + ': successfull writting of verticaly interpolated file "'+ofile+'" !!!'
