## Author: AS
def errormess(text,printvar=None):
    print text
    if printvar is not None: print printvar 
    exit()
    return

## Author: AS
def adjust_length (tab, zelen):
    from numpy import ones
    if tab is None:
        outtab = ones(zelen) * -999999
    else:
        if zelen != len(tab):
            print "not enough or too much values... setting same values all variables"
            outtab = ones(zelen) * tab[0]
        else:
            outtab = tab
    return outtab

## Author: AS
def getname(var=False,winds=False,anomaly=False):
    if var and winds:     basename = var + '_UV'
    elif var:             basename = var
    elif winds:           basename = 'UV'
    else:                 errormess("please set at least winds or var",printvar=nc.variables)
    if anomaly:           basename = 'd' + basename
    return basename

## Author: AS
def localtime(utc,lon):
    ltst = utc + lon / 15.
    ltst = int (ltst * 10) / 10.
    ltst = ltst % 24
    return ltst

## Author: AS
def whatkindfile (nc):
    if 'controle' in nc.variables:   typefile = 'gcm'
    elif 'phisinit' in nc.variables: typefile = 'gcm' 
    elif 'vert' in nc.variables:     typefile = 'mesoapi'
    elif 'U' in nc.variables:        typefile = 'meso'
    elif 'HGT_M' in nc.variables:    typefile = 'geo'
    #else:                            errormess("whatkindfile: typefile not supported.")
    else:                            typefile = 'gcm' # for lslin-ed files from gcm
    return typefile

## Author: AS
def getfield (nc,var):
    ## this allows to get much faster (than simply referring to nc.variables[var])
    import numpy as np
    dimension = len(nc.variables[var].dimensions)
    #print "   Opening variable with", dimension, "dimensions ..."
    if dimension == 2:    field = nc.variables[var][:,:]
    elif dimension == 3:  field = nc.variables[var][:,:,:]
    elif dimension == 4:  field = nc.variables[var][:,:,:,:]
    # if there are NaNs in the ncdf, they should be loaded as a masked array which will be
    # recasted as a regular array later in reducefield
    if (np.isnan(np.sum(field)) and (type(field).__name__ not in 'MaskedArray')):
       print "Warning: netcdf as nan values but is not loaded as a Masked Array."
       print "recasting array type"
       out=np.ma.masked_invalid(field)
       out.set_fill_value([np.NaN])
    else:
       out=field
    return out

