# Pthon script to comput diagnostics
# L. Fita, LMD. CNR, UPMC-Jussieu, Paris, France
# File diagnostics.inf provides the combination of variables to get the desired diagnostic
#
## e.g. # diagnostics.py -d 'Time@time,bottom_top@ZNU,south_north@XLAT,west_east@XLONG' -v 'clt|CLDFRA,cllmh|CLDFRA@WRFp,RAINTOT|RAINC@RAINNC@XTIME' -f WRF_LMDZ/NPv31/wrfout_d01_1980-03-01_00:00:00
## e.g. # diagnostics.py -f /home/lluis/PY/diagnostics.inf -d variable_combo -v WRFprc

from optparse import OptionParser
import numpy as np
from netCDF4 import Dataset as NetCDFFile
import os
import re
import nc_var_tools as ncvar
import datetime as dt

main = 'diagnostics.py'
errormsg = 'ERROR -- error -- ERROR -- error'
warnmsg = 'WARNING -- warning -- WARNING -- warning'

# Gneral information
##
def reduce_spaces(string):
    """ Function to give words of a line of text removing any extra space
    """
    values = string.replace('\n','').split(' ')
    vals = []
    for val in values:
         if len(val) > 0:
             vals.append(val)

    return vals

def variable_combo(varn,combofile):
    """ Function to provide variables combination from a given variable name
      varn= name of the variable
      combofile= ASCII file with the combination of variables
        [varn] [combo]
          [combo]: '@' separated list of variables to use to generate [varn]
            [WRFdt] to get WRF time-step (from general attributes)
    >>> variable_combo('WRFprls','/home/lluis/PY/diagnostics.inf')
    deaccum@RAINNC@XTIME@prnc
    """
    fname = 'variable_combo'

    if varn == 'h':
        print fname + '_____________________________________________________________'
        print variable_combo.__doc__
        quit()

    if not os.path.isfile(combofile):
        print errormsg
        print '  ' + fname + ": file with combinations '" + combofile +              \
          "' does not exist!!"
        quit(-1)

    objf = open(combofile, 'r')

    found = False
    for line in objf:
        linevals = reduce_spaces(line)
        varnf = linevals[0]
        combo = linevals[1].replace('\n','')
        if varn == varnf: 
            found = True
            break

    if not found:
        print errormsg
        print '  ' + fname + ": variable '" + varn + "' not found in '" + combofile +\
          "' !!"
        combo='ERROR'

    objf.close()

    return combo

# Mathematical operators
##
def compute_deaccum(varv, dimns, dimvns):
    """ Function to compute the deaccumulation of a variable
    compute_deaccum(varv, dimnames, dimvns)
      [varv]= values to deaccum (assuming [t,])
      [dimns]= list of the name of the dimensions of the [varv]
      [dimvns]= list of the name of the variables with the values of the 
        dimensions of [varv]
    """
    fname = 'compute_deaccum'

    deacdims = dimns[:]
    deacvdims = dimvns[:]

    slicei = []
    slicee = []

    Ndims = len(varv.shape)
    for iid in range(0,Ndims):
        slicei.append(slice(0,varv.shape[iid]))
        slicee.append(slice(0,varv.shape[iid]))

    slicee[0] = np.arange(varv.shape[0])
    slicei[0] = np.arange(varv.shape[0])
    slicei[0][1:varv.shape[0]] = np.arange(varv.shape[0]-1)

    vari = varv[tuple(slicei)]
    vare = varv[tuple(slicee)]

    deac = vare - vari

    return deac, deacdims, deacvdims

def derivate_centered(var,dim,dimv):
    """ Function to compute the centered derivate of a given field
      centered derivate(n) = (var(n-1) + var(n+1))/(2*dn).
    [var]= variable
    [dim]= which dimension to compute the derivate
    [dimv]= dimension values (can be of different dimension of [var])
    >>> derivate_centered(np.arange(16).reshape(4,4)*1.,1,1.)
    [[  0.   1.   2.   0.]
     [  0.   5.   6.   0.]
     [  0.   9.  10.   0.]
     [  0.  13.  14.   0.]]
    """

    fname = 'derivate_centered'
    
    vark = var.dtype

    if hasattr(dimv, "__len__"):
# Assuming that the last dimensions of var [..., N, M] are the same of dimv [N, M]
        if len(var.shape) != len(dimv.shape):
            dimvals = np.zeros((var.shape), dtype=vark)
            if len(var.shape) - len(dimv.shape) == 1:
                for iz in range(var.shape[0]):
                    dimvals[iz,] = dimv
            elif len(var.shape) - len(dimv.shape) == 2:
                for it in range(var.shape[0]):
                    for iz in range(var.shape[1]):
                        dimvals[it,iz,] = dimv
            else:
                print errormsg
                print '  ' + fname + ': dimension difference between variable',      \
                  var.shape,'and variable with dimension values',dimv.shape,         \
                  ' not ready !!!'
                quit(-1)
        else:
            dimvals = dimv
    else:
