# Reconstructing ORCHIDEE's forcing files as matrices for a gien area
# L. Fita, CIMA. December 2017
#
import numpy as np
from netCDF4 import Dataset as NetCDFFile
import os
import re
import numpy.ma as ma
# Importing generic tools file 'generic_tools.py'
import generic_tools as gen
import nc_var_tools as ncvar
import subprocess as sub
import module_ForSci as Sci
from optparse import OptionParser

parser = OptionParser()
parser.add_option("-F", "--filename", dest="fn", help="name of files (overwrites -f option)", \
  metavar="VALUE")
parser.add_option("-f", "--fileHEader", dest="fh", help="header of files",           \
  metavar="VALUE")
parser.add_option("-i", "--indices", dest="indn", help="name of the variable indices", \
  metavar="VALUE")
parser.add_option("-L", "--latitude", dest="latn", help="name of the variable latitiude", \
  metavar="VALUE")
parser.add_option("-l", "--longitude", dest="lonn", help="name of the variable longitiude", \
  metavar="VALUE")
parser.add_option("-t", "--TransposeVariables", dest="tvars", help="whether variables should be trasposed", \
  metavar="VALUE")
parser.add_option("-v", "--Variables", dest="varns", help="',' separated list of variables", \
  metavar="VALUE")
parser.add_option("-y", "--year", dest="year", help="year to process",               \
  metavar="VALUE")

(opts, args) = parser.parse_args()

####### ###### ##### #### ### ## #

if opts.fn is None:
    filen = opts.fh + opts.year + '.nc'
else:
    filen = opts.fn

# Variable whcih provides the indices of a 1D vector from the dimy, dimx space
indvar = opts.indn 

# 2D longitude, latitude matrices
lonvar = opts.lonn
latvar = opts.latn

# Range to retrieve
Xmin=-90.25
Xmin='all'
Xmax=-33.25
Ymin=-67.25
Ymax=15.25

# Variables to reconstruct
#variable = 'all'
variable = opts.varns

# Resolution
resX = 0.5
resY = 0.5

# Minimum difference from matrix to localized point
maxdiff = 0.05

# Projection of the matrix
matProj = 'latlon'

#######    #######
## MAIN
    #######
main = 'ORforcing_reconstruct.py'
fname = 'ORforcing_reconstruct.py'

onc = NetCDFFile(filen, 'r')

availProj = ['latlon']

ofilen = 'reconstruct_matrix_' + opts.year + '.nc'

ncvars = onc.variables.keys()

varstocheck = [indvar, lonvar, latvar]
for vn in varstocheck:
    if not gen.searchInlist(ncvars, vn):
        print gen.errormsg
        print '  ' + main + ": file '" + filen + "' does not have variable '" + vn + \
          "' !!"
        print '    available ones:', ncvars
        quit(-1)

oind = onc.variables[indvar]
olon = onc.variables[lonvar]
olat = onc.variables[latvar]

indv = oind[:]
lonv = olon[:]
latv = olat[:]

if len(olon.dimensions) == 2:
    dimx = lonv.shape[1]
    dimy = lonv.shape[0]
    veclon = lonv.reshape(dimx*dimy)
    veclat = latv.reshape(dimx*dimy)
    Xdimn = olon.dimensions[1]
    Ydimn = olon.dimensions[0]
else:
    veclon = lonv.copy()
    veclat = latv.copy()
    Xdimn = 'x'
    Ydimn = 'y'

vecdimn = oind.dimensions[0]

if not gen.searchInlist(availProj, matProj):
    print errormsg
    print '  ' + fname + ": projection '" + matProj + "' not available !!"
    print '    available ones:', availProj
    quit(-1)

ncdims = onc.dimensions

if variable == 'all':
    varns = ncvars
else:
    varns = variable.split(',')
    for vn in varns:
        if not gen.searchInlist(ncvars, vn):
            print gen.errormsg
            print '  ' + fname + ": file '" + filen + "' does not have " +           \
              " variable '" + vn + "' !!"
            print '  available ones:', ncvars
            quit(-1)

if gen.searchInlist(olon.ncattrs(), 'units'):
    xunits = olon.getncattr('units')
else:
    xunits = '-'
if gen.searchInlist(olat.ncattrs(), 'units'):
    yunits = olat.getncattr('units')
else:
   yunits = '-'

if type(Xmin) == type('2') and Xmin == 'all':
    Xmin = np.min(lonv)
    Xmax = np.max(lonv)
    Ymin = np.min(latv)
    Ymax = np.max(latv)

# Matrix values
if matProj == 'latlon':
    dimx = int((Xmax - Xmin+resX)/resX)
    dimy = int((Ymax - Ymin+resY)/resY)