## Author: AS + TN + AC
def reducefield (input,d4=None,d3=None,d2=None,d1=None,yint=False,alt=None,anomaly=False):
    ### we do it the reverse way to be compliant with netcdf "t z y x" or "t y x" or "y x"
    ### it would be actually better to name d4 d3 d2 d1 as t z y x
    ### ... note, anomaly is only computed over d1 and d2 for the moment
    import numpy as np
    from mymath import max,mean
    csmooth = 12 ## a fair amount of grid points
    dimension = np.array(input).ndim
    shape = np.array(input).shape
    #print 'd1,d2,d3,d4: ',d1,d2,d3,d4
    print 'IN  REDUCEFIELD dim,shape: ',dimension,shape
    if anomaly: print 'ANOMALY ANOMALY'
    output = input
    error = False
    #### this is needed to cope the case where d4,d3,d2,d1 are single integers and not arrays
    if d4 is not None and not isinstance(d4, np.ndarray): d4=[d4]
    if d3 is not None and not isinstance(d3, np.ndarray): d3=[d3]
    if d2 is not None and not isinstance(d2, np.ndarray): d2=[d2]
    if d1 is not None and not isinstance(d1, np.ndarray): d1=[d1]
    ### now the main part
    if dimension == 2:
        if   d2 >= shape[0]: error = True
        elif d1 >= shape[1]: error = True
        elif d1 is not None and d2 is not None:  output = mean(input[d2,:],axis=0); output = mean(output[d1],axis=0)
        elif d1 is not None:         output = mean(input[:,d1],axis=1)
        elif d2 is not None:         output = mean(input[d2,:],axis=0)
    elif dimension == 3:
        if   max(d4) >= shape[0]: error = True
        elif max(d2) >= shape[1]: error = True
        elif max(d1) >= shape[2]: error = True
        elif d4 is not None and d2 is not None and d1 is not None:  
            output = mean(input[d4,:,:],axis=0); output = mean(output[d2,:],axis=0); output = mean(output[d1],axis=0)
        elif d4 is not None and d2 is not None:    output = mean(input[d4,:,:],axis=0); output=mean(output[d2,:],axis=0)
        elif d4 is not None and d1 is not None:    output = mean(input[d4,:,:],axis=0); output=mean(output[:,d1],axis=1)
        elif d2 is not None and d1 is not None:    output = mean(input[:,d2,:],axis=1); output=mean(output[:,d1],axis=1)
        elif d1 is not None:                       output = mean(input[:,:,d1],axis=2)
        elif d2 is not None:                       output = mean(input[:,d2,:],axis=1)
        elif d4 is not None:                       output = mean(input[d4,:,:],axis=0)
    elif dimension == 4:
        if   max(d4) >= shape[0]: error = True
        elif max(d3) >= shape[1]: error = True
        elif max(d2) >= shape[2]: error = True
        elif max(d1) >= shape[3]: error = True
        elif d4 is not None and d3 is not None and d2 is not None and d1 is not None:
             output = mean(input[d4,:,:,:],axis=0)
             output = reduce_zaxis(output[d3,:,:],ax=0,yint=yint,vert=alt,indice=d3)
             if anomaly: output = 100. * ((output / smooth(output,csmooth)) - 1.) 
             output = mean(output[d2,:],axis=0)
             output = mean(output[d1],axis=0)
        elif d4 is not None and d3 is not None and d2 is not None: 
             output = mean(input[d4,:,:,:],axis=0)
             output = reduce_zaxis(output[d3,:,:],ax=0,yint=yint,vert=alt,indice=d3)
             if anomaly: output = 100. * ((output / smooth(output,csmooth)) - 1.) 
             output = mean(output[d2,:],axis=0)
        elif d4 is not None and d3 is not None and d1 is not None: 
             output = mean(input[d4,:,:,:],axis=0)
             output = reduce_zaxis(output[d3,:,:],ax=0,yint=yint,vert=alt,indice=d3)
             if anomaly: output = 100. * ((output / smooth(output,csmooth)) - 1.) 
             output = mean(output[:,d1],axis=1)
        elif d4 is not None and d2 is not None and d1 is not None: 
             output = mean(input[d4,:,:,:],axis=0)
             if anomaly:
                 for k in range(output.shape[0]):  output[k,:,:] = 100. * ((output[k,:,:] / smooth(output[k,:,:],csmooth)) - 1.)
             output = mean(output[:,d2,:],axis=1)
             output = mean(output[:,d1],axis=1)
             #noperturb = smooth1d(output,window_len=7)
             #lenlen = len(output) ; output = output[1:lenlen-7] ; yeye = noperturb[4:lenlen-4]
             #plot(output) ; plot(yeye) ; show() ; plot(output-yeye) ; show()
        elif d3 is not None and d2 is not None and d1 is not None: 
             output = reduce_zaxis(input[:,d3,:,:],ax=1,yint=yint,vert=alt,indice=d3) 
             if anomaly:
                 for k in range(output.shape[0]):  output[k,:,:] = 100. * ((output[k,:,:] / smooth(output[k,:,:],csmooth)) - 1.)
             output = mean(output[:,d2,:],axis=1)
             output = mean(output[:,d1],axis=1) 
        elif d4 is not None and d3 is not None:  
             output = mean(input[d4,:,:,:],axis=0) 
             output = reduce_zaxis(output[d3,:,:],ax=0,yint=yint,vert=alt,indice=d3)
             if anomaly: output = 100. * ((output / smooth(output,csmooth)) - 1.) 
        elif d4 is not None and d2 is not None:  
             output = mean(input[d4,:,:,:],axis=0) 
             output = mean(output[:,d2,:],axis=1)
        elif d4 is not None and d1 is not None:  
             output = mean(input[d4,:,:,:],axis=0) 
             output = mean(output[:,:,d1],axis=2)
        elif d3 is not None and d2 is not None:  
             output = reduce_zaxis(input[:,d3,:,:],ax=1,yint=yint,vert=alt,indice=d3)
             output = mean(output[:,d2,:],axis=1)
        elif d3 is not None and d1 is not None:  
             output = reduce_zaxis(input[:,d3,:,:],ax=1,yint=yint,vert=alt,indice=d3)
             output = mean(output[:,:,d1],axis=0)
        elif d2 is not None and d1 is not None:  
             output = mean(input[:,:,d2,:],axis=2)
             output = mean(output[:,:,d1],axis=2)
        elif d1 is not None:        output = mean(input[:,:,:,d1],axis=3)
        elif d2 is not None:        output = mean(input[:,:,d2,:],axis=2)
        elif d3 is not None:        output = reduce_zaxis(input[:,d3,:,:],ax=0,yint=yint,vert=alt,indice=d3)
        elif d4 is not None:        output = mean(input[d4,:,:,:],axis=0)
    dimension = np.array(output).ndim
    shape = np.array(output).shape
    print 'OUT REDUCEFIELD dim,shape: ',dimension,shape
    return output, error

## Author: AC + AS
def reduce_zaxis (input,ax=None,yint=False,vert=None,indice=None):
    from mymath import max,mean
    from scipy import integrate
    if yint and vert is not None and indice is not None:
       if type(input).__name__=='MaskedArray':
          input.set_fill_value([np.NaN])
          output = integrate.trapz(input.filled(),x=vert[indice],axis=ax)
       else:
          output = integrate.trapz(input,x=vert[indice],axis=ax)
    else:
       output = mean(input,axis=ax)
    return output

## Author: AS + TN
def definesubplot ( numplot, fig ):
    from matplotlib.pyplot import rcParams
    rcParams['font.size'] = 12. ## default (important for multiple calls)
    if numplot <= 0:
        subv = 99999
        subh = 99999
    elif numplot == 1: 
        subv = 99999
        subh = 99999
    elif numplot == 2:
        subv = 1
        subh = 2
        fig.subplots_adjust(wspace = 0.35)
        rcParams['font.size'] = int( rcParams['font.size'] * 3. / 4. )
    elif numplot == 3:
        subv = 2
        subh = 2
        fig.subplots_adjust(wspace = 0.5)
        rcParams['font.size'] = int( rcParams['font.size'] * 1. / 2. )
    elif   numplot == 4:
        subv = 2
        subh = 2
        fig.subplots_adjust(wspace = 0.3, hspace = 0.3)
        rcParams['font.size'] = int( rcParams['font.size'] * 2. / 3. )
    elif numplot <= 6:
        subv = 2
        subh = 3
        fig.subplots_adjust(wspace = 0.4, hspace = 0.0)
        rcParams['font.size'] = int( rcParams['font.size'] * 1. / 2. )
    elif numplot <= 8:
        subv = 2
        subh = 4
        fig.subplots_adjust(wspace = 0.3, hspace = 0.3)
        rcParams['font.size'] = int( rcParams['font.size'] * 1. / 2. )
    elif numplot <= 9:
        subv = 3
        subh = 3
        fig.subplots_adjust(wspace = 0.3, hspace = 0.3)
        rcParams['font.size'] = int( rcParams['font.size'] * 1. / 2. )
    elif numplot <= 12:
        subv = 3
        subh = 4
        fig.subplots_adjust(wspace = 0.1, hspace = 0.1)
        rcParams['font.size'] = int( rcParams['font.size'] * 1. / 2. )
    elif numplot <= 16:
        subv = 4
        subh = 4
        fig.subplots_adjust(wspace = 0.3, hspace = 0.3)
        rcParams['font.size'] = int( rcParams['font.size'] * 1. / 2. )
    else:
        print "number of plot supported: 1 to 16"
        exit()
    return subv,subh