# dimension values are identical everywhere! 
# from: http://stackoverflow.com/questions/16807011/python-how-to-identify-if-a-variable-is-an-array-or-a-scalar    
        dimvals = np.ones((var.shape), dtype=vark)*dimv

    derivate = np.zeros((var.shape), dtype=vark)
    if dim > len(var.shape) - 1:
        print errormsg
        print '  ' + fname + ': dimension',dim,' too big for given variable of ' +   \
          'shape:', var.shape,'!!!'
        quit(-1)

    slicebef = []
    sliceaft = []
    sliceder = []

    for id in range(len(var.shape)):
        if id == dim:
            slicebef.append(slice(0,var.shape[id]-2))
            sliceaft.append(slice(2,var.shape[id]))
            sliceder.append(slice(1,var.shape[id]-1))
        else:
            slicebef.append(slice(0,var.shape[id]))
            sliceaft.append(slice(0,var.shape[id]))
            sliceder.append(slice(0,var.shape[id]))

    if hasattr(dimv, "__len__"):
        derivate[tuple(sliceder)] = (var[tuple(slicebef)] + var[tuple(sliceaft)])/   \
          ((dimvals[tuple(sliceaft)] - dimvals[tuple(slicebef)]))
        print (dimvals[tuple(sliceaft)] - dimvals[tuple(slicebef)])
    else:
        derivate[tuple(sliceder)] = (var[tuple(slicebef)] + var[tuple(sliceaft)])/   \
          (2.*dimv)

#    print 'before________'
#    print var[tuple(slicebef)]

#    print 'after________'
#    print var[tuple(sliceaft)]

    return derivate

def rotational_z(Vx,Vy,pos):
    """ z-component of the rotatinoal of horizontal vectorial field
    \/ x (Vx,Vy,Vz) = \/xVy - \/yVx
    [Vx]= Variable component x
    [Vy]=  Variable component y
    [pos]= poisition of the grid points
    >>> rotational_z(np.arange(16).reshape(4,4)*1., np.arange(16).reshape(4,4)*1., 1.)
    [[  0.   1.   2.   0.]
     [ -4.   0.   0.  -7.]
     [ -8.   0.   0. -11.]
     [  0.  13.  14.   0.]]
    """

    fname =  'rotational_z'

    ndims = len(Vx.shape)
    rot1 = derivate_centered(Vy,ndims-1,pos)
    rot2 = derivate_centered(Vx,ndims-2,pos)

    rot = rot1 - rot2

    return rot

# Diagnostics
##

def var_clt(cfra):
    """ Function to compute the total cloud fraction following 'newmicro.F90' from 
      LMDZ using 1D vertical column values
      [cldfra]= cloud fraction values (assuming [[t],z,y,x])
    """
    ZEPSEC=1.0E-12

    fname = 'var_clt'

    zclear = 1.
    zcloud = 0.

    dz = cfra.shape[0]
    for iz in range(dz):
        zclear =zclear*(1.-np.max([cfra[iz],zcloud]))/(1.-np.min([zcloud,1.-ZEPSEC]))
        clt = 1. - zclear
        zcloud = cfra[iz]

    return clt

def compute_clt(cldfra, dimns, dimvns):
    """ Function to compute the total cloud fraction following 'newmicro.F90' from 
      LMDZ
    compute_clt(cldfra, dimnames)
      [cldfra]= cloud fraction values (assuming [[t],z,y,x])
      [dimns]= list of the name of the dimensions of [cldfra]
      [dimvns]= list of the name of the variables with the values of the 
        dimensions of [cldfra]
    """
    fname = 'compute_clt'

    cltdims = dimns[:]
    cltvdims = dimvns[:]

    if len(cldfra.shape) == 4:
        clt = np.zeros((cldfra.shape[0],cldfra.shape[2],cldfra.shape[3]),            \
          dtype=np.float)
        dx = cldfra.shape[3]
        dy = cldfra.shape[2]
        dz = cldfra.shape[1]
        dt = cldfra.shape[0]
        cltdims.pop(1)
        cltvdims.pop(1)

        for it in range(dt):
            for ix in range(dx):
                for iy in range(dy):
                    zclear = 1.
                    zcloud = 0.
                    ncvar.percendone(it*dx*dy + ix*dy + iy, dx*dy*dt, 5, 'diagnosted')
                    clt[it,iy,ix] = var_clt(cldfra[it,:,iy,ix])

    else:
        clt = np.zeros((cldfra.shape[1],cldfra.shape[2]), dtype=np.float)
        dx = cldfra.shape[2]
        dy = cldfra.shape[1]
        dy = cldfra.shape[0]
        cltdims.pop(0)
        cltvdims.pop(0)
        for ix in range(dx):
            for iy in range(dy):
                zclear = 1.
                zcloud = 0.
                ncvar.percendone(ix*dy + iy, dx*dy*dt, 5, 'diagnosted')
                clt[iy,ix] = var_clt(cldfra[:,iy,ix])

    return clt, cltdims, cltvdims

def var_cllmh(cfra, p):
    """ Fcuntion to compute cllmh on a 1D column
    """

    fname = 'var_cllmh'

    ZEPSEC =1.0E-12
    prmhc = 440.*100.
    prmlc = 680.*100.

    zclearl = 1.
    zcloudl = 0.
    zclearm = 1.
    zcloudm = 0.
    zclearh = 1.
    zcloudh = 0.

    dvz = cfra.shape[0]

    cllmh = np.ones((3), dtype=np.float)

    for iz in range(dvz):
        if p[iz] < prmhc:
            cllmh[2] = cllmh[2]*(1.-np.max([cfra[iz], zcloudh]))/(1.-                \
              np.min([zcloudh,1.-ZEPSEC]))
            zcloudh = cfra[iz]
        elif p[iz] >= prmhc and p[iz] < prmlc:
            cllmh[1] = cllmh[1]*(1.-np.max([cfra[iz], zcloudm]))/(1.-                \
              np.min([zcloudm,1.-ZEPSEC]))
            zcloudm = cfra[iz]
        elif p[iz] >= prmlc:
            cllmh[0] = cllmh[0]*(1.-np.max([cfra[iz], zcloudl]))/(1.-                \
              np.min([zcloudl,1.-ZEPSEC]))
            zcloudl = cfra[iz]

    cllmh = 1.- cllmh

    return cllmh

