"""This module includes functions and classes representing 3D clouds and cloud
microphysics.  This module has all you need for the generation of scattering
data files (via the arts_scat module) and particle number density fields, which
enable the representation of 3D cloud fields in ARTS simulations

Example of use: a 3D ice and liquid cloud field.

#load iwc and lwc fields from TRMM data
iwc_field=arts_types.GriddedField3().load('../016/iwc_field.xml')
lwc_field=arts_types.GriddedField3().load('../016/lwc_field.xml')

#load temperatrue field I prepared earlier
t_field=arts_types.GriddedField3().load('t_field.xml')

#these need to be padded to account for a nasty inefficiency in ARTS
iwc_field.pad()
lwc_field.pad()
t_field.pad()

#Define hydrometeors
ice_column=clouds.Crystal(NP=-2,aspect_ratio=0.5,ptype=30,npoints=10)
water_droplet=clouds.Droplet(c1=6,c2=1,rc=20)

#Create cloud field
a_cloud=clouds.Cloud(t_field=t_field,iwc_field=iwc_field,lwc_field=lwc_field)
#add hydrometeors
a_cloud.addHydrometeor(ice_column,habit_fraction=1.0)
a_cloud.addHydrometeor(water_droplet,habit_fraction=1.0)

#generate (or find existing) single scattering data
a_cloud.scat_file_gen(f_grid=[200e9,201e9],num_proc=2)

#generate pnd fields
a_cloud.pnd_field_gen('pnd_field.xml')

#save cloud object for later
quickpickle(a_cloud,'Cloud3D.pickle')


"""


from general import dict_combine_with_default,DATA_PATH
from arts_types import *
from Numeric import *
import artsXML
import MLab
import arts_math
import arts_scat
import Laguerre


#SIZE DISTRIBUTIONS

def mh97(IWC,r,TK):
    """MH97 McFarquhar and Heymsfield 1997's particle size distribution
    n,integrated_n=mh97(IWC,r,TK)
    IWC: ice water content in gm^-3
    r:   Array of radii in microns 
    TK:   temperature in Kelvin 
    n:   number density in l^-1mm^-1 (or m^-3micron(radius)^-1)
    integrated_n:  integrated number density in each size bin in m^-3
    """
    if IWC==0:
        return zeros(len(r),Float),zeros(len(r),Float)

    D=r*2
    T=TK-273.15
    #Calculate bin boundarys
    Di=zeros(len(D)+1,Float)
    Di[0]=1
    Di[1:len(D)]=0.5*(D[:len(D)-1]+D[1:len(D)])
    Di[len(D)]=1.5*D[len(D)-1]-0.5*D[len(D)-2]
    
    IWC_0=1.0; #gm^-3
    IWC_lt=min(IWC,0.252*(IWC/IWC_0)**0.837)
    IWC_gt=IWC-IWC_lt
    alpha=-4.99e-3-0.049*log10(IWC_lt/IWC_0)#micron^-1
    rho_ice=0.917#gcm^-3
    D_0=1.0#micron
    
    
    #Some abbreviations
    a=IWC_lt*alpha**5/4/pi/rho_ice*1e-6#micron^-5
    b=alpha#micron^-1

    #Gamma distributionn component
    gamma_n=a*D*exp(-b*D)#micron^-4
    gamma_int=a*exp(-b*Di)*(-1/b**2-Di/b)#micron^-3


    #Log normal component
    if IWC_gt>0:
        sigma=0.47+2.1e-3*T+(0.018-2.1e-4*T)*log10(IWC_gt/IWC_0)
        mu=5.2+0.0013*T+(0.026-1.2e-3*T)*log10(IWC_gt/IWC_0)
        c=6*IWC_gt/(sqrt(2*pi**3)*rho_ice*D_0**3*sigma*\
                    exp(3*mu+4.5*sigma**2))*1e-6
        g=sigma
        f=mu
        lognorm_n=c/D*exp(-0.5*((log(D/D_0)-f)/g)**2)
        lognorm_int=c*g*sqrt(pi/2)*arts_math.erf((log(Di/D_0)-f)/sqrt(2)/g)
    else:
        lognorm_n=zeros(len(r),Float)
        lognorm_int=zeros(len(r),Float)
        
    #Combine
    n=gamma_n+lognorm_n
    integrated_n=gamma_int[1:]+lognorm_int[1:]-\
		  gamma_int[:len(Di)-1]-lognorm_int[:len(Di)-1]
    n=n*2e18;   # convert to l^-1mm^-1 m^-3micron-1(radius)
    integrated_n=integrated_n*1e18 # convert to m^-3
    
    IWC_check=sum(integrated_n*4/3*pi*(D*1e-6/2)**3*rho_ice*1e6)
    
    #This seems to be what JPL do to conserve IWC.  The alternative would be to
    #take bin boundaries and choose the appropriate D that conserves IWC.
    integrated_n=integrated_n*IWC/IWC_check
    return n,integrated_n