## Author: AS
def getstralt(nc,nvert):
    typefile = whatkindfile(nc)
    if typefile is 'meso':                      
        stralt = "_lvl" + str(nvert)
    elif typefile is 'mesoapi':
        zelevel = int(nc.variables['vert'][nvert])
        if abs(zelevel) < 10000.:   strheight=str(zelevel)+"m"
        else:                       strheight=str(int(zelevel/1000.))+"km"
        if 'altitude'       in nc.dimensions:   stralt = "_"+strheight+"-AMR"
        elif 'altitude_abg' in nc.dimensions:   stralt = "_"+strheight+"-ALS"
        elif 'bottom_top'   in nc.dimensions:   stralt = "_"+strheight
        elif 'pressure'     in nc.dimensions:   stralt = "_"+str(zelevel)+"Pa"
        else:                                   stralt = ""
    else:
        stralt = ""
    return stralt

## Author: AS
def getlschar ( namefile ):
    from netCDF4 import Dataset
    from timestuff import sol2ls
    from numpy import array
    from string import rstrip
    namefile = rstrip( rstrip( rstrip( namefile, chars="_z"), chars="_zabg"), chars="_p")
    #### we assume that wrfout is next to wrfout_z and wrfout_zabg
    nc  = Dataset(namefile)
    zetime = None
    if 'Times' in nc.variables:
        zetime = nc.variables['Times'][0]
        shape = array(nc.variables['Times']).shape
        if shape[0] < 2: zetime = None
    if zetime is not None \
       and 'vert' not in nc.variables:
        #### strangely enough this does not work for api or ncrcat results!
        zetimestart = getattr(nc, 'START_DATE')
        zeday = int(zetime[8]+zetime[9]) - int(zetimestart[8]+zetimestart[9])
        if zeday < 0:    lschar=""  ## might have crossed a month... fix soon
        else:            lschar="_Ls"+str( int( 10. * sol2ls ( getattr( nc, 'JULDAY' ) + zeday ) ) / 10. )
        ###
        zetime2 = nc.variables['Times'][1]
        one  = int(zetime[11]+zetime[12]) + int(zetime[14]+zetime[15])/37.
        next = int(zetime2[11]+zetime2[12]) + int(zetime2[14]+zetime2[15])/37. 
        zehour    = one
        zehourin  = abs ( next - one )
    else:
        lschar=""
        zehour = 0
        zehourin = 1  
    return lschar, zehour, zehourin

## Author: AS
def getprefix (nc):
    prefix = 'LMD_MMM_'
    prefix = prefix + 'd'+str(getattr(nc,'GRID_ID'))+'_'
    prefix = prefix + str(int(getattr(nc,'DX')/1000.))+'km_'
    return prefix

## Author: AS
def getproj (nc):
    typefile = whatkindfile(nc)
    if typefile in ['mesoapi','meso','geo']:
        ### (il faudrait passer CEN_LON dans la projection ?)
        map_proj = getattr(nc, 'MAP_PROJ')
        cen_lat  = getattr(nc, 'CEN_LAT')
        if map_proj == 2:
            if cen_lat > 10.:    
                proj="npstere"
                #print "NP stereographic polar domain" 
            else:            
                proj="spstere"
                #print "SP stereographic polar domain"
        elif map_proj == 1: 
            #print "lambert projection domain" 
            proj="lcc"
        elif map_proj == 3: 
            #print "mercator projection"
            proj="merc"
        else:
            proj="merc"
    elif typefile in ['gcm']:  proj="cyl"    ## pb avec les autres (de trace derriere la sphere ?)
    else:                      proj="ortho"
    return proj    

## Author: AS
def ptitle (name):
    from matplotlib.pyplot import title
    title(name)
    print name

## Author: AS
def polarinterv (lon2d,lat2d):
    import numpy as np
    wlon = [np.min(lon2d),np.max(lon2d)]
    ind = np.array(lat2d).shape[0] / 2  ## to get a good boundlat and to get the pole
    wlat = [np.min(lat2d[ind,:]),np.max(lat2d[ind,:])]
    return [wlon,wlat]

## Author: AS
def simplinterv (lon2d,lat2d):
    import numpy as np
    return [[np.min(lon2d),np.max(lon2d)],[np.min(lat2d),np.max(lat2d)]]

## Author: AS
def wrfinterv (lon2d,lat2d):
    nx = len(lon2d[0,:])-1
    ny = len(lon2d[:,0])-1
    lon1 = lon2d[0,0] 
    lon2 = lon2d[nx,ny] 
    lat1 = lat2d[0,0] 
    lat2 = lat2d[nx,ny] 
    if abs(0.5*(lat1+lat2)) > 60.:  wider = 0.5 * (abs(lon1)+abs(lon2)) * 0.1
    else:                           wider = 0.
    if lon1 < lon2:  wlon = [lon1, lon2 + wider]  
    else:            wlon = [lon2, lon1 + wider]
    if lat1 < lat2:  wlat = [lat1, lat2]
    else:            wlat = [lat2, lat1]
    return [wlon,wlat]

