# Python program to modify any sfc variable from netCDF file to impose aprofile of a given variable
# L. Fita Borrell, LMD-Jussieu, IPSL, UMPC, CNRS, Paris, France
# July 2013

from optparse import OptionParser
import numpy as np
from netCDF4 import Dataset as NetCDFFile
import os

## g.e. # python sfcVAR_global_modification.py -f met_em.d01_1979-01-01_00:00:00 -k control -v SST -s 300.15

fname = 'sfcVAR_global_modification.py'

errormsg = 'ERROR -- error -- ERROR -- error'

kelvin=273.15
pi3=np.pi/3.
pi36=np.pi/36.

def HS_kind(Ndims,dimx,dimy,dimt,lonvals,latvals,zero,amp,dl):
    print '  Imposing "Held & Suarez, 94, BAMS" characteristics...'
    if Ndims == 1:
        varvals = np.zeros((dimy,dimx), dtype=np.float32)
        for ilat in range(dimy):
            varvals[ilat,:] = amp-dl*np.sin(latvals[ilat])**2.
    elif Ndims==2:
        varvals = np.zeros((dimy,dimx), dtype=np.float32)
        for ilat in range(dimy):
            for ilon in range(dimx):
                varvals[ilat,ilon] = amp-dl*np.sin(latvals[ilat,ilon])**2.
    elif Ndims==3:
        varvals = np.zeros((dimt,dimy,dimx), dtype=np.float32)
        for ilat in range(dimy):
            for ilon in range(dimx):
                varvals[:,ilat,ilon] = amp-dl*np.sin(latvals[0,ilat,ilon])**2.

    return varvals + zero


def control_kind(Ndims,dimx,dimy,dimt,lonvals,latvals,zero,amp):
    print '  Imposing "control" characteristics...'
    if Ndims == 1:
        varvals = np.zeros((dimy,dimx), dtype=np.float32)
        for ilat in range(dimy):
            if np.abs(latvals[ilat]) < pi3:
                varvals[ilat,:] = amp*(1.-np.sin(3.*latvals[ilat]/2.)**2.)
            else:
                varvals[ilat,:] = 0.
    elif Ndims==2:
        varvals = np.zeros((dimy,dimx), dtype=np.float32)
        for ilat in range(dimy):
            for ilon in range(dimx):
                if np.abs(latvals[ilat,ilon]) < pi3:
                    varvals[ilat,ilon] = amp*(1.-np.sin(3.*latvals[ilat,ilon]/2.)**2.)
                else:
                    varvals[ilat,ilon] = 0.
    elif Ndims==3:
        varvals = np.zeros((dimt,dimy,dimx), dtype=np.float32)
        for ilat in range(dimy):
            for ilon in range(dimx):
                if np.abs(latvals[0,ilat,ilon]) < pi3:
                    varvals[:,ilat,ilon] = amp*(1.-np.sin(3.*latvals[0,ilat,ilon]/2.)**2.)
                else:
                    varvals[:,ilat,ilon] = 0.

    return varvals + zero

