import time
from Numeric import *
from FFT import *

def PtSample(n, step=1, type='FFT'):
    if type=='FFT':
        if step: tmp=arange(0, n*step, step, Float)
        else: tmp=zeros(n, Float)
    elif type=='FD':
        if step: tmp=arange(0, (n+1)*step, step, Float)
        else: tmp=zeros(n+1, Float)
    else:
        raise '"Type" not valid!'
    return tmp

def PtIndex(n1, n2, step=1, type='FFT'):
    if type=='FFT':
        tmp=range(n1,n2,step)
    elif type=='FD':
        tmp=range(n1,n2+1,step)
    else:
        raise '"Type" not valid!'
    return tmp

def FourierCoeff(n, d):
    wk=zeros(n,Float)
    pdn= pi*2./(d*n)
    n2=n/2
    if n2 >= 2:
        for i in range(1,n):
            if i<n2:
                wk[i] = pdn*i
            else:
                wk[i]=-pdn*(n-i)
    return wk

def FourierDeriv(func, order=1, coeff=None):
    if coeff==None: coeff=FourierCoeff(len(func), 2*pi/len(func))
    for i in range(order):
        f=fft(func)
        func=inverse_fft(1j*coeff*f)
    return func

# ==============================================================================

def FourierSolver1dHomo(I, f, c, U_0, U_L, L, n, dt, tstop, user_action=None):

    t0 = time.clock() # measure the CPU time

    dx = L/float(n)
    x = PtSample(n,step=dx,type='FFT') # grid points in x dir
    if dt <= 0: dt = 2*dx/(float(c)*pi) # max time step?
    C2 = (c*dt)**2 # help variable in the scheme
    dt2 = dt*dt

    up = PtSample(n,step=0,type='FFT') # solution array
    u = up.copy() # solution at t-dt
    um = up.copy() # solution at t-2*dt

    k=FourierCoeff(n,dx)

    t = 0.0
    u = I(x)
    um = u + 0.5*C2*FourierDeriv(u,coeff=k,order=2).real + dt2*f(x, t)

    if user_action is not None: user_action(u,x,t)

    while t <= tstop:
        t_old = t; t += dt
        up = - um + 2*u + C2*FourierDeriv(u,coeff=k,order=2).real + dt2*f(x, t_old)

        # insert boundary conditions:
        up[0] = U_0(t)

        if user_action is not None: user_action(up,x,t)

        # update data structures for next step:
        um, u, up = u, up, um

    t1 = time.clock()
    return t1-t0

# ==============================================================================

def FourierSolver1dNotHomo(I, f, c, U_0, U_L, L, n, dt, tstop, user_action=None):

    t0 = time.clock() # measure the CPU time

    dx = L/float(n)
    x = PtSample(n,step=dx,type='FFT') # grid points in x dir
    cc= PtSample(n,step=0,type='FFT')
    cc=c(x)
    cc2=cc**2
    if dt <= 0: dt = 2.*dx/(float(max(cc))*pi) # max time step?
    dt2 = dt*dt

    up = PtSample(n,step=0,type='FFT') # solution array
    u = up.copy() # solution at t-dt
    um = up.copy() # solution at t-2*dt

    k=FourierCoeff(n,dx)

    t = 0.0
    u = I(x)
    um = u + 0.5*FourierDeriv(cc2*FourierDeriv(u,coeff=k).real,coeff=k).real*dt2 + dt2*f(x, t)

    if user_action is not None: user_action(u,x,t)

    while t <= tstop:
        t_old = t; t += dt
        up = - um + 2*u + FourierDeriv(cc2*FourierDeriv(u,coeff=k).real,coeff=k).real*dt2 + dt2*f(x, t_old)

        # insert boundary conditions:
        up[0] = U_0(t)

        if user_action is not None: user_action(up,x,t)

        # update data structures for next step:
        um, u, up = u, up, um

    t1 = time.clock()
    return t1-t0

# ==============================================================================

def FDSolver1dHomo(I, f, c, U_0, U_L, L, n, dt, tstop, user_action=None, version='vectorized'):

    t0 = time.clock() # measure the CPU time

    dx = L/float(n)
    x = PtSample(n,step=dx,type='FD') # grid points in x dir
    if dt <= 0: dt = dx/float(c) # max time step?
    C2 = (c*dt/dx)**2 # help variable in the scheme
    dt2 = dt*dt

    up = PtSample(n,step=0,type='FD') # solution array
    u = up.copy() # solution at t-dt
    um = up.copy() # solution at t-2*dt

    t = 0.0
    for i in PtIndex(0,n,type='FD'):
        u[i] = I(x[i])
    for i in PtIndex(1,n-1,type='FD'):
        um[i] = u[i] + 0.5*C2*(u[i-1] - 2*u[i] + u[i+1]) + dt2*f(x[i], t)
    um[0] = U_0(t+dt); um[n] = U_L(t+dt)

    if user_action is not None: user_action(u, x, t)

    while t <= tstop:
        t_old = t; t += dt
        # update all inner points:
        if version == 'scalar':
            for i in PtIndex(1,n-1):
                up[i] = - um[i] + 2*u[i] + C2*(u[i-1] - 2*u[i] + u[i+1]) + dt2*f(x[i], t_old)
        elif version == 'vectorized':
            up[1:n] = - um[1:n] + 2*u[1:n] + C2*(u[0:n-1] - 2*u[1:n] + u[2:n+1]) + dt2*f(x[1:n], t_old)
        else:
            raise ValueError, 'version=%s' % version

        # insert boundary conditions:
        up[0] = U_0(t); up[n] = U_L(t)

        if user_action is not None: user_action(up, x, t)

        # update data structures for next step:
        um, u, up = u, up, um

    t1 = time.clock()
    return t1-t0