## Author: AS
def makeplotres (filename,res=None,pad_inches_value=0.25,folder='',disp=True,ext='png',erase=False):
    import  matplotlib.pyplot as plt
    from os import system 
    addstr = ""
    if res is not None:
        res = int(res)
        addstr = "_"+str(res)
    name = filename+addstr+"."+ext
    if folder != '':      name = folder+'/'+name
    plt.savefig(name,dpi=res,bbox_inches='tight',pad_inches=pad_inches_value)
    if disp:              display(name)
    if ext in ['eps','ps','svg']:   system("tar czvf "+name+".tar.gz "+name+" ; rm -f "+name)
    if erase:   system("mv "+name+" to_be_erased")		
    return

## Author: AS
def dumpbdy (field,n,stag=None):
    nx = len(field[0,:])-1
    ny = len(field[:,0])-1
    if stag == 'U': nx = nx-1
    if stag == 'V': ny = ny-1
    return field[n:ny-n,n:nx-n]

## Author: AS
def getcoorddef ( nc ):   
    import numpy as np
    ## getcoord2d for predefined types
    typefile = whatkindfile(nc)
    if typefile in ['mesoapi','meso']:
        [lon2d,lat2d] = getcoord2d(nc)
        lon2d = dumpbdy(lon2d,6)
        lat2d = dumpbdy(lat2d,6)
    elif typefile in ['gcm']: 
        [lon2d,lat2d] = getcoord2d(nc,nlat="latitude",nlon="longitude",is1d=True)
    elif typefile in ['geo']:
        [lon2d,lat2d] = getcoord2d(nc,nlat='XLAT_M',nlon='XLONG_M')
    return lon2d,lat2d    

## Author: AS
def getcoord2d (nc,nlat='XLAT',nlon='XLONG',is1d=False):
    import numpy as np
    if is1d:
        lat = nc.variables[nlat][:]
        lon = nc.variables[nlon][:]
        [lon2d,lat2d] = np.meshgrid(lon,lat)
    else:
        lat = nc.variables[nlat][0,:,:]
        lon = nc.variables[nlon][0,:,:]
        [lon2d,lat2d] = [lon,lat]
    return lon2d,lat2d

## FROM COOKBOOK http://www.scipy.org/Cookbook/SignalSmooth
def smooth1d(x,window_len=11,window='hanning'):
    import numpy
    """smooth the data using a window with requested size.
    This method is based on the convolution of a scaled window with the signal.
    The signal is prepared by introducing reflected copies of the signal 
    (with the window size) in both ends so that transient parts are minimized
    in the begining and end part of the output signal.
    input:
        x: the input signal 
        window_len: the dimension of the smoothing window; should be an odd integer
        window: the type of window from 'flat', 'hanning', 'hamming', 'bartlett', 'blackman'
            flat window will produce a moving average smoothing.
    output:
        the smoothed signal
    example:
    t=linspace(-2,2,0.1)
    x=sin(t)+randn(len(t))*0.1
    y=smooth(x)
    see also: 
    numpy.hanning, numpy.hamming, numpy.bartlett, numpy.blackman, numpy.convolve
    scipy.signal.lfilter
    TODO: the window parameter could be the window itself if an array instead of a string   
    """
    x = numpy.array(x)
    if x.ndim != 1:
        raise ValueError, "smooth only accepts 1 dimension arrays."
    if x.size < window_len:
        raise ValueError, "Input vector needs to be bigger than window size."
    if window_len<3:
        return x
    if not window in ['flat', 'hanning', 'hamming', 'bartlett', 'blackman']:
        raise ValueError, "Window is on of 'flat', 'hanning', 'hamming', 'bartlett', 'blackman'"
    s=numpy.r_[x[window_len-1:0:-1],x,x[-1:-window_len:-1]]
    #print(len(s))
    if window == 'flat': #moving average
        w=numpy.ones(window_len,'d')
    else:
        w=eval('numpy.'+window+'(window_len)')
    y=numpy.convolve(w/w.sum(),s,mode='valid')
    return y

## Author: AS
def smooth (field, coeff):
	## actually blur_image could work with different coeff on x and y
	if coeff > 1:	result = blur_image(field,int(coeff))
	else:		result = field
	return result

## FROM COOKBOOK http://www.scipy.org/Cookbook/SignalSmooth
def gauss_kern(size, sizey=None):
	import numpy as np
    	# Returns a normalized 2D gauss kernel array for convolutions
    	size = int(size)
    	if not sizey:
	        sizey = size
	else:
	        sizey = int(sizey)
	x, y = np.mgrid[-size:size+1, -sizey:sizey+1]
	g = np.exp(-(x**2/float(size)+y**2/float(sizey)))
	return g / g.sum()

## FROM COOKBOOK http://www.scipy.org/Cookbook/SignalSmooth
def blur_image(im, n, ny=None) :
	from scipy.signal import convolve
	# blurs the image by convolving with a gaussian kernel of typical size n. 
	# The optional keyword argument ny allows for a different size in the y direction.
    	g = gauss_kern(n, sizey=ny)
    	improc = convolve(im, g, mode='same')
    	return improc

## Author: AS
def getwinddef (nc):    
    ## getwinds for predefined types
    typefile = whatkindfile(nc)
    ###
    if typefile is 'mesoapi':    [uchar,vchar] = ['Um','Vm']
    elif typefile is 'gcm':      [uchar,vchar] = ['u','v']
    elif typefile is 'meso':     [uchar,vchar] = ['U','V']
    else:                        [uchar,vchar] = ['not found','not found']
    ###
    if typefile in ['meso']:     metwind = False ## geometrical (wrt grid) 
    else:                        metwind = True  ## meteorological (zon/mer)
    if metwind is False:         print "Not using meteorological winds. You trust numerical grid as being (x,y)"
    ###
    return uchar,vchar,metwind