def control5N_kind(Ndims,dimx,dimy,dimt,lonvals,latvals,zero,amp):
    print '  Imposing "control5N" characteristics...'
    if Ndims == 1:
        varvals = np.zeros((dimy,dimx), dtype=np.float32)
        for ilat in range(dimy):
            if pi36 < latvals[ilat] and latvals[ilat] < pi3:
                varvals[ilat,:] = amp*(1.-np.sin(90./55.*(latvals[ilat]-pi36))**2.)
            elif -pi3 < latvals[ilat] and latvals[ilat] < pi36:
                varvals[ilat,:] = amp*(1.-np.sin(90./65.*(latvals[ilat]-pi36))**2.)
            else:
                varvals[ilat,:] = 0.
    elif Ndims==2:
        varvals = np.zeros((dimy,dimx), dtype=np.float32)
        for ilat in range(dimy):
            for ilon in range(dimx):
                if pi36 <= latvals[ilat,ilon] and latvals[ilat,ilon] < pi3:
                    varvals[ilat,ilon] = amp*(1.-np.sin(90./55.*(latvals[ilat,ilon]-pi36))**2.)
                elif -pi3 < latvals[ilat,ilon] and latvals[ilat,ilon] < pi36:
                    varvals[ilat,ilon] = amp*(1.-np.sin(90./65.*(latvals[ilat,ilon]-pi36))**2.)
                else:
                    varvals[ilat,ilon] = 0.
    elif Ndims==3:
        varvals = np.zeros((dimt,dimy,dimx), dtype=np.float32)
        for ilat in range(dimy):
            for ilon in range(dimx):
                if pi36 <= latvals[0,ilat,ilon] and latvals[0,ilat,ilon] < pi3:
                    varvals[:,ilat,ilon] = amp*(1.-np.sin(90./55.*(latvals[0,ilat,ilon]-pi36))**2.)
                elif -pi3 < latvals[0,ilat,ilon] and latvals[0,ilat,ilon] < pi36:
                    varvals[:,ilat,ilon] = amp*(1.-np.sin(90./65.*(latvals[0,ilat,ilon]-pi36))**2.)
                else:
                    varvals[:,ilat,ilon] = 0.

    return varvals + zero

def flat_kind(Ndims,dimx,dimy,dimt,lonvals,latvals,zero,amp):
    print '  Imposing "flat" characteristics...'
    if Ndims == 1:
        varvals = np.zeros((dimy,dimx), dtype=np.float32)
        for ilat in range(dimy):
            if np.abs(latvals[ilat]) < pi3:
                varvals[ilat,:] = amp*(1.-np.sin(3.*latvals[ilat]/2.)**4.)
            else:
                varvals[ilat,:] = 0.
    elif Ndims==2:
        varvals = np.zeros((dimy,dimx), dtype=np.float32)
        for ilat in range(dimy):
            for ilon in range(dimx):
                if np.abs(latvals[ilat,ilon]) < pi3:
                    varvals[ilat,ilon] = amp*(1.-np.sin(3.*latvals[ilat,ilon]/2.)**4.)
                else:
                    varvals[ilat,ilon] = 0.
    elif Ndims==3:
        varvals = np.zeros((dimt,dimy,dimx), dtype=np.float32)
        for ilat in range(dimy):
            for ilon in range(dimx):
                if np.abs(latvals[0,ilat,ilon]) < pi3:
                    varvals[:,ilat,ilon] = amp*(1.-np.sin(3.*latvals[0,ilat,ilon]/2.)**4.)
                else:
                    varvals[:,ilat,ilon] = 0.

    return varvals + zero