def compute_cllmh(cldfra, pres, dimns, dimvns):
    """ Function to compute cllmh: low/medium/hight cloud fraction following newmicro.F90 from LMDZ
    compute_clt(cldfra, pres, dimns, dimvns)
      [cldfra]= cloud fraction values (assuming [[t],z,y,x])
      [pres] = pressure field
      [dimns]= list of the name of the dimensions of [cldfra]
      [dimvns]= list of the name of the variables with the values of the 
        dimensions of [cldfra]
    """
    fname = 'compute_cllmh'

    cllmhdims = dimns[:]
    cllmhvdims = dimvns[:]

    if len(cldfra.shape) == 4:
        dx = cldfra.shape[3]
        dy = cldfra.shape[2]
        dz = cldfra.shape[1]
        dt = cldfra.shape[0]
        cllmhdims.pop(1)
        cllmhvdims.pop(1)

        cllmh = np.ones(tuple([3, dt, dy, dx]), dtype=np.float)

        for it in range(dt):
            for ix in range(dx):
                for iy in range(dy):
                    ncvar.percendone(it*dx*dy + ix*dy + iy, dx*dy*dt, 5, 'diagnosted')
                    cllmh[:,it,iy,ix] = var_cllmh(cldfra[it,:,iy,ix], pres[it,:,iy,ix])
        
    else:
        dx = cldfra.shape[2]
        dy = cldfra.shape[1]
        dz = cldfra.shape[0]
        cllmhdims.pop(0)
        cllmhvdims.pop(0)

        cllmh = np.ones(tuple([3, dy, dx]), dtype=np.float)

        for ix in range(dx):
            for iy in range(dy):
                ncvar.percendone(ix*dy + iy,dx*dy, 5, 'diagnosted')
                cllmh[:,iy,ix] = var_cllmh(cldfra[:,iy,ix], pres[:,iy,ix])

    return cllmh, cllmhdims, cllmhvdims

def var_virtualTemp (temp,rmix):
    """ This function returns virtual temperature in K, 
      temp: temperature [K]
      rmix: mixing ratio in [kgkg-1]
    """

    fname = 'var_virtualTemp'

    virtual=temp*(0.622+rmix)/(0.622*(1.+rmix))

    return virtual


def var_mslp(pres, psfc, ter, tk, qv):
    """ Function to compute mslp on a 1D column
    """

    fname = 'var_mslp'

    N = 1.0
    expon=287.04*.0065/9.81
    pref = 40000.

# First find where about 400 hPa is located
    dz=len(pres) 

    kref = -1
    pinc = pres[0] - pres[dz-1]

    if pinc < 0.:
        for iz in range(1,dz):
            if pres[iz-1] >= pref and pres[iz] < pref: 
                kref = iz
                break
    else:
        for iz in range(dz-1):
            if pres[iz] >= pref and pres[iz+1] < pref: 
                kref = iz
                break

    if kref == -1:
        print errormsg
        print '  ' + fname + ': no reference pressure:',pref,'found!!'
        print '    values:',pres[:]
        quit(-1)

    mslp = 0.

# We are below both the ground and the lowest data level.

# First, find the model level that is closest to a "target" pressure
# level, where the "target" pressure is delta-p less that the local
# value of a horizontally smoothed surface pressure field.  We use
# delta-p = 150 hPa here. A standard lapse rate temperature profile
# passing through the temperature at this model level will be used
# to define the temperature profile below ground.  This is similar
# to the Benjamin and Miller (1990) method, using  
# 700 hPa everywhere for the "target" pressure.

# ptarget = psfc - 15000.
    ptarget = 70000.
    dpmin=1.e4
    kupper = 0
    if pinc > 0.:
        for iz in range(dz-1,0,-1):
            kupper = iz
            dp=np.abs( pres[iz] - ptarget )
            if dp < dpmin: exit
            dpmin = np.min([dpmin, dp])
    else:
        for iz in range(dz):
            kupper = iz
            dp=np.abs( pres[iz] - ptarget )
            if dp < dpmin: exit
            dpmin = np.min([dpmin, dp])

    pbot=np.max([pres[0], psfc])
#    zbot=0.

#    tbotextrap=tk(i,j,kupper,itt)*(pbot/pres_field(i,j,kupper,itt))**expon
#    tvbotextrap=virtual(tbotextrap,qv(i,j,1,itt))

#    data_out(i,j,itt,1) = (zbot+tvbotextrap/.0065*(1.-(interp_levels(1)/pbot)**expon))
    tbotextrap = tk[kupper]*(psfc/ptarget)**expon
    tvbotextrap = var_virtualTemp(tbotextrap, qv[kupper])
    mslp = psfc*( (tvbotextrap+0.0065*ter)/tvbotextrap)**(1./expon)

    return mslp

def compute_mslp(pressure, psurface, terrain, temperature, qvapor, dimns, dimvns):
    """ Function to compute mslp: mean sea level pressure following p_interp.F90 from WRF
    var_mslp(pres, ter, tk, qv, dimns, dimvns)
      [pressure]= pressure field [Pa] (assuming [[t],z,y,x])
      [psurface]= surface pressure field [Pa]
      [terrain]= topography [m]
      [temperature]= temperature [K]
      [qvapor]= water vapour mixing ratio [kgkg-1]
      [dimns]= list of the name of the dimensions of [cldfra]
      [dimvns]= list of the name of the variables with the values of the 
        dimensions of [pres]
    """

    fname = 'compute_mslp'

    mslpdims = list(dimns[:])
    mslpvdims = list(dimvns[:])

    if len(pressure.shape) == 4:
        mslpdims.pop(1)
        mslpvdims.pop(1)
    else:
        mslpdims.pop(0)
        mslpvdims.pop(0)

    if len(pressure.shape) == 4:
        dx = pressure.shape[3]
        dy = pressure.shape[2]
        dz = pressure.shape[1]
        dt = pressure.shape[0]

        mslpv = np.zeros(tuple([dt, dy, dx]), dtype=np.float)