## Author: AS
def vectorfield (u, v, x, y, stride=3, scale=15., factor=250., color='black', csmooth=1, key=True):
    ## scale regle la reference du vecteur
    ## factor regle toutes les longueurs (dont la reference). l'AUGMENTER pour raccourcir les vecteurs.
    import  matplotlib.pyplot               as plt
    import  numpy                           as np
    posx = np.min(x) - np.std(x) / 10.
    posy = np.min(y) - np.std(y) / 10.
    u = smooth(u,csmooth)
    v = smooth(v,csmooth)
    widthvec = 0.003 #0.005 #0.003
    q = plt.quiver( x[::stride,::stride],\
                    y[::stride,::stride],\
                    u[::stride,::stride],\
                    v[::stride,::stride],\
                    angles='xy',color=color,pivot='middle',\
                    scale=factor,width=widthvec )
    if color in ['white','yellow']:     kcolor='black'
    else:                               kcolor=color
    if key: p = plt.quiverkey(q,posx,posy,scale,\
                   str(int(scale)),coordinates='data',color=kcolor,labelpos='S',labelsep = 0.03)
    return 

## Author: AS
def display (name):
    from os import system
    system("display "+name+" > /dev/null 2> /dev/null &")
    return name

## Author: AS
def findstep (wlon):
    steplon = int((wlon[1]-wlon[0])/4.)  #3
    step = 120.
    while step > steplon and step > 15. :       step = step / 2.
    if step <= 15.:
        while step > steplon and step > 5.  :   step = step - 5.
    if step <= 5.:
        while step > steplon and step > 1.  :   step = step - 1.
    if step <= 1.:
        step = 1. 
    return step

## Author: AS
def define_proj (char,wlon,wlat,back=None,blat=None):
    from    mpl_toolkits.basemap            import Basemap
    import  numpy                           as np
    import  matplotlib                      as mpl
    from mymath import max
    meanlon = 0.5*(wlon[0]+wlon[1])
    meanlat = 0.5*(wlat[0]+wlat[1])
    if blat is None:
        ortholat=meanlat
        if   wlat[0] >= 80.:   blat =  40. 
        elif wlat[1] <= -80.:  blat = -40.
        elif wlat[1] >= 0.:    blat = wlat[0] 
        elif wlat[0] <= 0.:    blat = wlat[1]
    else:  ortholat=blat
    #print "blat ", blat
    h = 50.  ## en km
    radius = 3397200.
    if   char == "cyl":     m = Basemap(rsphere=radius,projection='cyl',\
                              llcrnrlat=wlat[0],urcrnrlat=wlat[1],llcrnrlon=wlon[0],urcrnrlon=wlon[1])
    elif char == "moll":    m = Basemap(rsphere=radius,projection='moll',lon_0=meanlon)
    elif char == "ortho":   m = Basemap(rsphere=radius,projection='ortho',lon_0=meanlon,lat_0=ortholat)
    elif char == "lcc":     m = Basemap(rsphere=radius,projection='lcc',lat_1=meanlat,lat_0=meanlat,lon_0=meanlon,\
                              llcrnrlat=wlat[0],urcrnrlat=wlat[1],llcrnrlon=wlon[0],urcrnrlon=wlon[1])
    elif char == "npstere": m = Basemap(rsphere=radius,projection='npstere', boundinglat=blat, lon_0=0.)
    elif char == "spstere": m = Basemap(rsphere=radius,projection='spstere', boundinglat=blat, lon_0=180.)
    elif char == "nplaea":  m = Basemap(rsphere=radius,projection='nplaea', boundinglat=wlat[0], lon_0=meanlon)
    elif char == "laea":    m = Basemap(rsphere=radius,projection='laea',lon_0=meanlon,lat_0=meanlat,lat_ts=meanlat,\
                              llcrnrlat=wlat[0],urcrnrlat=wlat[1],llcrnrlon=wlon[0],urcrnrlon=wlon[1])
    elif char == "nsper":   m = Basemap(rsphere=radius,projection='nsper',lon_0=meanlon,lat_0=meanlat,satellite_height=h*1000.)
    elif char == "merc":    m = Basemap(rsphere=radius,projection='merc',lat_ts=0.,\
                              llcrnrlat=wlat[0],urcrnrlat=wlat[1],llcrnrlon=wlon[0],urcrnrlon=wlon[1])
    fontsizemer = int(mpl.rcParams['font.size']*3./4.)
    if char in ["cyl","lcc","merc","nsper","laea"]:   step = findstep(wlon)
    else:                                             step = 10.
    steplon = step*2.
    #if back in ["geolocal"]:                          
    #    step = np.min([5.,step])
    #    steplon = step
    m.drawmeridians(np.r_[-180.:180.:steplon], labels=[0,0,0,1], color='grey', fontsize=fontsizemer)
    m.drawparallels(np.r_[-90.:90.:step], labels=[1,0,0,0], color='grey', fontsize=fontsizemer)
    if back: m.warpimage(marsmap(back),scale=0.75)
            #if not back:
            #    if not var:                                        back = "mola"    ## if no var:         draw mola
            #    elif typefile in ['mesoapi','meso','geo'] \
            #       and proj not in ['merc','lcc','nsper','laea']:  back = "molabw"  ## if var but meso:   draw molabw
            #    else:                                              pass             ## else:              draw None
    return m

## Author: AS
#### test temporaire
def putpoints (map,plot):
    #### from http://www.scipy.org/Cookbook/Matplotlib/Maps
    # lat/lon coordinates of five cities.
    lats = [18.4]
    lons = [-134.0]
    points=['Olympus Mons']
    # compute the native map projection coordinates for cities.
    x,y = map(lons,lats)
    # plot filled circles at the locations of the cities.
    map.plot(x,y,'bo')
    # plot the names of those five cities.
    wherept = 0 #1000 #50000
    for name,xpt,ypt in zip(points,x,y):
       plot.text(xpt+wherept,ypt+wherept,name)
    ## le nom ne s'affiche pas...
    return

