# FROM: https://pypi.org/project/pupygrib/
import numpy as np
import pupygrib
import os
from netCDF4 import Dataset as NetCDFFile
import generic_tools as gen
import nc_var_tools as ncvar

errormsg = 'ERROR -- error -- ERROR -- error'

ifile = 'ERAI_pl200503_130.grib'

ofile = 'ERAI_pl_130.nc'

def grib128_ncname(varid):
    """ Function to provide CF name from a ECMWF GRIB table 128 variable ID
    >>> grib128_ncname(130)
    ['T', 'Temperature', 'K', 'air_temperature', 'air temperature']
    """
    fname = 'grib128_ncname'

    folder = os.path.dirname(os.path.realpath(__file__))

    infile = folder + '/transform_128ECMWF.html'

    if not os.path.isfile(infile):
        print errormsg
        print '  ' + fname + ": File '" + infile + "' does not exist !!"
        quit(-1)

    of = open(infile, 'r')

    idS = str(varid).zfill(3)

    for line in of:
        if len(line) > 28:
            #print line[0:27], ':', idS, line[0:27] == '<td align="center">' + idS + '</td>'
            if line[0:27] == '<td align="center">' + idS + '</td>':
                linevals = line.replace('\n','').split('</td>')
                varn = linevals[1].replace('<td align="center">','')
                Lvarn = linevals[2].replace('<td align="left">','')
                units = linevals[3].replace('<td align="left">','')
                units = units.replace('<sup>','').replace('</sup>','')
                units = units.replace('<sub>','').replace('</sub>','')
                units = units.replace('<nobr>','').replace('</nobr>','')

                #print varn, Lvarn, units

                cfvalues = gen.variables_values(varn)

                break

    return [varn,Lvarn,units,cfvalues[0], cfvalues[1], cfvalues[4].replace('|', ' ')]

timeu = 'minutes since 1949-12-01 00:00:00'

#######    #######
## MAIN
    #######