# Terrain... to 2D !
        terval = np.zeros(tuple([dy, dx]), dtype=np.float)
        if len(terrain.shape) == 3:
            terval = terrain[0,:,:]
        else:
            terval = terrain

        for ix in range(dx):
            for iy in range(dy):
                if terval[iy,ix] > 0.:
                    for it in range(dt):
                        mslpv[it,iy,ix] = var_mslp(pressure[it,:,iy,ix],             \
                          psurface[it,iy,ix], terval[iy,ix], temperature[it,:,iy,ix],\
                          qvapor[it,:,iy,ix])

                        ncvar.percendone(it*dx*dy + ix*dy + iy, dx*dy*dt, 5, 'diagnosted')
                else:
                    mslpv[:,iy,ix] = psurface[:,iy,ix]

    else:
        dx = pressure.shape[2]
        dy = pressure.shape[1]
        dz = pressure.shape[0]

        mslpv = np.zeros(tuple([dy, dx]), dtype=np.float)

# Terrain... to 2D !
        terval = np.zeros(tuple([dy, dx]), dtype=np.float)
        if len(terrain.shape) == 3:
            terval = terrain[0,:,:]
        else:
            terval = terrain

        for ix in range(dx):
            for iy in range(dy):
                ncvar.percendone(ix*dy + iy,dx*dy, 5, 'diagnosted')
                if terval[iy,ix] > 0.:
                    mslpv[iy,ix] = var_mslp(pressure[:,iy,ix], psurface[iy,ix],          \
                      terval[iy,ix], temperature[:,iy,ix], qvapor[:,iy,ix])
                else:
                    mslpv[iy,ix] = psfc[iy,ix]

    return mslpv, mslpdims, mslpvdims

def compute_prw(dens, q, dimns, dimvns):
    """ Function to compute water vapour path (prw)
      [dens] = density [in kgkg-1] (assuming [t],z,y,x)
      [q] = mixing ratio in [kgkg-1] (assuming [t],z,y,x)
      [dimns]= list of the name of the dimensions of [q]
      [dimvns]= list of the name of the variables with the values of the 
        dimensions of [q]
    """
    fname = 'compute_prw'

    prwdims = dimns[:]
    prwvdims = dimvns[:]

    if len(q.shape) == 4:
        prwdims.pop(1)
        prwvdims.pop(1)
    else:
        prwdims.pop(0)
        prwvdims.pop(0)

    data1 = dens*q
    prw = np.sum(data1, axis=1)

    return prw, prwdims, prwvdims

def compute_rh(p, t, q, dimns, dimvns):
    """ Function to compute relative humidity following 'Tetens' equation (T,P) ...'
      [t]= temperature (assuming [[t],z,y,x] in [K])
      [p] = pressure field (assuming in [hPa])
      [q] = mixing ratio in [kgkg-1]
      [dimns]= list of the name of the dimensions of [t]
      [dimvns]= list of the name of the variables with the values of the 
        dimensions of [t]
    """
    fname = 'compute_rh'

    rhdims = dimns[:]
    rhvdims = dimvns[:]

    data1 = 10.*0.6112*np.exp(17.67*(t-273.16)/(t-29.65))
    data2 = 0.622*data1/(0.01*p-(1.-0.622)*data1)

    rh = q/data2

    return rh, rhdims, rhvdims

def turbulence_var(varv, dimvn, dimn):
    """ Function to compute the Taylor's decomposition turbulence term from a a given variable
      x*=<x^2>_t-(<X>_t)^2
    turbulence_var(varv,dimn)
      varv= values of the variable
      dimvn= names of the dimension of the variable
      dimn= names of the dimensions (as a dictionary with 'X', 'Y', 'Z', 'T')
    >>> turbulence_var(np.arange((27)).reshape(3,3,3),['time','y','x'],{'T':'time', 'Y':'y', 'X':'x'})
    [[ 54.  54.  54.]
     [ 54.  54.  54.]
     [ 54.  54.  54.]]
    """
    fname = 'turbulence_varv'

    timedimid = dimvn.index(dimn['T'])

    varv2 = varv*varv

    vartmean = np.mean(varv, axis=timedimid)
    var2tmean = np.mean(varv2, axis=timedimid)

    varvturb = var2tmean - (vartmean*vartmean)

    return varvturb

def compute_turbulence(v, dimns, dimvns):
    """ Function to compute the rubulence term of the Taylor's decomposition ...'
      x*=<x^2>_t-(<X>_t)^2
      [v]= variable (assuming [[t],z,y,x])
      [dimns]= list of the name of the dimensions of [v]
      [dimvns]= list of the name of the variables with the values of the 
        dimensions of [v]
    """
    fname = 'compute_turbulence'

    turbdims = dimns[:]
    turbvdims = dimvns[:]

    turbdims.pop(0)
    turbvdims.pop(0)

    v2 = v*v

    vartmean = np.mean(v, axis=0)
    var2tmean = np.mean(v2, axis=0)

    turb = var2tmean - (vartmean*vartmean)

    return turb, turbdims, turbvdims