## Author: AS
def calculate_bounds(field,vmin=None,vmax=None):
    import numpy as np
    from mymath import max,min,mean
    ind = np.where(field < 9e+35)
    fieldcalc = field[ ind ] # la syntaxe compacte ne marche si field est un tuple
    ###
    dev = np.std(fieldcalc)*3.0
    ###
    if vmin is None:
        zevmin = mean(fieldcalc) - dev
    else:             zevmin = vmin
    ###
    if vmax is None:  zevmax = mean(fieldcalc) + dev
    else:             zevmax = vmax
    if vmin == vmax:
                      zevmin = mean(fieldcalc) - dev  ### for continuity
                      zevmax = mean(fieldcalc) + dev  ### for continuity            
    ###
    if zevmin < 0. and min(fieldcalc) > 0.: zevmin = 0.
    print "BOUNDS field ", min(fieldcalc), max(fieldcalc)
    print "BOUNDS adopted ", zevmin, zevmax
    return zevmin, zevmax

## Author: AS
def bounds(what_I_plot,zevmin,zevmax):
    from mymath import max,min,mean
    ### might be convenient to add the missing value in arguments
    #what_I_plot[ what_I_plot < zevmin ] = zevmin#*(1. + 1.e-7)
    if zevmin < 0: what_I_plot[ what_I_plot < zevmin*(1. - 1.e-7) ] = zevmin*(1. - 1.e-7)
    else:          what_I_plot[ what_I_plot < zevmin*(1. + 1.e-7) ] = zevmin*(1. + 1.e-7)
    print "NEW MIN ", min(what_I_plot)
    what_I_plot[ what_I_plot > 9e+35  ] = -9e+35
    what_I_plot[ what_I_plot > zevmax ] = zevmax
    print "NEW MAX ", max(what_I_plot)
    return what_I_plot

## Author: AS
def nolow(what_I_plot):
    from mymath import max,min
    lim = 0.15*0.5*(abs(max(what_I_plot))+abs(min(what_I_plot)))
    print "NO PLOT BELOW VALUE ", lim
    what_I_plot [ abs(what_I_plot) < lim ] = 1.e40 
    return what_I_plot


## Author : AC
def hole_bounds(what_I_plot,zevmin,zevmax):
    import numpy as np
    zi=0
    for i in what_I_plot:
        zj=0
        for j in i:
            if ((j < zevmin) or (j > zevmax)):what_I_plot[zi,zj]=np.NaN
            zj=zj+1
        zi=zi+1

    return what_I_plot

## Author: AS
def zoomset (wlon,wlat,zoom):
    dlon = abs(wlon[1]-wlon[0])/2.
    dlat = abs(wlat[1]-wlat[0])/2.
    [wlon,wlat] = [ [wlon[0]+zoom*dlon/100.,wlon[1]-zoom*dlon/100.],\
                    [wlat[0]+zoom*dlat/100.,wlat[1]-zoom*dlat/100.] ]
    print "ZOOM %",zoom,wlon,wlat
    return wlon,wlat

## Author: AS
def fmtvar (whichvar="def"):
    fmtvar    =     { \
             "TK":           "%.0f",\
# Variables from TES ncdf format
             "T_NADIR_DAY":  "%.0f",\
             "T_NADIR_NIT":  "%.0f",\
# Variables from tes.py ncdf format
             "TEMP_DAY":     "%.0f",\
             "TEMP_NIGHT":   "%.0f",\
# Variables from MCS and mcs.py ncdf format
             "DTEMP":     "%.0f",\
             "NTEMP":   "%.0f",\
             "TPOT":         "%.0f",\
             "TSURF":        "%.0f",\
             "def":          "%.1e",\
             "PTOT":         "%.0f",\
             "HGT":          "%.1e",\
             "USTM":         "%.2f",\
             "HFX":          "%.0f",\
             "ICETOT":       "%.1e",\
             "TAU_ICE":      "%.2f",\
             "VMR_ICE":      "%.1e",\
             "MTOT":         "%.1f",\
             "ANOMALY":      "%.1f",\
             "W":            "%.1f",\
             "WMAX_TH":      "%.1f",\
             "QSURFICE":     "%.0f",\
             "UM":           "%.0f",\
             "ALBBARE":      "%.2f",\
             "TAU":          "%.1f",\
             "CO2":          "%.2f",\
             #### T.N.
             "TEMP":         "%.0f",\
             "VMR_H2OICE":   "%.0f",\
             "VMR_H2OVAP":   "%.0f",\
             "TAUTES":       "%.2f",\
             "TAUTESAP":     "%.2f",\

                    }
    if whichvar not in fmtvar:
        whichvar = "def"
    return fmtvar[whichvar]

## Author: AS
####################################################################################################################
### Colorbars http://www.scipy.org/Cookbook/Matplotlib/Show_colormaps?action=AttachFile&do=get&target=colormaps3.png
def defcolorb (whichone="def"):
    whichcolorb =    { \
             "def":          "spectral",\
             "HGT":          "spectral",\
             "TK":           "gist_heat",\
             "TSURF":        "RdBu_r",\
             "QH2O":         "PuBu",\
             "USTM":         "YlOrRd",\
             #"T_nadir_nit":  "RdBu_r",\
             #"T_nadir_day":  "RdBu_r",\
             "HFX":          "RdYlBu",\
             "ICETOT":       "YlGnBu_r",\
             #"MTOT":         "PuBu",\
             "CCNQ":         "YlOrBr",\
             "CCNN":         "YlOrBr",\
             "TEMP":         "Jet",\
             "TAU_ICE":      "Blues",\
             "VMR_ICE":      "Blues",\
             "W":            "jet",\
             "WMAX_TH":      "spectral",\
             "ANOMALY":      "RdBu_r",\
             "QSURFICE":     "hot_r",\
             "ALBBARE":      "spectral",\
             "TAU":          "YlOrBr_r",\
             "CO2":          "YlOrBr_r",\
             #### T.N.
             "MTOT":         "Jet",\
             "H2O_ICE_S":    "RdBu",\
             "VMR_H2OICE":   "PuBu",\
             "VMR_H2OVAP":   "PuBu",\
                     }