print 'Xmin:', Xmin, 'Xmax:', Xmax, 'Ymin:', Ymin, 'Ymax:', Ymax, 'maxdiff:', maxdiff

matindt, matXt, matYt, matdifft = Sci.module_scientific.reconstruct_matrix(          \
  vectorxpos=veclon, vectorypos=veclat, dvec=veclon.shape[0], xmin=Xmin, xmax=Xmax,  \
  ymin=Ymin,ymax=Ymax, dmatx=dimx, dmaty=dimy, matproj=matProj, maxdiff=maxdiff)

matind = matindt.transpose()
Nfound = np.sum(matind != -1)
Nstations = veclon.shape[0]
print '  Nfound:', Nfound, ' number of stations:', Nstations

if Nfound*1. / Nstations < 0.8:
    print gen.errormsg
    print '  '+main + ': only ', '{:.2f}'.format(Nfound*100./Nstations),\
      '% of points ' + 'have been found !!'
    print '    this is not enough. Something must went wrong!'
    print '    Longitudes Latitudes _______'
    for i in range(veclon.shape[0]):
        print '    ', veclon[i], veclat[i]
    dx2 = dimx/2
    dy2 = dimy/2
    print '    dx2, dy2 -/+ 5 fraction of destiny longitudes _______'
    for j in range(-5,5):
        print matXt[dx2-5:dx2+5,dy2+j]
    print '    dx2, dy2 -/+ 5 fraction of destiny latitudes _______'
    for j in range(-5,5):
        print matYt[dx2-5:dx2+5,dy2+j]
    print '    dx2, dy2 -/+ 5 fraction of indices equivalency _______'
    for j in range(-5,5):
        print matindt[dx2-5:dx2+5,dy2+j]
    print '     min distance lon(dy2,dx2)=', matXt[dx2,dy2], ':',                    \
      np.min(veclon - matXt[dx2,dy2])
    print '     min distance lat(dy2,dx2)=', matYt[dx2,dy2], ':',                    \
      np.min(veclat - matYt[dx2,dy2])
    print '     longitude borders:', Xmin, Xmax, 'dX:', resX
    print '     latitude borders:', Ymin, Ymax, 'dY:', resY
    #quit(-1)

# Fortran like, First 1
matind = np.where(matind != -1, matind - 1, matind)

# Creation of file
onewnc = NetCDFFile(ofilen, 'w')

# Dimensions
newdim = onewnc.createDimension('x', dimx)
newdim = onewnc.createDimension('y', dimy)

# Variable-dimension
newvar = onewnc.createVariable('lon', 'f8', ('y', 'x'))
newvar[:] = matXt.transpose()
ncvar.basicvardef(newvar, 'lon', 'Longitude', 'degrees_east')   

newvar = onewnc.createVariable('lat', 'f8', ('y', 'x'))
newvar[:] = matYt.transpose()
ncvar.basicvardef(newvar, 'lat', 'Latitude', 'degrees_north')   
onewnc.sync()

# Variable indices
newvar = onewnc.createVariable('vec1D_matind', 'i', ('y', 'x'), fill_value=-1)
newvar[:] = matind
ncvar.basicvardef(newvar, 'vec1D_matind', 'matrix with the equivalencies from 1D ' + \
  'vector indices', '-')
ncvar.set_attribute(newvar, 'coordinates', 'lon lat')

# Variable differences
newvar = onewnc.createVariable('vec1D_matdiff', 'i', ('y', 'x'), fill_value=-1)
newvar[:] = matdifft.transpose()
ncvar.basicvardef(newvar,'vec1D_matdiff', 'matrix differences respect 1D ', 'degrees')
ncvar.set_attribute(newvar, 'coordinates', 'lon lat')

# Looking for equivalencies in the 1D vector
matlonlat = matind.copy()
for j in range(dimy):
    for i in range(dimx):
        if matind[j,i] != -1:
            matlonlat[j,i] = indv[matind[j,i]]

newvar = onewnc.createVariable('lonlat_matind', 'i', ('y', 'x'), fill_value=-1)
newvar[:] = matlonlat
ncvar.basicvardef(newvar, 'lonlat_matind', 'matrix with the equivalencies from ' +   \
  '2D lon, lat matrices', '-')
ncvar.set_attribute(newvar, 'coordinates', 'lon lat')
onewnc.sync()

