# Python script to transform from LMDZ output to regular output
# L. Fita, LMD. Jussieu, September 2013
#   On de coupling of WRF and LMDZ, matrices of (1,dimx*dimy,dimz) sizre are passed and therefore outputted with this shape. 
#     This script transform the LMDZ output to a regular (dimx,dimy,dimz) output
#
## export PATH=/u/lflmd/bin/gcc_Python-2.7.5/bin:${PATH}
## e.g.: # WRFLMDZ_regularout.py -d 31,31 -f /d4/lflmd/etudes/WRF_LMDZ/test_phylmd/run/histmth.nc -o histins_reg.nc
import numpy as np
from netCDF4 import Dataset as NetCDFFile
import os
import re
import nc_var_tools as ncvar
from optparse import OptionParser

main='WRFLMDZout_regularout.py'
errormsg='ERROR -- error -- ERROR -- error'
warnmsg='WARNING -- warning -- WARNING -- warning'

####### ###### ##### #### #### ## #

parser = OptionParser()
parser.add_option("-d", "--dimensions", dest="dims",
                  help="dimx,dimy", metavar="VALUES")
parser.add_option("-f", "--LMDZ_file", dest="lfile",
                  help="LMDZ file to use", metavar="FILE")
parser.add_option("-o", "--output_file", dest="ofile",
                  help="output file name to use", metavar="FILE")

(opts, args) = parser.parse_args()

#######    #######
## MAIN
    #######

ofile=opts.ofile

dimx=int(opts.dims.split(',')[0])
dimy=int(opts.dims.split(',')[1])

if not os.path.isfile(opts.lfile):
    print errormsg
    print '  ' + main + ' LMDZ file: "' + opts.lfile + '" does not exist !!'
    print errormsg
    quit(-1)

objlfile = NetCDFFile(opts.lfile, 'r')

# Checking dimensions. Remeber lat=(dimx*dimy)
##
varobj=objlfile.variables['lat']
varinf = ncvar.variable_inf(varobj)
if dimx*dimy != varinf.dims[0]:
    print errormsg
    print '  ' + main + ': given dimensions', dimx,',', dimy,                        \
      'does not coincide with lat size: ',varinf.dims [0],'!!!'
    quit(-1)

# Checking dimensions. time
##
varobj=objlfile.variables['time_counter']
varinf = ncvar.variable_inf(varobj)
if varinf.dims[0] == 0:
    print errormsg
    print '  ' + main + ': variable "time_counter" does not have values!!!!'
    quit(-1)

objofile = NetCDFFile(opts.ofile, 'w')

lfilevars = objlfile.variables
for varn in lfilevars:
    print '  Transforming "' + varn + '"...'
    varobj=objlfile.variables[varn]

    varinf = ncvar.variable_inf(varobj)

    vardims = varinf.dims
    vardimns = varinf.dimns
    vartype = varinf.dtype
    varattr = varinf.attributes

# Checking dimensions
##
    newdimns = []
    for vdim in vardimns:
        objofiledimns = objofile.dimensions.keys()
        if vdim == 'time_counter':
            newvdim = 'time'
        elif vdim == 'tbnds':
            newvdim = 'bnds'
        else:
            newvdim = vdim

        if not ncvar.searchInlist(objofiledimns, newvdim):
            objvdim = objlfile.dimensions[vdim]
            if vdim == 'lon':
                dimsize = dimx
            elif vdim == 'lat':
                dimsize = dimy
            else:
                if objvdim.isunlimited():
                    dimsize=None
                else:
                    dimsize = len(objvdim)

            print '      Adding dimension "' + newvdim + '" size:', dimsize
            dim = objofile.createDimension(newvdim, dimsize)
        newdimns.append(newvdim)

    newvardimns=tuple(newdimns)