with open(ifile, 'rb') as stream:
    for i, msg in enumerate(pupygrib.read(stream), 1):
        if i == 1:
            lons, lats = msg.get_coordinates()
            onewnc = NetCDFFile(ofile, 'w')

            # Create dimension
            dimx = lons.shape[1]
            dimy = lons.shape[0]

            newdim = onewnc.createDimension('lon',dimx)
            newdim = onewnc.createDimension('lat',dimy)
            newdim = onewnc.createDimension('time',None)
            newdim = onewnc.createDimension('Lstring',64)
            
            dlon = msg[2].longitudeOfLastGridPoint - msg[2].longitudeOfFirstGridPoint
            dlat = msg[2].latitudeOfLastGridPoint - msg[2].latitudeOfFirstGridPoint
            print dlon, dlat

            if dlat < 0.:
                lons[:,:] = lons[::-1,:]
                lats[:,:] = lats[::-1,:]

            # Create dimension-varibale
            newvar = onewnc.createVariable('lon', 'f8', ('lat', 'lon'))
            newvar[:] = lons[:]
            ncvar.basicvardef(newvar, 'lon', 'Longitude', 'degrees_east')
            newvar.setncattr('axis', 'X')
            newvar.setncattr('_CoordinateAxisType', 'Lon')

            newvar = onewnc.createVariable('lat', 'f8', ('lat', 'lon'))
            newvar[:] = lats[:]
            ncvar.basicvardef(newvar, 'lat', 'Latitude', 'degrees_north')
            newvar.setncattr('axis', 'Y')
            newvar.setncattr('_CoordinateAxisType', 'Lat')
            
            newvartime = onewnc.createVariable('time', 'f8', ('time'))
            ncvar.basicvardef(newvartime, 'time', 'time', timeu)
            newvar.setncattr('axis', 'T')
            newvar.setncattr('_CoordinateAxisType', 'Time')
            
            newvartimeS = onewnc.createVariable('timeS', 'c', ('time', 'Lstring'))
            ncvar.basicvardef(newvartimeS, 'time', 'time', 'YmdHMS')
            newvar.setncattr('axis', 'T')
            newvar.setncattr('_CoordinateAxisType', 'Time')

        values = msg.get_values()
        century = msg[1].centuryOfReferenceTimeOfData
        year = msg[1].yearOfCentury
        month = msg[1].month
        day = msg[1].day
        hour = msg[1].hour
        minute = msg[1].minute
        levtype = msg[1].indicatorOfTypeOfLevel
        level = msg[1].level
        varid = msg[1].indicatorOfParameter
        offset = msg[1].offset
        scalefactor = msg[1].decimalScaleFactor
        timerange = msg[1].timeRangeIndicator

        if levtype == 100:
            if not onewnc.variables.has_key('press'):
                print '  we got a variable with a pressure value !!'
                pressv = []
                newdim = onewnc.createDimension('press', None)
                newvarpress = onewnc.createVariable('press', 'f8', ('press'))
                newvarpress[0] = level*100.
                ncvar.basicvardef(newvarpress, 'air_pressure', 'air pressure', 'Pa')
                newvarpress.setncattr('axis', 'Z')
                newvarpress.setncattr('_CoordinateAxisType', 'Press')
            if not gen.searchInlist(pressv, level*100.): 
                pressv.append(level*100.)
                iz = gen.index_vec(pressv, level*100.)
                newvarpress[iz] = level*100.
            else: 
                iz = gen.index_vec(pressv, level*100.)
        elif levtype == 1:
            if not onewnc.variables.has_key('height'):
                print '  we got a variable with a height value !!'
                heightv = []
                newdim = onewnc.createDimension('height', None)
                newvarheight = onewnc.createVariable('height', 'f8', ('height'))
                newvarheight[0] = level
                ncvar.basicvardef(newvarheight, 'height', 'height', 'm')
                newvarheight.setncattr('axis', 'Z')
                newvarheight.setncattr('_CoordinateAxisType', 'Height')
            if not gen.searchInlist(heightv, level): 
                heightv.append(level)
                iz = gen.index_vec(heightv, level)
                newvarheight[iz] = level
            else:
                iz = gen.index_vec(heightv, level)

        #print "Message :", i, values.mean(), values.shape, ':', msg[1].centuryOfReferenceTimeOfData, \
        #  msg[1].yearOfCentury, msg[1].month, msg[1].day, msg[1].hour, msg[1].minute, \
        #  msg[1].indicatorOfTypeOfLevel, msg[1].level, msg[1].localDefinitionNumber, msg[1].indicatorOfParameter, \
        #  msg[1].offset, msg[1].decimalScaleFactor, msg[1].timeRangeIndicator
        #print 'msg:', dir(msg)
        #for i in range(6):
        #    print 'msg ' + str(i) + ':', dir(msg[i])

        varparams = grib128_ncname(varid)
        if i==1: print varparams

        timeS=str((century-1)*100 + year) + str(month).zfill(2) + str(day).zfill(2)+\
          str(hour).zfill(2) + str(minute).zfill(2) + '00'

        if i == 1: timevs = [timeS]

        if not gen.searchInlist(timevs, timeS): timevs.append(timeS)
        it = gen.index_vec(timevs, timeS)
        #print 'timeS:', timeS, 'it:', it, 'lev:', level, 'iz:', iz
        cftime = gen.datetimeStr_conversion(timeS,'YmdHMS','cfTime,'+timeu)

        newvartimeS[it,0:14] = timeS
        newvartime[it] = cftime

        if not onewnc.variables.has_key(varparams[3]):
            if levtype == 1:
                dimns = ['time', 'height', 'lat', 'lon']
            elif levtype == 100:
                dimns = ['time', 'press', 'lat', 'lon']
            else:
                dimns = ['time', 'lat', 'lon']

            newvar = onewnc.createVariable(varparams[3], 'f', tuple(dimns))
            ncvar.basicvardef(newvar, varparams[4], varparams[5], varparams[2])
        else:
            newvar = onewnc.variables[varparams[3]]

        if dlat < 0.: 
            values[:] = values[::-1,:]

        if levtype == 1 or levtype == 100:
            newvar[it,iz,:,:] = values[:]
        else:
            newvar[it,:,:] = values[:]

        onewnc.sync()

# Global values
ncvar.add_global_PyNCplot(onewnc, 'red_grib.py', ' ', '0.1') 

onewnc.close()
print "Successfull writting of '" + ofile + "' !!"