def KEQ_kind(Ndims,dimx,dimy,dimt,lonvals,latvals,zero,amp,L0,dL,dl):
    print '  Imposing "KEQ" characteristics...'
    npoints=0
    if Ndims == 1:
        varvals = np.zeros((dimy,dimx), dtype=np.float32)
        for ilat in range(dimy):
            for ilon in range(dimx):
                if L0-dL < lonvals[ilon] and lonvals[ilon] < L0+dL and np.abs(latvals[ilat]) < dl:
                    npoints = npoints + 1
                    A=np.cos(np.pi/2.*(lonvals[ilon]-L0)/dL)**2.
                    B=np.cos(np.pi/2.*latvals[ilat]/dl)**2.
                    varvals[ilat,ilon] = amp*A*B
                else:
                    varvals[ilat,ilon] = 0.
    elif Ndims==2:
        varvals = np.zeros((dimy,dimx), dtype=np.float32)
        for ilat in range(dimy):
            for ilon in range(dimx):
                if L0-dL < lonvals[ilat,ilon] and lonvals[ilat,ilon] < L0+dL and np.abs(latvals[ilat,ilon]) < dl:
                    npoints = npoints + 1
                    A=np.cos(np.pi/2.*(lonvals[ilat,ilon]-L0)/dL)**2.
                    B=np.cos(np.pi/2.*latvals[ilat,ilon]/dl)**2.
                    varvals[ilat,ilon] = amp*A*B
                else:
                    varvals[ilat,ilon] = 0.
    elif Ndims==3:
        varvals = np.zeros((dimt,dimy,dimx), dtype=np.float32)
        for ilat in range(dimy):
            for ilon in range(dimx):
                if L0-dL < lonvals[0,ilat,ilon] and lonvals[0,ilat,ilon] < L0+dL and np.abs(latvals[0,ilat,ilon]) < dl:
                    npoints = npoints + 1
                    A=np.cos(np.pi/2.*(lonvals[0,ilat,ilon]-L0)/dL)**2.
                    B=np.cos(np.pi/2.*latvals[0,ilat,ilon]/dl)**2.
                    varvals[:,ilat,ilon] = amp*A*B
                else:
                    varvals[:,ilat,ilon] = 0.

    if npoints < 1:
       print errormsg
       print '    KEQ: not modified points generated!!!!'
       print '       L0-dL: ',L0-dL
       print '       L0+dL: ',L0+dL
       print '       dl:', dl
       print errormsg
       quit(-1)

    return varvals 

def KW1_kind(Ndims,dimx,dimy,dimt,lonvals,latvals,zero,amp,L0,dL,dl):
    print '  Imposing "KW1 characteristics...'
    npoints=0
    if Ndims == 1:
        varvals = np.zeros((dimy,dimx), dtype=np.float32)
        for ilat in range(dimy):
            for ilon in range(dimx):
                if np.abs(latvals[ilat]) < dl:
                    npoints = npoints + 1
                    A=np.cos(lonvals[ilon]-L0)
                    B=np.cos(np.pi/2.*(latvals[ilat])/dl)**2.
                    varvals[ilat,ilon] =  amp*A*B
                else:
                    varvals[ilat,ilon] = 0.
    elif Ndims==2:
        varvals = np.zeros((dimy,dimx), dtype=np.float32)
        for ilat in range(dimy):
            for ilon in range(dimx):
                if np.abs(latvals[ilat,ilon]) < dl:
                    npoints = npoints + 1
                    A=np.cos(lonvals[ilat,ilon]-L0)
                    B=np.cos(np.pi/2.*(latvals[ilat,ilon])/dl)**2.
                    varvals[ilat,ilon] =  amp*A*B
                else:
                    varvals[ilat,ilon] = 0.
    elif Ndims==3:
        varvals = np.zeros((dimt,dimy,dimx), dtype=np.float32)
        for ilat in range(dimy):
            for ilon in range(dimx):
                if np.abs(latvals[0,ilat,ilon]) < dl:
                    npoints = npoints + 1
                    A=np.cos(lonvals[0,ilat,ilon]-L0)
                    B=np.cos(np.pi/2.*(latvals[0,ilat,ilon])/dl)**2.
                    varvals[:,ilat,ilon] = amp*A*B
                else:
                    varvals[:,ilat,ilon] = 0.

    if npoints < 1:
       print errormsg
       print '    KW1: not modified points generated!!!!'
       print '       dl:', dl
       print errormsg
       quit(-1)

    return varvals 

