from __future__ import division
import numpy as np
from netCDF4 import Dataset
import sys 
import os
import math
from scipy.interpolate import interp1d
import netCDF4 as cdf


#levels = [100., 200., 300., 500., 700., 1000., 2000., 3000., 5000., 7000., 10000., 12500., 15000., 17500., 20000., 22500.,25000., 30000., 35000., 40000., 
#    45000., 50000., 55000., 60000., 65000., 70000., 75000., 77500., 80000., 82500., 85000., 87500., 90000., 92500., 95000., 97500., 100000.]
levels = [1., 2., 3., 5., 7., 10., 20., 30., 50., 70., 100., 125., 150., 175., 200., 225.,250., 300., 350., 400.,
    450., 500., 550., 600., 650., 700., 750., 775., 800., 825., 850., 875., 900., 925., 950., 975., 1000.]
levels = levels[::-1]
print(levels)

dw1 = Dataset('2001.nc','r')
dtfull = dw1.variables['t'][:,:,:,:]
dpfulloned = dw1.variables['level'][:]
dpfull = np.zeros(((len(dpfulloned),dtfull.shape[2],dtfull.shape[3])))
#------ mb conversion
#dpfulloned = dpfulloned *100.
for l in range(0, len(dpfulloned)):
    dpfull[l,:,:]= dpfulloned[l]
print(dpfull)

dtfull = dw1.variables['t'][:,:,:,:]
#dPSfull = dw1.variables['PS'][:,:,:]
time_counterfull = dw1.variables['time'][:]
lat = dw1.variables['lat'][:]
lon = dw1.variables['lon'][:]

dim0=dtfull.shape[0]
dim1=dtfull.shape[1]
dim2=dtfull.shape[2]
dim3=dtfull.shape[3]
out3 = np.zeros((((dim0,len(levels),dim2,dim3))))

def vertical_int2(levels, pfull, dataset):
    nl, ni, nj = len(levels), dataset.shape[1], dataset.shape[2]
    out = np.zeros((nl,ni,nj))
    for i in range(dataset.shape[1]):
        for j in range(dataset.shape[2]):
            f = interp1d(pfull[:,i,j],dataset[:,i,j],kind='linear',fill_value="extrapolate")
            out[:,i,j] = f(levels)
    return out

pfull = dpfull

for t in range(dim0):
    print("time step=",t)
    tfull = dtfull[t,:,:,:]
    out2 = vertical_int2(levels, pfull, tfull)
    out3[t,:,:,:]=out2[:,:,:]
del tfull,dtfull

dufull = dw1.variables['u'][:,:,:,:]
out4 = np.zeros((((dim0,len(levels),dim2,dim3))))
for t in range(dim0):
    print("time step=",t)
    ufull = dufull[t,:,:,:]
    out2 = vertical_int2(levels, pfull, ufull)
    out4[t,:,:,:]=out2[:,:,:]
del ufull, dufull 

dvfull = dw1.variables['v'][:,:,:,:]
out5 = np.zeros((((dim0,len(levels),dim2,dim3))))
for t in range(dim0):
    print("time step=",t)
    vfull = dvfull[t,:,:,:]
    out2 = vertical_int2(levels, pfull, vfull)
    out5[t,:,:,:]=out2[:,:,:]
del vfull,dvfull


print("-----------------------------writing NetCDF--------------------------------")
print(out3)
try:
    f = cdf.Dataset('ERA4.nc', 'w', format='NETCDF4')
except:
    print("Error occurred while opening new netCDF file, Error: ", sys.exc_info()[0])
#levels2 = [levels[i] / 100. for i in range(len(levels))]
levels2 = levels
#levels2 = [int(i) for i in levels2]
f.createDimension('lon', len(lon))
f.createDimension('lat', len(lat))
f.createDimension('lev', len(levels2))
f.createDimension('time_counter', len(time_counterfull))
vlon = f.createVariable('lon', 'f4', 'lon')
vlat = f.createVariable('lat', 'f4', 'lat')  
vlev = f.createVariable('lev', 'f4', 'lev')
#vlev = f.createVariable('lev', 'i4', 'lev')
vtime = f.createVariable('time_counter', 'f8', 'time_counter')
vt = f.createVariable('t', 'f4',('time_counter', 'lev', 'lat', 'lon'))
vu = f.createVariable('u', 'f4',('time_counter', 'lev', 'lat', 'lon'))
vv = f.createVariable('v', 'f4',('time_counter', 'lev', 'lat', 'lon'))
#vPS = f.createVariable('PS', 'f4',('time_counter', 'lat', 'lon'))
vlon[:] = lon
vlat[:] = lat
vlev[:] = levels2
vtime[:] = time_counterfull
vt[:] = out3
vu[:] = out4
vv[:] = out5
#vPS[:] = dPSfull
f.close()
print("-----------------------done--------------------")