def timeunits_seconds(dtu):
    """ Function to transform a time units to seconds
    timeunits_seconds(timeuv)
      [dtu]= time units value to transform in seconds
    """
    fname='timunits_seconds'

    if dtu == 'years':
        times = 365.*24.*3600.
    elif dtu == 'weeks':
        times = 7.*24.*3600.
    elif dtu == 'days':
        times = 24.*3600.
    elif dtu == 'hours':
        times = 3600.
    elif dtu == 'minutes':
        times = 60.
    elif dtu == 'seconds':
        times = 1.
    elif dtu == 'miliseconds':
        times = 1./1000.
    else:
        print errormsg
        print '  ' + fname  + ": time units '" + dtu + "' not ready !!"
        quit(-1)

    return times

####### ###### ##### #### ### ## #
comboinf="\nIF -d 'variable_combo', provides information of the combination to obtain -v [varn] with the ASCII file with the combinations as -f [combofile]"

parser = OptionParser()
parser.add_option("-f", "--netCDF_file", dest="ncfile", help="file to use", metavar="FILE")
parser.add_option("-d", "--dimensions", dest="dimns",  
  help="[dimxn]@[dxvn],[dimyn]@[dxvn],[...,[dimtn]@[dxvn]], ',' list with the couples [dimDn]@[dDvn], [dimDn], name of the dimension D and name of the variable [dDvn] with the values of the dimension" + comboinf, 
  metavar="LABELS")
parser.add_option("-v", "--variables", dest="varns", 
  help=" [varn1]|[var11]@[...[varN1]],[...,[varnM]|[var1M]@[...[varLM]]] ',' list of variables to compute [varnK] and its necessary ones [var1K]...[varPK]", metavar="VALUES")

(opts, args) = parser.parse_args()

#######    #######
## MAIN
    #######
availdiags = ['ACRAINTOT', 'clt', 'cllmh', 'deaccum', 'LMDZrh', 'mslp', 'RAINTOT',   \
  'rvors', 'turbulence', 'WRFrvors']

# Variables not to check
NONcheckingvars = ['cllmh', 'deaccum', 'WRFbils', 'WRFdens', 'WRFgeop', 'WRFp',      \
  'WRFpos', 'WRFprc', 'WRFprls', 'WRFrh', 'LMDZrh', 'LMDZrhs', 'WRFrhs', 'WRFrvors', \
  'WRFt', 'WRFtime']

ofile = 'diagnostics.nc'

dimns = opts.dimns
varns = opts.varns

# Special method. knowing variable combination
##
if opts.dimns == 'variable_combo':
    print warnmsg
    print '  ' + main + ': knowing variable combination !!!'
    combination = variable_combo(opts.varns,opts.ncfile)
    print '     COMBO: ' + combination
    quit(-1)

if not os.path.isfile(opts.ncfile):
    print errormsg
    print '   ' + main + ": file '" + opts.ncfile + "' does not exist !!"
    quit(-1)

ncobj = NetCDFFile(opts.ncfile, 'r')

# File creation
newnc = NetCDFFile(ofile,'w')

# dimensions
dimvalues = dimns.split(',')
dnames = []
dvnames = []

for dimval in dimvalues:
    dnames.append(dimval.split('@')[0])
    dvnames.append(dimval.split('@')[1])

# diagnostics to compute
diags = varns.split(',')
Ndiags = len(diags)

# Looking for specific variables that might be use in more than one diagnostic
WRFp_compute = False
WRFt_compute = False
WRFrh_compute = False
WRFght_compute = False
WRFdens_compute = False
WRFpos_compute = False

for idiag in range(Ndiags):
    if diags[idiag].split('|')[1].find('@') == -1:
        depvars = diags[idiag].split('|')[1]
        if depvars == 'WRFp': WRFp_compute = True
        if depvars == 'WRFt': WRFt_compute = True
        if depvars == 'WRFrh': WRFrh_compute = True
        if depvars == 'WRFght': WRFght_compute = True
        if depvars == 'WRFdens': WRFdens_compute = True
        if depvars == 'WRFpos': WRFpos_compute = True

    else:
        depvars = diags[idiag].split('|')[1].split('@')
        if ncvar.searchInlist(depvars, 'WRFp'): WRFp_compute = True
        if ncvar.searchInlist(depvars, 'WRFt'): WRFt_compute = True
        if ncvar.searchInlist(depvars, 'WRFrh'): WRFrh_compute = True
        if ncvar.searchInlist(depvars, 'WRFght'): WRFght_compute = True
        if ncvar.searchInlist(depvars, 'WRFdens'): WRFdens_compute = True
        if ncvar.searchInlist(depvars, 'WRFpos'): WRFpos_compute = True

if WRFp_compute:
    print '  ' + main + ': Retrieving pressure value from WRF as P + PB'
    dimv = ncobj.variables['P'].shape
    WRFp = ncobj.variables['P'][:] + ncobj.variables['PB'][:]

if WRFght_compute:
    print '    ' + main + ': computing geopotential height from WRF as PH + PHB ...' 
    WRFght = ncobj.variables['PH'][:] + ncobj.variables['PHB'][:]

if WRFrh_compute:
    print '    ' + main + ": computing relative humidity from WRF as 'Tetens'" +     \
      ' equation (T,P) ...'
    p0=100000.
    p=ncobj.variables['P'][:] + ncobj.variables['PB'][:]
    tk = (ncobj.variables['T'][:] + 300.)*(p/p0)**(2./7.)
    qv = ncobj.variables['QVAPOR'][:]

    data1 = 10.*0.6112*np.exp(17.67*(tk-273.16)/(tk-29.65))
    data2 = 0.622*data1/(0.01*p-(1.-0.622)*data1)

    WRFrh = qv/data2