def mh97_int(IWC,r,TK):
    """calls mh97 but only returns the integrated number density over each
    size-bin"""
    n,integrated_n=mh97(IWC,r,TK)
    return integrated_n

def nioku(LWC,r,rc,c1,c2):
    """The nioku size distribution for water droplets as given in the MLS cloudy sky
    ATBD. The units are  l-1mm-1 (m-3micron-1)"""
    B=float(c1)/(c2*rc**c2)
    A=3*LWC*c2*B**((c1+4.0)/c2)*1e12/(4*pi*arts_math.gosper((c1+4.0)/c2-1))
    n=A*r**c1*exp(-B*r**c2)
    return n

def nioku_int(LWC,r,rc,c1,c2):
    """Gives the integrated droplet number density for each size bin ,with size
    bin radii defined by the input vector r"""
    #define function to be integrated
    def func(r1):
	return nioku(LWC,r1,rc,c1,c2)
    intn=zeros(r.shape,Float)
    #do the integrations, summing estimated LWC along the way
    sum=0
    intn[0]=arts_math.qromb(func,0.01,(r[0]+r[1])/2,EPS=0.05,JMAX=10,K=3)[0]
    sum+=intn[0]*4*pi/3*1e-12*r[0]**3
    for i in range(1,len(r)-1):
	intn[i]=arts_math.qromb(func,(r[i]+r[i-1])/2,(r[i]+r[i+1])/2,
				EPS=0.05,JMAX=10,K=3)[0]
	sum+=intn[i]*4*pi/3*1e-12*r[i]**3
    intn[-1]=arts_math.qromb(func,(r[-1]+r[-2])/2,r[-1]+(r[-1]-r[-2])/2,
			     EPS=0.05,JMAX=10,K=3)[0]
    sum+=intn[-1]*4*pi/3*1e-12*r[-1]**3
    #rescale to conserve LWC
    #print LWC/sum
    intn*=LWC/sum
    return intn


###################Cloud structures################################