#W --> spectral ou jet
#spectral BrBG RdBu_r
    #print "predefined colorbars"
    if whichone not in whichcolorb:
        whichone = "def"
    return whichcolorb[whichone]

## Author: AS
def definecolorvec (whichone="def"):
        whichcolor =    { \
                "def":          "black",\
                "vis":          "yellow",\
                "vishires":     "yellow",\
                "molabw":       "yellow",\
                "mola":         "black",\
                "gist_heat":    "white",\
                "hot":          "tk",\
                "gist_rainbow": "black",\
                "spectral":     "black",\
                "gray":         "red",\
                "PuBu":         "black",\
                        }
        if whichone not in whichcolor:
                whichone = "def"
        return whichcolor[whichone]

## Author: AS
def marsmap (whichone="vishires"):
        from os import uname
        mymachine = uname()[1]
        ### not sure about speed-up with this method... looks the same
        if "lmd.jussieu.fr" in mymachine: domain = "/u/aslmd/WWW/maps/"
        else:                             domain = "http://www.lmd.jussieu.fr/~aslmd/maps/"
	whichlink = 	{ \
		#"vis":		"http://maps.jpl.nasa.gov/pix/mar0kuu2.jpg",\
		#"vishires":	"http://www.lmd.jussieu.fr/~aslmd/maps/MarsMap_2500x1250.jpg",\
                #"geolocal":    "http://dl.dropbox.com/u/11078310/geolocal.jpg",\
		#"mola":	"http://www.lns.cornell.edu/~seb/celestia/mars-mola-2k.jpg",\
		#"molabw":	"http://dl.dropbox.com/u/11078310/MarsElevation_2500x1250.jpg",\
                "vis":         domain+"mar0kuu2.jpg",\
                "vishires":    domain+"MarsMap_2500x1250.jpg",\
                "geolocal":    domain+"geolocal.jpg",\
                "mola":        domain+"mars-mola-2k.jpg",\
                "molabw":      domain+"MarsElevation_2500x1250.jpg",\
                "clouds":      "http://www.johnstonsarchive.net/spaceart/marswcloudmap.jpg",\
                "jupiter":     "http://www.mmedia.is/~bjj/data/jupiter_css/jupiter_css.jpg",\
                "jupiter_voy": "http://www.mmedia.is/~bjj/data/jupiter/jupiter_vgr2.jpg",\
                "bw":          "http://users.info.unicaen.fr/~karczma/TEACH/InfoGeo/Images/Planets/EarthElevation_2500x1250.jpg",\
                "contrast":    "http://users.info.unicaen.fr/~karczma/TEACH/InfoGeo/Images/Planets/EarthMapAtmos_2500x1250.jpg",\
                "nice":        "http://users.info.unicaen.fr/~karczma/TEACH/InfoGeo/Images/Planets/earthmap1k.jpg",\
                "blue":        "http://eoimages.gsfc.nasa.gov/ve/2430/land_ocean_ice_2048.jpg",\
                "blueclouds":  "http://eoimages.gsfc.nasa.gov/ve/2431/land_ocean_ice_cloud_2048.jpg",\
                "justclouds":  "http://eoimages.gsfc.nasa.gov/ve/2432/cloud_combined_2048.jpg",\
			}
        ### see http://www.mmedia.is/~bjj/planetary_maps.html
	if whichone not in whichlink: 
		print "marsmap: choice not defined... you'll get the default one... "
		whichone = "vishires"  
        return whichlink[whichone]

#def earthmap (whichone):
#	if   whichone == "contrast":	whichlink="http://users.info.unicaen.fr/~karczma/TEACH/InfoGeo/Images/Planets/EarthMapAtmos_2500x1250.jpg"
#	elif whichone == "bw":		whichlink="http://users.info.unicaen.fr/~karczma/TEACH/InfoGeo/Images/Planets/EarthElevation_2500x1250.jpg"
#	elif whichone == "nice":	whichlink="http://users.info.unicaen.fr/~karczma/TEACH/InfoGeo/Images/Planets/earthmap1k.jpg"
#	return whichlink

## Author: AS
def latinterv (area="Whole"):
    list =    { \
        "Europe":                [[ 20., 80.],[- 50.,  50.]],\
        "Central_America":       [[-10., 40.],[ 230., 300.]],\
        "Africa":                [[-20., 50.],[- 50.,  50.]],\
        "Whole":                 [[-90., 90.],[-180., 180.]],\
        "Southern_Hemisphere":   [[-90., 60.],[-180., 180.]],\
        "Northern_Hemisphere":   [[-60., 90.],[-180., 180.]],\
        "Tharsis":               [[-30., 60.],[-170.,- 10.]],\
        "Whole_No_High":         [[-60., 60.],[-180., 180.]],\
        "Chryse":                [[-60., 60.],[- 60.,  60.]],\
        "North_Pole":            [[ 50., 90.],[-180., 180.]],\
        "Close_North_Pole":      [[ 75., 90.],[-180., 180.]],\
        "Far_South_Pole":        [[-90.,-40.],[-180., 180.]],\
        "South_Pole":            [[-90.,-50.],[-180., 180.]],\
        "Close_South_Pole":      [[-90.,-75.],[-180., 180.]],\
              }
    if area not in list:   area = "Whole"
    [olat,olon] = list[area]
    return olon,olat