if WRFt_compute:
    print '    ' + main + ': computing temperature from WRF as inv_potT(T + 300) ...'
    p0=100000.
    p=ncobj.variables['P'][:] + ncobj.variables['PB'][:]

    WRFt = (ncobj.variables['T'][:] + 300.)*(p/p0)**(2./7.)

if WRFdens_compute:
    print '    ' + main + ': computing air density from WRF as ((MU + MUB) * ' +     \
      'DNW)/g ...'
    grav = 9.81

# Just we need in in absolute values: Size of the central grid cell
##    dxval = ncobj.getncattr('DX')
##    dyval = ncobj.getncattr('DY')
##    mapfac = ncobj.variables['MAPFAC_M'][:]
##    area = dxval*dyval*mapfac

    mu = (ncobj.variables['MU'][:] + ncobj.variables['MUB'][:])
    dnw = ncobj.variables['DNW'][:]

    WRFdens = np.zeros((mu.shape[0], dnw.shape[1], mu.shape[1], mu.shape[2]),        \
      dtype=np.float)
    levval = np.zeros((mu.shape[1], mu.shape[2]), dtype=np.float)

    for it in range(mu.shape[0]):
        for iz in range(dnw.shape[1]):
            levval.fill(np.abs(dnw[it,iz]))
            WRFdens[it,iz,:,:] = levval
            WRFdens[it,iz,:,:] = mu[it,:,:]*WRFdens[it,iz,:,:]/grav

if WRFpos_compute:
# WRF positions from the lowest-leftest corner of the matrix
    print '    ' + main + ': computing position from MAPFAC_M as sqrt(DY*j**2 + ' +  \
      'DX*x**2)*MAPFAC_M ...'

    mapfac = ncobj.variables['MAPFAC_M'][:]

    distx = np.float(ncobj.getncattr('DX'))
    disty = np.float(ncobj.getncattr('DY'))

    print 'distx:',distx,'disty:',disty

    dx = mapfac.shape[2]
    dy = mapfac.shape[1]
    dt = mapfac.shape[0]

    WRFpos = np.zeros((dt, dy, dx), dtype=np.float)

    for i in range(1,dx):
        WRFpos[0,0,i] = distx*i/mapfac[0,0,i]
    for j in range(1,dy):
        i=0
        WRFpos[0,j,i] = WRFpos[0,j-1,i] + disty/mapfac[0,j,i]
        for i in range(1,dx):
#            WRFpos[0,j,i] = np.sqrt((disty*j)**2. + (distx*i)**2.)/mapfac[0,j,i]
#            WRFpos[0,j,i] = np.sqrt((disty*j)**2. + (distx*i)**2.)
             WRFpos[0,j,i] = WRFpos[0,j,i-1] + distx/mapfac[0,j,i]

    for it in range(1,dt):
        WRFpos[it,:,:] = WRFpos[0,:,:]

### ## #
# Going for the diagnostics
### ## #
print '  ' + main + ' ...'

for idiag in range(Ndiags):
    print '    diagnostic:',diags[idiag]
    diag = diags[idiag].split('|')[0]
    depvars = diags[idiag].split('|')[1].split('@')
    if diags[idiag].split('|')[1].find('@') != -1:
        depvars = diags[idiag].split('|')[1].split('@')
        if depvars[0] == 'deaccum': diag='deaccum'
        for depv in depvars:
            if not ncobj.variables.has_key(depv) and not                             \
              ncvar.searchInlist(NONcheckingvars, depv) and depvars[0] != 'deaccum':
                print errormsg
                print '  ' + main + ": file '" + opts.ncfile +                       \
                  "' does not have variable '" + depv + "' !!"
                quit(-1)
    else:
        depvars = diags[idiag].split('|')[1]
        if not ncobj.variables.has_key(depvars) and not                              \
          ncvar.searchInlist(NONcheckingvars, depvars) and depvars[0] != 'deaccum':
            print errormsg
            print '  ' + main + ": file '" + opts.ncfile +                           \
              "' does not have variable '" + depvars + "' !!"
            quit(-1)

    print "\n    Computing '" + diag + "' from: ", depvars, '...'

# acraintot: accumulated total precipitation from WRF RAINC, RAINNC
    if diag == 'ACRAINTOT':
            
        var0 = ncobj.variables[depvars[0]]
        var1 = ncobj.variables[depvars[1]]
        diagout = var0[:] + var1[:]

        dnamesvar = var0.dimensions
        dvnamesvar = ncvar.var_dim_dimv(dnamesvar,dnames,dvnames)

        ncvar.insert_variable(ncobj, 'acpr', diagout, dnamesvar, dvnamesvar, newnc)

# cllmh with cldfra, pres
    elif diag == 'cllmh':
            
        var0 = ncobj.variables[depvars[0]]
        if depvars[1] == 'WRFp':
            var1 = WRFp
        else:
            var01 = ncobj.variables[depvars[1]]
            if len(size(var1.shape)) < len(size(var0.shape)):
                var1 = np.brodcast_arrays(var01,var0)[0]
            else:
                var1 = var01

        diagout, diagoutd, diagoutvd = compute_cllmh(var0,var1,dnames,dvnames)
        ncvar.insert_variable(ncobj, 'cll', diagout[0,:], diagoutd, diagoutvd, newnc)
        ncvar.insert_variable(ncobj, 'clm', diagout[1,:], diagoutd, diagoutvd, newnc)
        ncvar.insert_variable(ncobj, 'clh', diagout[2,:], diagoutd, diagoutvd, newnc)