# Checking fill value
## 
    if ncvar.searchInlist(varattr, '_FillValue'):
        varfil = varobj._FillValue
    else:
        varfil = False

    if varn == 'time_counter':
        varn = 'time'
    elif varn == 'time_counter_bnds':
        varn='time_bnds'

    varvalues = varobj[:]

    varshape = varvalues.shape
    Ndims = len(varshape)

    regvardims = list(vardims)
    for idim in range(varinf.Ndims):
        if vardimns[idim] == 'lon':
            regvardims[idim] = dimx
        elif vardimns[idim] == 'lat':
            regvardims[idim] = dimy
        else:
            regvardims[idim] = vardims[idim]

    if varinf.Ndims != 0:
        print '      Adding variable: "' + varn + '" shape: ', regvardims[0:varinf.Ndims]
        newvar = objofile.createVariable(varn, vartype, newvardimns, fill_value=varfil)

        if varinf.Ndims == 1:
            varvals = np.zeros(regvardims[0], dtype=vartype) 
            if varn == 'lon':
                dlon=360./(dimx*1.)
                print 'Computing the lon values centered at 180.!'
                varvals=dlon/2. + dlon*np.array(range(dimx))
                print varvals
            elif varn == 'lat':
                varvals = varvalues[dimx*np.array(range(dimy))]
            elif varn == 'presnivs':
                print varvalues
                dimz=regvardims[0]
                varvals = varvalues
#                varvals = dimz*1.-np.array(range(dimz))*1.
            else:
                varvals = varvalues
        elif varinf.Ndims == 2:
            varvals = np.zeros((regvardims[0], regvardims[1]), dtype=vartype)
            varvals = varvalues.reshape(regvardims[0], regvardims[1])
        elif varinf.Ndims == 3:
            varvals = np.zeros((regvardims[0], regvardims[1], regvardims[2]),        \
              dtype=vartype)
            varvals = varvalues.reshape(regvardims[0], regvardims[1], regvardims[2])
        elif varinf.Ndims == 4:
            varvals = np.zeros((regvardims[0], regvardims[1], regvardims[2],         \
              regvardims[3]), dtype=vartype) 
            varvals = varvalues.reshape(regvardims[0], regvardims[1], regvardims[2], \
              regvardims[3])
        elif varinf.Ndims == 5:
            varvals = np.zeros((regvardims[0], regvardims[1], regvardims[2],         \
              regvardims[3], regvardims[4]), dtype=vartype)
            varvals = varvalues.reshape(regvardims[0], regvardims[1], regvardims[2], \
              regvardims[3], regvardims[4])
        elif varinf.Ndims == 6:
            varvals = np.zeros((regvardims[0], regvardims[1], regvardims[2],         \
              regvardims[3], regvardims[4], regvardims[5]), dtype=vartype)
            varvals = varvalues.reshape(regvardims[0], regvardims[1], regvardims[2], \
              regvardims[3], regvardims[4], regvardims[5])
        else:
            print errormsg
            print '  variable size ',varinf.Ndims,' is not ready!!!!'

    for attr in varattr:
        newvarattrs = newvar.ncattrs()
        attrv = varobj.getncattr(attr)
        if not ncvar.searchInlist(newvarattrs, attr):
            if attr == 'coordinates':
                newattrv = attrv.replace('time_counter','time')
                attrv = newattrv.replace('time_counter_bnds','time_bnds')
                newattrv = attrv.replace('tbnds','bnds')
                attrv = newattrv

            newvar.setncattr(attr, attrv)

    newvar[:] = varvals
    objofile.sync()

objofile.sync()
lfilegattrs = objlfile.ncattrs()
Nattrs = len(lfilegattrs)
print '   Adding ', Nattrs,' global atributes'

for attr in lfilegattrs:
    attrv = objlfile.getncattr(attr)
    atvar = ncvar.set_attribute(objofile, attr, attrv)

objlfile.close()
objofile.sync()
objofile.close()

print 'File "' + opts.ofile + '" succesffully created!'