## Author: TN
def separatenames (name):
  from numpy import concatenate
  # look for comas in the input name to separate different names (files, variables,etc ..)
  if name is None:
     names = None
  else:
    names = []
    stop = 0
    currentname = name
    while stop == 0:
      indexvir = currentname.find(',')
      if indexvir == -1:
        stop = 1
        name1 = currentname
      else:
        name1 = currentname[0:indexvir]
      names = concatenate((names,[name1]))
      currentname = currentname[indexvir+1:len(currentname)]
  return names

## Author: TN [old]
def getopmatrix (kind,n):
  import numpy as np
  # compute matrix of operations between files
  # the matrix is 'number of files'-square
  # 1: difference (row minus column), 2: add
  # not 0 in diag : just plot
  if n == 1:
    opm = np.eye(1)
  elif kind == 'basic':
    opm = np.eye(n)
  elif kind == 'difference1': # show differences with 1st file
    opm = np.zeros((n,n))
    opm[0,:] = 1
    opm[0,0] = 0
  elif kind == 'difference2': # show differences with 1st file AND show 1st file
    opm = np.zeros((n,n))
    opm[0,:] = 1
  else:
    opm = np.eye(n)
  return opm
 
## Author: TN [old] 
def checkcoherence (nfiles,nlat,nlon,ntime):
  if (nfiles > 1):
     if (nlat > 1):
        errormess("what you asked is not possible !")
  return 1

## Author: TN
def readslices(saxis):
  from numpy import empty
  if saxis == None:
     zesaxis = None
  else:
     zesaxis = empty((len(saxis),2))
     for i in range(len(saxis)):
        a = separatenames(saxis[i])
        if len(a) == 1:
           zesaxis[i,:] = float(a[0])
        else:
           zesaxis[i,0] = float(a[0])
           zesaxis[i,1] = float(a[1])
           
  return zesaxis

## Author: AS
def bidimfind(lon2d,lat2d,vlon,vlat):
   import numpy as np
   if vlat is None:    array = (lon2d - vlon)**2
   elif vlon is None:  array = (lat2d - vlat)**2
   else:               array = (lon2d - vlon)**2 + (lat2d - vlat)**2
   idy,idx = np.unravel_index( np.argmin(array), lon2d.shape )
   if vlon is not None:
       #print lon2d[idy,idx],vlon
       if (np.abs(lon2d[idy,idx]-vlon)) > 5: errormess("longitude not found ",printvar=lon2d)
   if vlat is not None:
       #print lat2d[idy,idx],vlat
       if (np.abs(lat2d[idy,idx]-vlat)) > 5: errormess("latitude not found ",printvar=lat2d)
   return idx,idy

## Author: TN
def getsindex(saxis,index,axis):
# input  : all the desired slices and the good index
# output : all indexes to be taken into account for reducing field
  import numpy as np
  if ( np.array(axis).ndim == 2):
      axis = axis[:,0]
  if saxis is None:
      zeindex = None
  else:
      aaa = int(np.argmin(abs(saxis[index,0] - axis)))
      bbb = int(np.argmin(abs(saxis[index,1] - axis)))
      [imin,imax] = np.sort(np.array([aaa,bbb]))
      zeindex = np.array(range(imax-imin+1))+imin
      # because -180 and 180 are the same point in longitude,
      # we get rid of one for averaging purposes.
      if axis[imin] == -180 and axis[imax] == 180:
         zeindex = zeindex[0:len(zeindex)-1]
         print "INFO: whole longitude averaging asked, so last point is not taken into account."
  return zeindex
     
## Author: TN
def define_axis(lon,lat,vert,time,indexlon,indexlat,indexvert,indextime,what_I_plot,dim0,vertmode):
# Purpose of define_axis is to find x and y axis scales in a smart way
# x axis priority: 1/time 2/lon 3/lat 4/vertical
# To be improved !!!...
   from numpy import array,swapaxes
   x = None
   y = None
   count = 0
   what_I_plot = array(what_I_plot)
   shape = what_I_plot.shape
   if indextime is None:
      print "AXIS is time"
      x = time
      count = count+1
   if indexlon is None:
      print "AXIS is lon"
      if count == 0: x = lon
      else: y = lon
      count = count+1
   if indexlat is None:
      print "AXIS is lat"
      if count == 0: x = lat
      else: y = lat
      count = count+1
   if indexvert is None and dim0 is 4:
      print "AXIS is vert"
      if vertmode == 0: # vertical axis is as is (GCM grid)
         if count == 0: x=range(len(vert))
         else: y=range(len(vert))
         count = count+1
      else: # vertical axis is in kms
         if count == 0: x = vert
         else: y = vert
         count = count+1
   x = array(x)
   y = array(y)
   print "CHECK: what_I_plot.shape", what_I_plot.shape
   print "CHECK: x.shape, y.shape", x.shape, y.shape
   if len(shape) == 1:
      if shape[0] != len(x):
         print "WARNING HERE !!!"
         x = y
   elif len(shape) == 2:
      if shape[1] == len(y) and shape[0] == len(x) and shape[0] != shape[1]:
         what_I_plot = swapaxes(what_I_plot,0,1)
         print "INFO: swapaxes", what_I_plot.shape, shape
   return what_I_plot,x,y

# Author: TN + AS
def determineplot(slon, slat, svert, stime):
    nlon = 1 # number of longitudinal slices -- 1 is None
    nlat = 1
    nvert = 1
    ntime = 1
    nslices = 1
    if slon is not None:
        nslices = nslices*len(slon)
        nlon = len(slon)
    if slat is not None:
        nslices = nslices*len(slat)
        nlat = len(slat)
    if svert is not None:
        nslices = nslices*len(svert)
        nvert = len(svert)
    if stime is not None:
        nslices = nslices*len(stime)
        ntime = len(stime)
    #else:
    #    nslices = 2  

    mapmode = 0
    if slon is None and slat is None:
       mapmode = 1 # in this case we plot a map, with the given projection

    return nlon, nlat, nvert, ntime, mapmode, nslices