# clt with cldfra
    elif diag == 'clt':
            
        var0 = ncobj.variables[depvars]
        diagout, diagoutd, diagoutvd = compute_clt(var0,dnames,dvnames)
        ncvar.insert_variable(ncobj, 'clt', diagout, diagoutd, diagoutvd, newnc)

# deaccum: deacumulation of any variable as (Variable, time [as [tunits] 
#   from/since ....], newvarname)
    elif diag == 'deaccum':

        var0 = ncobj.variables[depvars[1]]
        var1 = ncobj.variables[depvars[2]]

        dnamesvar = var0.dimensions
        dvnamesvar = ncvar.var_dim_dimv(dnamesvar,dnames,dvnames)

        diagout, diagoutd, diagoutvd = compute_deaccum(var0,dnamesvar,dvnamesvar)

# Transforming to a flux
        if depvars[2] == 'XTIME':
            dtimeunits = var1.getncattr('description')
            tunits = dtimeunits.split(' ')[0]
        else:
            dtimeunits = var1.getncattr('units')
            tunits = dtimeunits.split(' ')[0]

        dtime = (var1[1] - var1[0])*timeunits_seconds(tunits)
        ncvar.insert_variable(ncobj, depvars[3], diagout/dtime, diagoutd, diagoutvd, newnc)

# LMDZrh (pres, t, r)
    elif diag == 'LMDZrh':
            
        var0 = ncobj.variables[depvars[0]][:]
        var1 = ncobj.variables[depvars[1]][:]
        var2 = ncobj.variables[depvars[2]][:]

        diagout, diagoutd, diagoutvd = compute_rh(var0,var1,var2,dnames,dvnames)
        ncvar.insert_variable(ncobj, 'hus', diagout, diagoutd, diagoutvd, newnc)

# LMDZrhs (psol, t2m, q2m)
    elif diag == 'LMDZrhs':
            
        var0 = ncobj.variables[depvars[0]][:]
        var1 = ncobj.variables[depvars[1]][:]
        var2 = ncobj.variables[depvars[2]][:]

        dnamesvar = ncobj.variables[depvars[0]].dimensions
        dvnamesvar = ncvar.var_dim_dimv(dnamesvar,dnames,dvnames)

        diagout, diagoutd, diagoutvd = compute_rh(var0,var1,var2,dnamesvar,dvnamesvar)

        ncvar.insert_variable(ncobj, 'huss', diagout, diagoutd, diagoutvd, newnc)

# mslp: mean sea level pressure (pres, psfc, terrain, temp, qv)
    elif diag == 'mslp' or diag == 'WRFmslp':
            
        var1 = ncobj.variables[depvars[1]][:]
        var2 = ncobj.variables[depvars[2]][:]
        var4 = ncobj.variables[depvars[4]][:]

        if diag == 'WRFmslp':
            var0 = WRFp
            var3 = WRFt
            dnamesvar = ncobj.variables['P'].dimensions
        else:
            var0 = ncobj.variables[depvars[0]][:]
            var3 = ncobj.variables[depvars[3]][:]
            dnamesvar = ncobj.variables[depvars[0]].dimensions

        dvnamesvar = ncvar.var_dim_dimv(dnamesvar,dnames,dvnames)

        diagout, diagoutd, diagoutvd = compute_mslp(var0, var1, var2, var3, var4,    \
          dnamesvar, dvnamesvar)

        ncvar.insert_variable(ncobj, 'psl', diagout, diagoutd, diagoutvd, newnc)

# raintot: instantaneous total precipitation from WRF as (RAINC + RAINC) / dTime
    elif diag == 'RAINTOT':

        var0 = ncobj.variables[depvars[0]]
        var1 = ncobj.variables[depvars[1]]
        if depvars[2] != 'WRFtime':
            var2 = ncobj.variables[depvars[2]]

        var = var0[:] + var1[:]

        dnamesvar = var0.dimensions
        dvnamesvar = ncvar.var_dim_dimv(dnamesvar,dnames,dvnames)

        diagout, diagoutd, diagoutvd = compute_deaccum(var,dnamesvar,dvnamesvar)

# Transforming to a flux
        if depvars[2] != 'WRFtime':
            dtimeunits = var2.getncattr('units')
            tunits = dtimeunits.split(' ')[0]

            dtime = (var2[1] - var2[0])*timeunits_seconds(tunits)
        else:
            var2 = ncobj.variables['Times']
            time1 = var2[0,:]
            time2 = var2[1,:]
            tmf1 = ''
            tmf2 = ''
            for ic in range(len(time1)):
                tmf1 = tmf1 + time1[ic]
                tmf2 = tmf2 + time2[ic]
            dtdate1 = dt.datetime.strptime(tmf1,"%Y-%m-%d_%H:%M:%S")
            dtdate2 = dt.datetime.strptime(tmf2,"%Y-%m-%d_%H:%M:%S")
            diffdate12 = dtdate2 - dtdate1
            dtime = diffdate12.total_seconds()
            print 'dtime:',dtime

        ncvar.insert_variable(ncobj, 'pr', diagout/dtime, diagoutd, diagoutvd, newnc)

