import numpy as np
from numpy.linalg import solve, norm 
import matplotlib.pyplot as plt
from scipy import interpolate
plt.ion()
plt.clf()

def cardinal(xdata, x):
    """
    cardinal(xdata, x): 
    In: xdata, array with the nodes x_i.
        x, array or a scalar of values in which the cardinal functions are evaluated.
    Return: l: a list of arrays of the cardinal functions evaluated in x. 
    """
    n = len(xdata)              # Number of evaluation points x
    l = []
    for i in range(n):          # Loop over the cardinal functions
        li = np.ones(len(x))
        for j in range(n):      # Loop to make the product for l_i
            if i is not j:
                li = li*(x-xdata[j])/(xdata[i]-xdata[j])
        l.append(li)            # Append the array to the list            
    return l

def lagrange(ydata, l):
    """
    lagrange(ydata, l):
    In: ydata, array of the y-values of the interpolation points.
         l, a list of the cardinal functions, given by cardinal(xdata, x)
    Return: An array with the interpolation polynomial. 
    """
    poly = 0                        
    for i in range(len(ydata)):
        poly = poly + ydata[i]*l[i]  
    return poly
# end of lagrange

def divdiff(xdata,ydata):
    '''
    Create the table of divided differences based
    on the data in the arrays x_data and y_data. 
    '''
    n = len(xdata)
    F = np.zeros((n,n))
    F[:,0] = ydata             # Array for the divided differences
    for j in range(n):
        for i in range(n-j-1):
            F[i,j+1] = (F[i+1,j]-F[i,j])/(xdata[i+j+1]-xdata[i])
    return F                    # Return all of F for inspection. 
                                # Only the first row is necessary for the
                                # polynomial.

def newton_interpolation(F, xdata, x):
    # The Newton interpolation polynomial evaluated in x. 
    n, m = np.shape(F)
    xpoly = np.ones(len(x))               # (x-x[0])(x-x[1])...
    newton_poly = F[0,0]*np.ones(len(x))  # The Newton polynomial
    for j in range(n-1):
        xpoly = xpoly*(x-xdata[j])
        newton_poly = newton_poly + F[0,j+1]*xpoly
    return newton_poly
# end of newton_interpolation


def equidistributed_bound(n, M, a, b):
    # Return the bound for error in equidistributed nodes
    print(n)
    h = (b-a)/n
    return 0.25*h**(n+1)/(n+1)*M;
# end of equidistributed_bound

def omega(xdata, x):
    # compute omega(x) for the nodes in xdata
    n1 = len(xdata)
    omega_value = np.ones(len(x))             
    for j in range(n1):
        omega_value = omega_value*(x-xdata[j])  # (x-x_0)(x-x_1)...(x-x_n)
    return omega_value
# end of omega

def plot_omega():
    # Plot omega(x) 
    n = 8                           # Number of interpolation points is n+1
    a, b = -1, 1                    # The interval
    x = np.linspace(a, b, 501)        
    xdata = np.linspace(a, b, n) 
    plt.plot(x, omega(xdata, x))
    plt.grid(True)
    plt.xlabel('x')
    plt.ylabel('omega(x)')
    print("n = {:2d}, max|omega(x)| = {:.2e}".format(n, max(abs(omega(xdata, x)))))
# end of plot_omega
 
def chebyshev_nodes(a, b, n):
    # n Chebyshev nodes in the interval [a, b] 
    i = np.array(range(n))                 # i = [0,1,2,3, ....n-1]
    x = np.cos((2*i+1)*pi/(2*(n)))         # nodes over the interval [-1,1]
    return 0.5*(b-a)*x+0.5*(b+a)           # nodes over the interval [a,b]
# end of chebyshev_nodes

def example1():
    # Example 1
    xdata = [0,2/3., 1]                    # Interpolation data
    ydata = [1, 1/2., 0]
    x = np.linspace(0,1,100)               # Gridpoints for plotting
    p2 = (-3*x**2-x+4)/4                   # Interpolation polynomial
    f = np.cos(np.pi*x/2)                  # Original function
    plt.plot(x, f, 'c', x, p2, 'm', xdata, ydata, "ok")
    plt.legend(['$\cos(\pi x/2)$', '$p_2(x)$', 'Interpolation data']);
# end of example 1