class Cloud:
    """A high level class for the generation of ARTS cloud field data.
    A Cloud object is initialised with up to three arts_type.GriddedField3
    objects representing 3D temperature, ice water content, and liquid
    water content fields.  The temperature field is compulsory but either
    IWC or LWC may be omitted.
    Droplet or Crystal objects can then be added to the Cloud Object using
    the addHydrometeor method.  The user is encouraged to create their own
    Hydrometeor classes (all that is required is that they have scat_calc
    and pnd_calc methods with the same input/output arguments).
    The scat_file_gen and pnd_field_gen methods create the single
    scattering data files and particle number density files required to
    represent the cloud field in ARTS simulations.

    """
    def __init__(self,t_field,iwc_field=None,lwc_field=None):
        """All three inputs are arts_type.GriddedField3 objects.  Either
        iwc_field or lwc_field may be omitted if you have a single phase
        cloud field."""
        self.iwc_field=iwc_field
        self.lwc_field=lwc_field
        self.t_field=t_field
        if not iwc_field==None:
            self.p_grid=iwc_field['p_grid']
            self.lat_grid=iwc_field['lat_grid']
            self.lon_grid=iwc_field['lon_grid']
        elif not lwc_field==None:
            self.p_grid=lwc_field['p_grid']
            self.lat_grid=lwc_field['lat_grid']
            self.lon_grid=lwc_field['lon_grid']
        else:
            raise 'Must supply either IWC or LWC field (or both)'
        self.hydrometeors=[]
        self.habit_fractions=[]
        self.scat_files=[]
        self.pnd_fields=[]

    def addHydrometeor(self,hydrometeor,habit_fraction=1.0):
        """Adds a hydrometeor (e.g. a Droplet or Crystal object) to the cloud
        object.  The habit_fraction argument allows the implementation of
        multi-habit ice clouds. The habit_fractions for all of the added
        Crystal objects should add up to 1.0. Otherwise the specified iwc_field
        will not be reproduced. """
        self.hydrometeors.append(hydrometeor)
        self.habit_fractions.append(habit_fraction)
        return self
    
    def scat_file_gen(self,f_grid,za_grid=arange(0,181,10),
		      aa_grid=arange(0,181,10),num_proc=1):
        """Calculates all of the single scattering data files required to
        represent the cloud field in an ARTS simulation.  The file names are
        stored in the scat_files data member.  The input arguments are
        f_grid,T_grid,za_grid, and aa_grid: numpy arrays determining the
        corresponding data in the arts_scat.SingleScatteringData objects. The
        optional argument num_proc determines the number of processes used to
        complete the task."""
        for hydrometeor in self.hydrometeors:
            self.scat_files.extend(hydrometeor.scat_calc(f_grid, za_grid,
                                                         aa_grid, num_proc))
        return self

    def pnd_field_gen(self,filename=''):
        """Calculates the pnd data required to represent the cloud field in
        an ARTS simulation.  THis produces an arts_types.ArrayOfGriddedField3
        object, which has the same number of elements as the scat_files data
        member.  This is stored in the pnd_data member and output to *filename*
        in ARTS XML format."""
        for i in range(len(self.hydrometeors)):
            hydrometeor=self.hydrometeors[i]
            pnd_fields=hydrometeor.pnd_calc(self.lwc_field,self.iwc_field,
                                            self.t_field)
            for pnd_field in pnd_fields:
                #multiply data by habit fraction
                pnd_field['data']*=self.habit_fractions[i]
            self.pnd_fields.extend(pnd_fields)
        self.pnd_data=ArrayOfGriddedField3(self.pnd_fields)
        self.pnd_data.save(filename)
        self.pnd_file=filename
        return self

def boxcloud(ztopkm,zbottomkm,lat1,lat2,lon1,lon2,cb_size,zfile,tfile,IWC):
    """Return a box shaped Cloud object"""
    zfield=GriddedField3().load(zfile)
    tfield=GriddedField3().load(tfile)
    ztopm=ztopkm*1e3
    zbottomm=zbottomkm*1e3
    
    p1=exp(arts_math.interp(MLab.squeeze(zfield['data']),log(zfield['p_grid']),zbottomm))
    p2=exp(arts_math.interp(MLab.squeeze(zfield['data']),log(zfield['p_grid']),ztopm))
    
    #create zero valued gridpoints just outside the cloudbox
    p_grid=zeros(cb_size['np']+2,Float)
    lat_grid=zeros(cb_size['nlat']+2,Float)
    lon_grid=zeros(cb_size['nlon']+2,Float)
    data=zeros([cb_size['np']+2,cb_size['nlat']+2,cb_size['nlon']+2],Float)
    
    p_grid[0]=p1+1;p_grid[-1]=p2-1
    p_grid[1:-1]=arts_math.nlogspace(p1,p2,cb_size['np'])
    
    lat_grid[0]=lat1-0.1;lat_grid[-1]=lat2+0.1
    lat_grid[1:-1]=arts_math.nlinspace(lat1,lat2,cb_size['nlat'])
    
    lon_grid[0]=lon1-0.1;lon_grid[-1]=lon2+0.1
    lon_grid[1:-1]=arts_math.nlinspace(lon1,lon2,cb_size['nlon'])
    
    data[1:-1,1:-1,1:-1]=IWC*ones([cb_size['np'],cb_size['nlat'],cb_size['nlon']],Float)
    
    iwcfield=GriddedField3({'p_grid':p_grid,
                            'lat_grid':lat_grid,
                            'lon_grid':lon_grid,
                            'data':data})
    
    old_t_grid=MLab.squeeze(tfield['data'])
    new_t_grid=arts_math.interp(log(tfield['p_grid']),old_t_grid,log(p_grid))
    tdata=zeros([cb_size['np']+2,cb_size['nlat']+2,cb_size['nlon']+2],Float)
    for i in range(len(lat_grid)):
        for j in range(len(lat_grid)):
            tdata[:,i,j]=new_t_grid
    
    tfield=GriddedField3({'p_grid':p_grid,
                          'lat_grid':lat_grid,
                          'lon_grid':lon_grid,
                          'data':tdata})
    
    iwcfield.pad()
    tfield.pad()
    the_cloud=Cloud(iwc_field=iwcfield,t_field=tfield)
    the_cloud.cloudbox={'p1':p1,'p2':p2,'lat1':lat1,'lat2':lat2,'lon1':lon1,
                        'lon2':lon2}
    return the_cloud