# turbulence (var)
    elif diag == 'turbulence':

        var0 = ncobj.variables[depvars][:]

        dnamesvar = list(ncobj.variables[depvars].dimensions)
        dvnamesvar = ncvar.var_dim_dimv(dnamesvar,dnames,dvnames)

        diagout, diagoutd, diagoutvd = compute_turbulence(var0,dnamesvar,dvnamesvar)
        valsvar = ncvar.variables_values(depvars)

        ncvar.insert_variable(ncobj, valsvar[0] + 'turb', diagout, diagoutd, 
          diagoutvd, newnc)
        varobj = newnc.variables[valsvar[0] + 'turb']
        attrv = varobj.long_name
        attr = varobj.delncattr('long_name')
        newattr = ncvar.set_attribute(varobj, 'long_name', attrv +                   \
          " Taylor decomposition turbulence term")

# WRFbils fom WRF as HFX + LH
    elif diag == 'WRFbils':
            
        var0 = ncobj.variables[depvars[0]][:]
        var1 = ncobj.variables[depvars[1]][:]

        diagout = var0 + var1

        ncvar.insert_variable(ncobj, 'bils', diagout, dnames, dvnames, newnc)

# WRFp pressure from WRF as P + PB
    elif diag == 'WRFp':
            
        diagout = WRFp

        ncvar.insert_variable(ncobj, 'pres', diagout, dnames, dvnames, newnc)

# WRFpos 
    elif diag == 'WRFpos':
            
        dnamesvar = ncobj.variables['MAPFAC_M'].dimensions
        dvnamesvar = ncvar.var_dim_dimv(dnamesvar,dnames,dvnames)

        ncvar.insert_variable(ncobj, 'WRFpos', WRFpos, dnamesvar, dvnamesvar, newnc)

# WRFprw WRF water vapour path WRFdens, QVAPOR
    elif diag == 'WRFprw':
            
        var0 = WRFdens
        var1 = ncobj.variables[depvars[1]]

        dnamesvar = list(var1.dimensions)
        dvnamesvar = ncvar.var_dim_dimv(dnamesvar,dnames,dvnames)

        diagout, diagoutd, diagoutvd = compute_prw(var0, var1, dnamesvar,dvnamesvar)

        ncvar.insert_variable(ncobj, 'prw', diagout, diagoutd, diagoutvd, newnc)

# WRFrh (P, T, QVAPOR)
    elif diag == 'WRFrh':
            
        dnamesvar = list(ncobj.variables[depvars[2]].dimensions)
        dvnamesvar = ncvar.var_dim_dimv(dnamesvar,dnames,dvnames)

        ncvar.insert_variable(ncobj, 'hus', WRFrh, dnames, dvnames, newnc)

# WRFrhs (PSFC, T2, Q2)
    elif diag == 'WRFrhs':
            
        var0 = ncobj.variables[depvars[0]][:]
        var1 = ncobj.variables[depvars[1]][:]
        var2 = ncobj.variables[depvars[2]][:]

        dnamesvar = list(ncobj.variables[depvars[2]].dimensions)
        dvnamesvar = ncvar.var_dim_dimv(dnamesvar,dnames,dvnames)

        diagout, diagoutd, diagoutvd = compute_rh(var0,var1,var2,dnamesvar,dvnamesvar)
        ncvar.insert_variable(ncobj, 'huss', diagout, diagoutd, diagoutvd, newnc)

# rvors (u10, v10, WRFpos)
    elif diag == 'WRFrvors':
            
        var0 = ncobj.variables[depvars[0]]
        var1 = ncobj.variables[depvars[1]]

        diagout = rotational_z(var0, var1, distx)

        dnamesvar = ncobj.variables[depvars[0]].dimensions
        dvnamesvar = ncvar.var_dim_dimv(dnamesvar,dnames,dvnames)

        ncvar.insert_variable(ncobj, 'rvors', diagout, dnamesvar, dvnamesvar, newnc)

# wss (u10, v10)
    elif diag == 'wss':
            
        var0 = ncobj.variables[depvars[0]][:]
        var1 = ncobj.variables[depvars[1]][:]

        diagout = np.sqrt(var0*var0 + var1*var1)

        dnamesvar = ncobj.variables[depvars[0]].dimensions
        dvnamesvar = ncvar.var_dim_dimv(dnamesvar,dnames,dvnames)

        print 'dnamesvar',dnamesvar
        print 'dnames',dnames
        print 'dvnames',dvnames
        print 'dvnamesvar',dvnamesvar

        ncvar.insert_variable(ncobj, 'wss', diagout, dnamesvar, dvnamesvar, newnc)

    else:
        print errormsg
        print '  ' + main + ": diagnostic '" + diag + "' not ready!!!"
        print '    available diagnostics: ', availdiags
        quit(-1)

    newnc.sync()

#   end of diagnostics

# Global attributes
##
atvar = ncvar.set_attribute(newnc, 'program', 'diagnostics.py')
atvar = ncvar.set_attribute(newnc, 'version', '1.0')
atvar = ncvar.set_attribute(newnc, 'author', 'Fita Borrell, Lluis')
atvar = ncvar.set_attribute(newnc, 'institution', 'Laboratoire Meteorologie ' +      \
  'Dynamique')
atvar = ncvar.set_attribute(newnc, 'university', 'Universite Pierre et Marie ' +     \
  'Curie -- Jussieu')
atvar = ncvar.set_attribute(newnc, 'centre', 'Centre national de la recherche ' +    \
  'scientifique')
atvar = ncvar.set_attribute(newnc, 'city', 'Paris')
atvar = ncvar.set_attribute(newnc, 'original_file', opts.ncfile)

gorigattrs = ncobj.ncattrs()

for attr in gorigattrs:
    attrv = ncobj.getncattr(attr)
    atvar = ncvar.set_attribute(newnc, attr, attrv)

ncobj.close()
newnc.close()

print '\n' + main + ': successfull writting of diagnostics file "' + ofile + '" !!!'