# ==============================================================================

def FDSolver2dHomo(I,f,c,bc,Lx,Ly,nx,ny,dt,tstop,user_action=None):

    t0 = time.clock() # measure the CPU time

    dx=Lx/float(nx)
    dy=Ly/float(ny)
    x=arange(0,Lx+dx,dx)
    y=arange(0,Ly+dy,dy)
    if dt<=0:
        dt=(1/float(c))*(1/sqrt(1/dx**2+1/dy**2))
    Cx2=(c*dt/dx)**2; Cy2=(c*dt/dy)**2
    dt2=dt**2
    
    up=zeros((nx+1,ny+1), Float)
    u=up.copy()
    um=up.copy()

    t=0.0
    for i in range(0,nx+1):
        for j in range(0,ny+1):
            u[i,j]=I(x[i],y[j])
    for i in range(1,nx):
        for j in range(1,ny):
            um[i,j]=u[i,j]+\
               0.5*Cx2*(u[i-1,j]-2*u[i,j]+u[i+1,j])+\
               0.5*Cy2*(u[i,j-1]-2*u[i,j]+u[i,j+1])+\
               dt2*f(x[i],y[j],t)
    i=0
    for j in range(0,ny+1): um[i,j]=bc(x[i],y[j],t+dt)
    j=0
    for i in range(0,nx+1): um[i,j]=bc(x[i],y[j],t+dt)
    i=nx
    for j in range(0,ny+1): um[i,j]=bc(x[i],y[j],t+dt)
    j=ny
    for i in range(0,nx+1): um[i,j]=bc(x[i],y[j],t+dt)

    if user_action is not None: user_action(u,x,y,t)

    while t<=tstop:
        t_old=t; t +=dt
        for i in range(1,nx):
            for j in range(1,ny):
                up[i,j]=-um[i,j]+2*u[i,j]+\
                         Cx2*(u[i-1,j]-2*u[i,j]+u[i+1,j])+\
                         Cy2*(u[i,j-1]-2*u[i,j]+u[i,j+1])+\
                         dt2*f(x[i],y[j],t_old)
                
        i=0
        for j in range(0,ny-1): up[i,j]=bc(x[i],y[j],t)
        j=0
        for i in range(0,nx-1): up[i,j]=bc(x[i],y[j],t)
        i=nx
        for j in range(0,ny-1): up[i,j]=bc(x[i],y[j],t)
        j=ny
        for i in range(0,nx-1): up[i,j]=bc(x[i],y[j],t)
        
        if user_action is not None: user_action(up,x,y,t)
            
        um, u, up=u, up, um
   
    t1 = time.clock()

    return t1-t0

# ==============================================================================

def polyVal(x, fjat):
    """
    Given x and fjat(xi,yi), having n values of the function f, 
    polyVal returns the value y=P(x) of the Lagrange interpolant 
    of f with order (n-1), such that yi=P(xi). 
    It uses the Neville's algorithm. 
    It returns also an error estimate dy.
    """
    xy=transpose(fjat)
    c=xy[1].astype('Float');   d=c.copy()
    jts=xy[0].astype('Float')-x
    n=size(c); 
    jD=zeros(n,Float);  
    js=argmin(fabs(jts))
    y=c[js]
    js=js-1
    for m in range(n-1):
        nn=n-m-1
        jD[0:nn]=jts[0:nn]-jts[1+m:n]
        if (sometrue(jD[0:nn] == 0.0)):
            print'polyVal: calculation failure'
        jD[0:nn]=(c[1:n-m]-d[0:nn])/jD[0:nn]
        d[0:nn]=jts[1+m:n]*jD[0:nn]
        c[0:nn]=jts[0:nn]*jD[0:nn]
        if (2*js < nn-1):
            dy=c[js+1]
        else:
            dy=d[js]
            js=js-1
        y=y+dy  #end do
        # dy error in convergence
    return y

def testpolyVal():
    t0=time.clock()
    x=arange(0.,pi,pi/100.)
    y=sin(x)
    fjat=transpose([x,y])
    #print fjat
    for xp in arange(0.,pi,pi/1001.): 
        pv=polyVal(xp,fjat)
        #print xp, sin(xp)- pv
    print 'time ',time.clock()-t0