##############Microphysics########################

class Droplet:
    """produces scattering data and pnd fields for liquid water clouds.
    The size distribution is the modified gamma distribution of Nioku, as used
    in the EOSMLS cloudy-sky forward model.  A Droplet object is initialised
    with the distribution parameters c1, c2, and rc. SCattering properties are
    integrated over the size distribution using an *npoints* Laguerre Gauss
    quadrature, to give a one arts_scat.SingleScatteringData object.  The pnd
    field is then simply scaled by the lwc_field.  The methods scat_calc and
    pnd_calc are called by the parent Cloud object"""
    def __init__(self,c1,c2,rc, npoints=3, T_grid=[260,280,300,320]):
        """The parameters c1,c2,rc are for the nioku distribution"""
        self.c1=c1
        self.c2=c2
        self.rc=rc
        self.T_grid=T_grid
        
        #First get Laguerre Gauss weights and abscissa
        x=array(arts_math.LagGaussData[npoints]['a'])
        w=array(arts_math.LagGaussData[npoints]['w'])
        #turn the abscissa into radii
        B=float(c1)/(c2*rc**c2)
        self.r=(x/B)**(1/c2)
        A=3*c2*B**((c1+4.0)/c2)*1e12/(4*pi*arts_math.gosper((c1+4.0)/c2-1))#gosper is not very good
        self.pnd_vec=A*w
    def scat_calc(self,f_grid, za_grid=arange(0,181,10),
		      aa_grid=arange(0,181,10),num_proc=1):
        """produces a single scattering data object that can be simply
        scaled by LWC.  The xx_grid variables have the same meaning as in
        arts_scat.SingleScatteringData objects. num_proc determines the number
        of processes used to complete the task. Returns a list (length=1) of
        scattering data files"""
        self.scat_file=DATA_PATH+'/scat/Droplet'+str(self.c1)+'_'+str(self.c2)+'_'+str(self.rc)+\
                        'f'+str(f_grid[0])+'-'+str(f_grid[-1])+'T'+\
                        str(self.T_grid[0])+'-'+str(self.T_grid[-1])+'.xml'
        #generate single scattering data files
        scat_params={'f_grid':[f_grid],'T_grid':[self.T_grid],'NP':[-1],
                     'aspect_ratio':[1.000001],'ptype':[20],
                     'phase':'liquid','equiv_radius':self.r,'za_grid':za_grid,
                     'aa_grid':aa_grid}
        scat_files=arts_scat.batch_generate(scat_params,num_proc)
        #Combine these to form a single scattering data object
        
        scat_data_list=[]
        for fname in scat_files:
            scat_data_list.append(arts_scat.SingleScatteringData().load(fname))
        scat_data=arts_scat.combine(scat_data_list,self.pnd_vec)
        scat_data.file_gen(self.scat_file)
        return [self.scat_file]

    def pnd_calc(self,LWC_field,IWC_field,T_field):
        """calculate the pnd field associated with the scattering data
        calculated by the scat_calc method above.  Returns a list (length=1)
        of pnf fields"""
        return [LWC_field]