def example2():
    # Example 2

    # Define the interpolation points
    xdata = [0,1,3,4,7]
    ydata = [3,8,6,0,-4]
    # Find the degree of the interpolation polynomial
    # (that is, one less than the number of interpolation points)
    n = np.size(xdata) - 1
    
    # Set the interval
    a, b = 0, 7                        # The interpolation interval
    x = np.linspace(a, b, 101)         # The 'x-axis'

    # Compute the interpolation polynomial using built in numpy functions
    p = np.polynomial.Polynomial.fit(xdata,ydata,n)

    # Output the polynomial
    print(p.convert())

    plt.plot(x, p(x))                  # Plot the polynomial
    plt.plot(xdata, ydata, 'o')        # Plot the interpolation points 
    plt.title('The interpolation polynomial p(x)')
    plt.xlabel('x');
# end of example 2

def example4():
    # Example 4
    xdata = [0, 1, 3]               # The interpolation points
    ydata = [3, 8, 6]
    x = np.linspace(0, 3, 101)      # The x-values in which the polynomial is evaluated
    l = cardinal(xdata, x)          # Find the cardinal functions evaluated in x
    p = lagrange(ydata, l)          # Compute the polynomial evaluated in x
    plt.plot(x, p)                  # Plot the polynomial
    plt.plot(xdata, ydata, 'o')     # Plot the interpolation points 
    plt.title('The interpolation polynomial p(x)')
    plt.xlabel('x');
# end of example 4

def example5():
    # Example 5

    # Define the function
    def f(x):
        return np.sin(x)

    # Set the interval 
    a, b = 0, 2*np.pi                  # The interpolation interval
    x = np.linspace(a, b, 101)         # The 'x-axis' 

    # Set the interpolation points
    n = 8                              # Interpolation points
    xdata = np.linspace(a, b, n+1)     # Equidistributed nodes (can be changed)
    ydata = f(xdata)                

    # Compute the interpolation polynomial using built in numpy functions
    p = np.polynomial.Polynomial.fit(xdata,ydata,n)

    # Plot f(x) og p(x) and the interpolation points
    plt.subplot(2,1,1)                  
    plt.plot(x, f(x), x, p(x), xdata, ydata, 'o')
    plt.legend(['f(x)','p(x)'])
    plt.grid(True)

    # Plot the interpolation error
    plt.subplot(2,1,2)
    plt.plot(x, (f(x)-p(x)))
    plt.xlabel('x')
    plt.ylabel('Error: f(x)-p(x)')
    plt.grid(True)
    print("Max error is {:.2e}".format(max(abs(p-f(x)))))
# end of example 5

def example_divided_differences():
    # Example: Use of divided differences and the Newton interpolation
    # formula. 
    xdata = [0, 2/3, 1]
    ydata = [1, 1/2, 0]
    F = divdiff(xdata, ydata)       # The table of divided differences
    print('The table of divided differences:\n',F)

    x = np.linspace(0, 1, 101)      # The x-values in which the polynomial is evaluated
    p = newton_interpolation(F, xdata, x)
    plt.plot(x, p)                  # Plot the polynomial
    plt.plot(xdata, ydata, 'o')     # Plot the interpolation points 
    plt.title('The interpolation polynomial p(x)')
    plt.grid(True)
    plt.xlabel('x');
# end of example_divided_differences


def example7():
    # Example 7

    # Define the interpolation points
    xdata = [0,1,3,4,7]
    ydata = [3,8,6,-1,2]
    
    # Set the interval
    a, b = 0, 7                        # The interpolation interval
    x = np.linspace(a, b, 101)         # The 'x-axis'

    # Compute the interpolating linear spline using built in numpy functions
    s = np.interp(x,xdata,ydata)

    plt.plot(x, s)                     # Plot the linear spline
    plt.plot(xdata, ydata, 'o')        # Plot the interpolation points 
    plt.title('Interpolation with a linear spline')
    plt.xlabel('x');
# end of example 7


def example8():
    # Example 8

    # Define the interpolation points
    xdata = [0,1,3,4,7]
    ydata = [3,8,6,-1,2]
    
    # Set the interval
    a, b = 0, 7                        # The interpolation interval
    x = np.linspace(a, b, 101)         # The 'x-axis'

    # Compute the different interpolating cubic splines using built in scipy functions
    # We assume here that the package scipy.interpolate is imported as interpolate
    s1 = interpolate.CubicSpline(xdata,ydata,bc_type='not-a-knot')
    s2 = interpolate.CubicSpline(xdata,ydata,bc_type='clamped')
    s3 = interpolate.CubicSpline(xdata,ydata,bc_type='natural')

    plt.plot(x, s1(x),'r',label='Not a knot')
    plt.plot(x, s2(x),'b',label='Clamped')
    plt.plot(x, s3(x),'k',label='Natural')
    plt.plot(xdata, ydata, 'o')
    plt.title('Interpolation with a cubic spline')
    plt.legend()
    plt.xlabel('x');
# end of example 8