def peaked_kind(Ndims,dimx,dimy,dimt,lonvals,latvals,zero,amp):
    print '  Imposing "peaked" characteristics...'
    if Ndims == 1:
        varvals = np.zeros((dimy,dimx), dtype=np.float32)
        for ilat in range(dimy):
            if np.abs(latvals[ilat]) < pi3:
                varvals[ilat,:] = amp*(1.-(3.*np.abs(latvals[ilat])/np.pi))
            else:
                varvals[ilat,:] = 0.
    elif Ndims==2:
        varvals = np.zeros((dimy,dimx), dtype=np.float32)
        for ilat in range(dimy):
            for ilon in range(dimx):
                if np.abs(latvals[ilat,ilon]) < pi3:
                    varvals[ilat,ilon] = amp*(1.-(3.*np.abs(latvals[ilat,ilon])/np.pi))
                else:
                    varvals[ilat,ilon] = 0.
    elif Ndims==3:
        varvals = np.zeros((dimt,dimy,dimx), dtype=np.float32)
        for ilat in range(dimy):
            for ilon in range(dimx):
                if np.abs(latvals[0,ilat,ilon]) < pi3:
                    varvals[:,ilat,ilon] = amp*(1.-(3.*np.abs(latvals[0,ilat,ilon])/np.pi))
                else:
                    varvals[:,ilat,ilon] = 0.

    return varvals + zero

####### ###### ##### #### ### ## #

parser = OptionParser()
parser.add_option("-f", "--netCDF_file", dest="ncfile", 
                  help="file to use", metavar="FILE")
parser.add_option("-k", "--kindMOD", type='choice', dest="modkind",
 choices=['control', 'control5N', 'flat', 'peaked', 'Qobs', 'KEQ', 'KW1', 'HS'],  
                  help="kind of modification from Neale and Hoskins, Atmos. Sci. Letters, (2001), HS: Held & Suarez, 1994 conditions (uses [ref]:[amp]:[perAmp, as /\Ty])", metavar="VALUE")
parser.add_option("-L", "--longitude", dest="lonname",
                  help="name of the longitude", metavar="VAR")
parser.add_option("-l", "--latitude", dest="latname",
                  help="name of the latitude", metavar="VAR")
parser.add_option("-v", "--variable", dest="varname",
                  help="variable to modify", metavar="VAR")
parser.add_option("-s", "--values", dest="values",
                  help="aditional values for 'KEQ' and 'KW1' [ref, in [units]]:[amp, with respect 'ref' in [units]]:[perAmp, in [units]]:[L0, in deg]:[dL, in deg]:[dl, in deg]", metavar="VAR")

(opts, args) = parser.parse_args()

#######    #######
## MAIN
    #######

if not os.path.isfile(opts.ncfile):
  print errormsg
  print '  File ' + opts.ncfile + ' does not exist !!'
  print errormsg
  quit()    

ncf = NetCDFFile(opts.ncfile,'a')

if not ncf.variables.has_key(opts.varname):
      print errormsg
      print '   ' + fname + ': File "' + opts.ncfile + '" does not have variable "' + opts.varname + '" !!!!'
      print errormsg
      ncf.close()
      quit(-1)

additionalvalues = opts.values.split(':')
Nvalues = len(additionalvalues)
if opts.modkind == 'KEQ' or opts.modkind == 'KW1':
    if Nvalues != 6:
        print errormsg
        print '  ' + fname + ': options "KEQ" and "KW1" require 4 additional arguments! ',Nvalues,' provided!'
        print errormsg
        ncf.close()
        quit(-1)
    else:
        refval = np.float32(additionalvalues[0])
        amplitude = np.float32(additionalvalues[1])
        Peramplitude = np.float32(additionalvalues[2])
        Lon0 = np.float32(additionalvalues[3])*np.pi/180.
        dLon = np.float32(additionalvalues[4])*np.pi/180.
        dlat = np.float32(additionalvalues[5])*np.pi/180.
        print '  ' + fname +': additional values____'
        print '    ref: ',refval
        print '    amp: ',amplitude
        print '    perAmp: ',Peramplitude
        print '    L0: ',Lon0
        print '    dL: ',dLon
        print '    dl: ',dlat