class Crystal:
    """produces scattering data and pnd fields for ice clouds
    The size distribution is the McFarquhar- Heymsfield 1997 distribution, as
    used in the EOSMLS cloudy-sky forward model.  A Crystal object is initialised
    with the particle parameters ptype, aspect_ratio, NP, which have the same
    meaning as in arts_scat.SingleScatteringData objects. equivalent particle
    radii and pnd values are determined from the abscissas and weights for an
    *npoints* Laguerre Gauss
    quadrature, to give *npoints* arts_scat.SingleScatteringData objects, and
    a pnd field corresponding to these particles.  The methods scat_calc and
    pnd_calc are called by the parent Cloud object"""
    
    def __init__(self,ptype,aspect_ratio,NP, npoints=10, T_grid=[215,272.15]):
        """A Crystal object is initialised with the particle parameters ptype,
        aspect_ratio, NP, which have the same meaning as in
        arts_scat.SingleScatteringData objects. equivalent particle radii and
        pnd values are determined from the abscissas and weights for an
        *npoints* Laguerre Gauss quadrature"""
        self.ptype=ptype
        self.aspect_ratio=aspect_ratio
        self.NP=NP
        self.a=0.02#this can be tuned
        #First get Laguerre Gauss weights and abscissa
        x,self.w=Laguerre.laggausdata(npoints,0)
        self.r=x/2/self.a
        self.T_grid=T_grid
        
    def scat_calc(self,f_grid,  za_grid=arange(0,181,10),
		      aa_grid=arange(0,181,10), num_proc=1):
        """produces npoints scattering data objects, and returns a list
        (length=npoints) of scattering data files"""
        #generate single scattering data files
        scat_params={'f_grid':[f_grid],'T_grid':[self.T_grid],'NP':[self.NP],
                     'aspect_ratio':[self.aspect_ratio],'ptype':[self.ptype],
                     'phase':'ice','equiv_radius':self.r,'za_grid':za_grid,
                     'aa_grid':aa_grid}
        scat_files=arts_scat.batch_generate(scat_params,num_proc)
        return scat_files

    def pnd_calc(self,LWC_field,IWC_field,T_field):
        """Returns a list (length=npoints) of pnd fields"""
        p_grid=IWC_field['p_grid']
        lat_grid=IWC_field['lat_grid']
        lon_grid=IWC_field['lon_grid']
        pnd_list=[]
        for ir in range(len(self.r)):
            r=self.r[ir]
            data=zeros([len(p_grid),len(lat_grid),len(lon_grid)],Float)
            for ip in range(len(p_grid)):
                for ilat in range(len(lat_grid)):
                    for ilon in range(len(lon_grid)):
                        n,integrated_n=mh97(IWC_field['data'][ip,ilat,ilon],
                                            array([r]),T_field['data'][ip,ilat,ilon])
                        data[ip,ilat,ilon]=n[0]*exp(2*self.a*r)/2/self.a*self.w[ir]
            pnd_list.append(GriddedField3({'p_grid':p_grid,
                                           'lat_grid':lat_grid,
                                           'lon_grid':lon_grid,
                                           'data':data}))
        return pnd_list

##################################################
#Some tests

import unittest

class NiokuTest(unittest.TestCase):
    def runTest(self):
	IWC=0.01
	rc=10.0
	c1=6.0
	c2=1.0
	def func(r):
	    return nioku(IWC,r,rc,c1,c2)*r**3
	IWC_from_nioku=4*pi/3*1e-12*arts_math.qromb(func,0.01,1000.0)[0]
	print 'IWC = '+str(IWC_from_nioku)
	assert (abs(IWC_from_nioku-IWC)/IWC < 0.01),"normalisation problem in nioku"

def test_suite():
    Suite=unittest.TestSuite()
    Suite.addTest(NiokuTest())
    return Suite

def run_tests():
    TestRunner=unittest.TextTestRunner()
    TestRunner.run(test_suite())


