import numpy as np
import matplotlib.pyplot as plt
plt.ion()
plt.clf()

def simpson(f, a, b, m=10):
    """
    Find an approximation to an integral by the composite Simpson's method:
    Input:  
       f:    integrand
       a, b: integration interval
       m:    number of subintervals
    Output: The approximation to the integral
    """
    n = 2*m
    x_nodes = np.linspace(a, b, n+1)       # equidistributed nodes from a to b 
    h = (b-a)/n                            # stepsize
    S1 = f(x_nodes[0]) + f(x_nodes[n])     # S1 = f(x_0)+f(x_n)
    S2 = np.sum(f(x_nodes[1:n:2]))         # S2 = f(x_1)+f(x_3)+...+f(x_m)
    S3 = np.sum(f(x_nodes[2:n-1:2]))       # S3 = f(x_2)+f(x_4)+...+f(x_{m-1})
    S = h*(S1 + 4*S2 + 2*S3)/3
    return S
# end of simpson

def simpson_basic(f, a, b):
    """
    Simpson's method with error estimate
    Input:  
       f:    integrand
       a, b: integration interval
    Output:
       S_2(a,b) and the error estimate.
    """ 
    # The nodes 
    c = 0.5*(a+b)
    d = 0.5*(a+c)
    e = 0.5*(c+b)
    
    # Calculate S1=S_1(a,b), S2=S_2(a,b) 
    H = b-a
    S1 = H*(f(a)+4*f(c)+f(b))/6
    S2 = 0.5*H*(f(a)+4*f(d)+2*f(c)+4*f(e)+f(b))/6

    error_estimate = (S2-S1)/15    # Error estimate for S2
    return S2, error_estimate
# end of simpson_basic

def simpson_adaptive(f, a, b, tol = 1.e-6, level = 0, maks_level=15):
    """
    Simpson's adaptive method
    Input:  
       f:    integrand
       a, b: integration interval
       tol:  tolerance
       level, maks_level: For the recursion. Just ignore them. 
    Output:
       The approximation to the integral
    """ 
    
    Q, error_estimate = simpson_basic(f, a, b)    # The quadrature and the error estimate 
    
    # -------------------------------------------------
    # Write the output, and plot the nodes. 
    # This part is only for illustration. 
    if level == 0:
        print(' l   a           b         feil_est   tol')
        print('==============================================') 
    print('{:2d}   {:.6f}   {:.6f}   {:.2e}   {:.2e}'.format(
            level, a, b, abs(error_estimate), tol))
    
    x = np.linspace(a, b, 101)
    plt.plot(x, f(x), [a, b], [f(a), f(b)], '.r')
    plt.title('The integrand and the subintervals')
    # -------------------------------------------------
    
    if level >= maks_level:
        print('Warning: Maximum number of levels used.')
        return Q
    
    if abs(error_estimate) < tol:         # Accept the result, and return
        result = Q + error_estimate      
    else:
        # Divide the interval in two, and apply the algorithm to each interval.
        c = 0.5*(b+a)
        result_left  = simpson_adaptive(f, a, c, tol = 0.5*tol, level = level+1)
        result_right = simpson_adaptive(f, c, b, tol = 0.5*tol, level = level+1)
        result = result_right + result_left
    return result
# end of simpson_adaptive

def test1():
    # Numerical experiment 1
    def f(x):                   # Integrand
        return 4*x**3+x**2+2*x-1    
    a, b = -1, 2                # Integration interval
    I_exact = 18.0              # Exact value of the integral (for comparision)
    S = simpson(f, a, b, m=1)   # Numerical solution, using m subintervals   
    err = I_exact-S             # Error
    print('I = {:.8f},  S = {:.8f},  error = {:.3e}'.format(I_exact, S, err))
# end of test1

def test2():
    # Numerical experiment 2
    def f(x):
        return np.cos(0.5*np.pi*x)
    a, b = 0, 1
    I_exact = 2/np.pi
    for m in [1,2,4,8,16]:
        S = simpson(f, a, b, m=m)   # Numerical solution, using m subintervals   
        err = I_exact-S             # Error
        if m == 1:
            print('m = {:3d},  error = {:.3e}'.format(m, err))
        else:
            print('m = {:3d},  error = {:.3e},  reduction factor = {:.3e}'.format(m, err, err/err_prev))
        err_prev=err
# end of test2



def test3(): 
    # Test of simpson_basic

    def f(x):               # Integrand
        return np.cos(x)

    a, b = 0, 1             # Integration interval
        
    I_exact = np.sin(1)     # Exact solution for comparision

    # Simpson's method over two intervals, with error estimate
    S, error_estimate = simpson_basic(f, a, b)

    # Print the result and the exact solution 
    print(f"Numerical solution = {S:.8f}, exact solution = {I_exact:.8f}")

    # Compare the error and the error estimate 
    print(f"Error in S2 = {I_exact-S:.3e},  error estimate for S2 = {error_estimate:.3e}")
# end of test3


def test4():
    # Test: The adaptive Simpson's method
    def f(x):                                      # Integrand       
        return 1/(1+(4*x)**2)
    a, b = 0, 8                                    # Integration interval
    I_exact = 0.25*(np.arctan(4*b)-np.arctan(4*a)) # Exact integral
    tol = 1.e-3                                    # Tolerance
    # Apply the algorithm
    result = simpson_adaptive(f, a, b, tol=tol)
    # Print the result and the exact solution 
    print(f"\nNumerical solution = {result:8f}, exact solution = {I_exact:8f}")
    # Compare the measured error and the tolerance
    err = I_exact - result
    print(f"\nTolerance = {tol:.1e}, error = {abs(err):.3e}")
# end of test4

def runge_df4():
    # Plot the 4th derivate of Runge's function:
    def df4(x):
        return 6144*(1280*x**4-160*x**2+1)/((1+16*x**2)**5)
    x = np.linspace(0, 8, 1001)
    plt.plot(x, df4(x))
    plt.title('The 4th derivative of Runges function');
    plt.xlabel('x');
# end of runge_df4