elif opts.modkind == 'HS':
    if Nvalues != 3:
        print errormsg
        print '  ' + fname + ': option "HS" requires 3 additional arguments! ',Nvalues,' provided!'
        print errormsg
        ncf.close()
        quit(-1)
    else:
        refval = np.float32(additionalvalues[0])
        amplitude = np.float32(additionalvalues[1])
        Peramplitude = np.float32(additionalvalues[2])
        print '  ' + fname +': additional values____'
        print '    ref: ',refval
        print '    amp: ',amplitude
        print '    perAmp: ',Peramplitude
else:
    refval = np.float32(additionalvalues[0])
    amplitude = np.float32(additionalvalues[1])
    print '  ' + fname +': values____'
    print '    ref: ',refval
    print '    amp: ',amplitude

objvarmod = ncf.variables[opts.varname]
objlon = ncf.variables[opts.lonname]
objlat = ncf.variables[opts.latname]

Nds = len(objvarmod.shape)
if Nds == 1:
    dx = objlon.shape[0]
    dy = objlat.shape[0]
else:
    dx = objlon.shape[Nds-1]
    dy = objlon.shape[Nds-2]
    if Nds == 3:
        dt = objvarmod.shape[0]
    else:
        dt = 0

if Nds > 3:
   print errormsg
   print fname + ': Ndims= ',Nds,' not ready !!!'
   print errormsg
   ncf.close()
   quit(-1)

print fname + ': Ndims: ',Nds,' dimensions of the matrices: ', dt, ', ', dy, ',', dx

lonV=objlon[:]
lonV=np.where(lonV < 0., lonV + 360., lonV)
lonV=lonV*np.pi/180.
latV=objlat[:]*np.pi/180.

if opts.modkind == 'HS':
    varVals = HS_kind(Nds,dx,dy,dt,lonV,latV,refval,amplitude,Peramplitude)

elif opts.modkind == 'control':
    varVals = control_kind(Nds,dx,dy,dt,lonV,latV,refval,amplitude)

elif opts.modkind == 'control5N':
    varVals = control5N_kind(Nds,dx,dy,dt,lonV,latV,refval,amplitude)

elif opts.modkind == 'flat':
    varVals = flat_kind(Nds,dx,dy,dt,lonV,latV,refval,amplitude)

elif opts.modkind == 'KEQ':
    varvalues=objvarmod[:]
    varValsA = varvalues.copy()
    varValsB = varvalues.copy()
    varValsA = control_kind(Nds,dx,dy,dt,lonV,latV,refval,amplitude)
    varValsB = KEQ_kind(Nds,dx,dy,dt,lonV,latV,refval,Peramplitude,Lon0,dLon,dlat)
    varVals= varValsA + varValsB

elif opts.modkind == 'KW1':
    varvalues=objvarmod[:]
    varValsA = varvalues.copy()
    varValsB = varvalues.copy()
    varValsA = control_kind(Nds,dx,dy,dt,lonV,latV,refval,amplitude)
    varValsB = KW1_kind(Nds,dx,dy,dt,lonV,latV,refval,Peramplitude,Lon0,dLon,dlat)
    varVals= varValsA + varValsB

elif opts.modkind == 'peaked':
    varVals = peaked_kind(Nds,dx,dy,dt,lonV,latV,refval,amplitude)

elif opts.modkind == 'Qobs':
    varvalues=objvarmod[:]
    varvals1=varvalues.copy()
    varvals2=varvalues.copy()

    varvals1 = control_kind(Nds,dx,dy,dt,lonV,latV,refval,amplitude)
    varvals2 = flat_kind(Nds,dx,dy,dt,lonV,latV,refval,amplitude)

    varVals = (varvals1 + varvals2)/2.

else:
    print errormsg
    print '  ' + fname + ': kind "' + opts.modkind + '" does not exist!!!'
    print errormsg
    ncf.close()
    quit(-1)

objvarmod[:] = varVals

ncf.sync()
ncf.close()

print fname + ': "' + opts.varname + '" of "' + opts.ncfile + '" has been modified with the "' + opts.modkind + '"'