# Getting variables
for vn in varns:
    if not onewnc.variables.has_key(vn):
        ovar = onc.variables[vn]
        if gen.Str_Bool(opts.tvars):
            indn0 = ovar.dimensions
            indn = list(indn0)[::-1]
        else:
            indn = ovar.dimensions
        vardims = []
        shapevar = []
        for dn in indn:
            if not gen.searchInlist(onewnc.dimensions, dn) and dn != Xdimn and   \
              dn != Ydimn:
                if onc.dimensions[dn].isunlimited():
                    newdim = onewnc.createDimension(dn, None)
                else:
                    newdim = onewnc.createDimension(dn, len(onc.dimensions[dn]))
                
            if dn == vecdimn: 
                vardims.append('y')
                vardims.append('x')
                shapevar.append(dimy)
                shapevar.append(dimx)
            else: 
                vardims.append(dn)
                shapevar.append(len(onc.dimensions[dn]))

        if ovar.dtype == type(int(2)):
            newvar= onewnc.createVariable(vn,ncvar.nctype(ovar.dtype),tuple(vardims),\
              fill_value=gen.fillValueI)
            varvals = np.ones(tuple(shapevar), dtype=ovar.dtype)*gen.fillValueI
        elif ovar.dtype == type(np.int32(2)):
            newvar= onewnc.createVariable(vn,ncvar.nctype(ovar.dtype),tuple(vardims),\
              fill_value=gen.fillValueI)
            varvals = np.ones(tuple(shapevar), dtype=ovar.dtype)*gen.fillValueI
        elif ovar.dtype == type(np.int64(2)):
            newvar= onewnc.createVariable(vn,ncvar.nctype(ovar.dtype),tuple(vardims),\
              fill_value=gen.fillValueI)
            varvals = np.ones(tuple(shapevar), dtype=ovar.dtype)*gen.fillValueI
        elif ovar.dtype == type(np.float(2.)):
            newvar= onewnc.createVariable(vn,ncvar.nctype(ovar.dtype),tuple(vardims),\
              fill_value=gen.fillValueF)
            varvals = np.ones(tuple(shapevar), dtype=ovar.dtype)*gen.fillValueF
        elif ovar.dtype == type(np.float32(2.)):
            newvar= onewnc.createVariable(vn,ncvar.nctype(ovar.dtype),tuple(vardims),\
              fill_value=gen.fillValueF)
            varvals = np.ones(tuple(shapevar), dtype=ovar.dtype)*gen.fillValueF
        else:
            print gen.errormsg
            print '  ' + fname + ': variable type:', ovar.dtype, ' not ready !!'
            quit(-1)

        print '  reconstructing:', vn, ' shape:', newvar.shape, '...'
        # Filling variable. It would be faster if we can avoid this loop... I'm feeling lazy!
        if not gen.searchInlist(vardims,'x') and not gen.searchInlist(vardims,'y'):
            if gen.Str_Bool(opts.tvars):
                newvar[:] = ovar[:].transpose() 
            else:
                newvar[:] = ovar[:] 
        else:
            if gen.Str_Bool(opts.tvars):
                ovart = ovar[:]
            else:
                ovart = ovar[:].transpose()
            print '  Lluis shapes ovart:', ovart.shape, 'newvar:', newvar.shape
            if newvar.dtype == type(float(2.)) or newvar.dtype == type(np.float(2.)) \
              or newvar.dtype == type(np.float32(2)) or                              \
              newvar.dtype == type(np.float64(2)):
                newvals = Sci.module_scientific.fill3dr_2dvec(matind=matindt,        \
                  inmat=ovart, id1=ovart.shape[0], id2=ovart.shape[1],               \
                  od1=newvar.shape[2], od2=newvar.shape[1], od3=newvar.shape[0])
            else:
                newvals = Sci.module_scientific.fill3di_2dvec(matind=matindt,        \
                  inmat=ovart, id1=ovart.shape[0], id2=ovart.shape[1],               \
                  od1=newvar.shape[2], od2=newvar.shape[1], od3=newvar.shape[0])
            newvar[:] = newvals.transpose()

        # Attributes
        for atn in ovar.ncattrs():
            if atn != '_FillValue' and atn != 'units':
                atv = ovar.getncattr(atn)
                ncvar.set_attribute(newvar, atn, atv)
        ncvar.set_attribute(newvar, 'coordinates', 'lon lat')
        onewnc.sync()
   
# Global attributes
for atn in onc.ncattrs():
    atv = onc.getncattr(atn)
    ncvar.set_attribute(onewnc, atn, atv)
onewnc.sync()
ncvar.add_global_PyNCplot(onewnc, main, fname, '0.1')
onc.close()

# Reconstructing times
#otime = onewnc.variables['time']
#ncvar.set_attribute(otime, 'units', 'seconds since ' + opts.year + '-01-01 00:00:00')

onewnc.sync()
onewnc.close()

print fname + ": Successful writing of file '" + ofilen + ".nc' !!"
