import numpy as np
from numpy.linalg import norm, solve
import matplotlib.pyplot as plt
plt.rcParams['axes.grid'] = True
plt.ion()

plt.figure(2)
plt.clf()
plt.figure(1)
plt.clf()


def euler(f, t0, y0, tend, N=100):
    '''
    Euler's method for solving y'=f(t,y), y(t0)=y0.
    '''
    h = (tend-t0)/N         # Stepsize

     
    # In the case of a scalar problem, convert y0 to a numpy vector.
    if not isinstance(y0, np.ndarray): 
        y0 = np.array([y0])
        m = 1
    else:
        m = len(y0)
    
    # Make arrays for storing the solution. 
    ysol = np.zeros((N+1, m))
    ysol[0,:] = y0
    tsol = np.zeros(N+1)
    tsol[0] = t0

    # Main loop
    for n in range(N):
        y = ysol[n,:]
        t = tsol[n]
        
        # One step of Euler's method
        ynext = y+h*f(t,y)
        tnext = t+h

        ysol[n+1,:] = ynext
        tsol[n+1] = tnext
  
    # In case of a scalar problem, convert the solution into a 1d-array
    if m==1:
        ysol = ysol[:,0] 

    return tsol, ysol
# end of euler

def heun(f, t0, y0, tend, N=100):
    '''
    Heun's method for solving y'=f(t,y), y(t0)=y0.
    '''
    h = (tend-t0)/N         # Stepsize

     
    # In the case of a scalar problem, convert y0 to a numpy vector.
    if not isinstance(y0, np.ndarray): 
        y0 = np.array([y0])
        m = 1
    else:
        m = len(y0)
    
    # Make arrays for storing the solution. 
    ysol = np.zeros((N+1, m))
    ysol[0,:] = y0
    tsol = np.zeros(N+1)
    tsol[0] = t0

    # Main loop
    for n in range(N):
        y = ysol[n,:]
        t = tsol[n]
        
        # One step with Heun's method
        k1 = f(t,y)
        k2 = f(t+h, y+h*k1)
        ynext = y+0.5*h*(k1+k2)
        tnext = t+h

        ysol[n+1,:] = ynext
        tsol[n+1] = tnext
  
    # In case of a scalar problem, convert the solution into a 1d-array
    if m==1:
        ysol = ysol[:,0] 

    return tsol, ysol
# end of heun




def ode_example1():
    # Numerical example 1
    # Define the problem to be solved
    def f(t,y):    
        # The right hand side of y' = f(t,y)
        return -2*t*y

    # Set the initial value
    y0 =   1            
    t0, tend = 0, 1   

    # Number of steps
    N = 10

    # Solve the equation numerically
    tsol, ysol = euler(f, t0, y0, tend, N=N)

    # Plot the numerical solution together with the exact, if available
    texact = np.linspace(0,1,1001)
    plt.plot(tsol, ysol, 'o', label='Euler')
    plt.plot(texact, np.exp(-texact**2), '--', label='Exact')
    plt.legend()
    plt.xlabel('t');
    plt.figure(2) # only for the python file

    # Plot the error 
    error = np.abs(ysol-np.exp(-tsol**2))
    plt.semilogy(tsol, error, 'o--')
    plt.title('Error')
    print(f'Max error is {np.max(error):.2e}')
# end of ode_example1


def ode_example2():
    # Numerical example 2
    # Define the problem to be solved
    def lotka_volterra(t,y):    
        # The right hand side of y' = f(t,y)
        alpha, beta, delta, gamma = 2, 1, 0.5, 1
        dy = np.zeros(2)
        dy[0] = alpha*y[0]-beta*y[0]*y[1]
        dy[1] = delta*y[0]*y[1] - gamma*y[1]
        return dy

    # Set the initial value
    y0 = np.array([2.0, 0.5])           
    t0, tend = 0, 20   

    # Number of steps
    N = 1000

    # Solve the equation numerically
    tsol, ysol = euler(lotka_volterra, t0, y0, tend, N=N)

    # Plot the numerical solution i
    plt.plot(tsol, ysol)
    plt.legend(['y1','y2'])
    plt.xlabel('t');
# end of ode_example2

def ode_example3():
    # Numerical example 3
    # Define the problem to be solved
    def f(t,y):    
        # The right hand side of y' = f(t,y)
        return -2*t*y

    # Set the initial value
    y0 =   1            
    t0, tend = 0, 1   

    # Number of steps
    N = 10

    # Solve the equation numerically 
    # with Euler's method
    t_euler, y_euler = euler(f, t0, y0, tend, N=N)
    # And with Heun's method
    t_heun, y_heun = heun(f, t0, y0, tend, N=N//2) 


    # Plot the numerical solution together with the exact, if available
    plt.plot(t_euler, y_euler, 'o', label='Euler')
    plt.plot(t_heun, y_heun, 'd', label='Heun')
    texact = np.linspace(0,1,1001)
    plt.plot(texact, np.exp(-texact**2), '--', label='Exact')
    plt.legend()
    plt.xlabel('t');
    plt.figure(2) # only for the python file ex3

    # Plot the errors for both methods
    error_euler = np.abs(y_euler-np.exp(-t_euler**2))
    error_heun = np.abs(y_heun-np.exp(-t_heun**2))
    plt.semilogy(t_euler, error_euler, 'o--', label='Euler')
    plt.semilogy(t_heun, error_heun, 'd--', label='Heun')
    plt.title('Error')
    print(f'Max error for Euler is {np.max(error_euler):.2e} and for Heun {np.max(error_heun):.2e}')
    # end of error plot ex3
    print('\n\n')

    # Print the error as a function of h
    y_exact = np.exp(-tend**2)
    N = 10
    print('Error in Euler and Heun\n')
    print('h           Euler       Heun')
    for n in range(10):
        t_euler, y_euler = euler(f, t0, y0, tend, N=N)
        t_heun, y_heun = heun(f, t0, y0, tend, N=N//2) 
        error_euler = np.abs(y_exact-y_euler[-1])
        error_heun = np.abs(y_exact-y_heun[-1])
        print(f'{(tend-t0)/N:.3e}   {error_euler:.3e}   {error_heun:.3e}')
        N = 2*N

# end of ode_example3

