import numpy as np
import matplotlib.pyplot as plt
from numpy.linalg import solve, norm    # Solve linear systems and compute norms

plt.ion()

def bisection(f, a, b, tol=1.e-6, max_iter = 100):
    ''' Solve the scalar equation f(x)=0 by bisection 
        The result of each iteration is printed
    Input:
        f:        The function. 
        a, b:     Interval: 
        tol :     Tolerance
        max_iter: Maximum number of iterations
    Output:
        the root and the number of iterations.
    '''
    fa = f(a)
    fb = f(b)

    assert fa*fb<0, 'Error: f(a)*f(b)>0, there may be no root in the interval.'
    
    for k in range(max_iter):
        c = 0.5*(a+b)                 # The midpoint
        fc = f(c)                   
        print(f"k ={k:3d}, a = {a:.6f}, b = {b:.6f}, c = {c:.6f}, f(c) = {fc:10.3e}")
        if abs(f(c)) < 1.e-14 or (b-a) < 2*tol:     # The zero is found!
            break 
        elif fa*fc < 0:               
            b = c                     # There is a root in [a, c]
        else:
            a = c                     # There is a root in [c, b]  
    return c, k 
# end of bisection


def fixedpoint(g, x0, tol=1.e-8, max_iter=30):
    ''' Solve x=g(x) by fixed point iterations
        The output of each iteration is printed
    Input:
        g:   The function g(x)
        x0:  Initial values
        tol: The tolerance
    Output:
        The root and the number of iterations
    '''
    x = x0
    print(f"k ={0:3d}, x = {x:14.10f}") 
    for k in range(max_iter):        
        x_old = x                        # Store old values for error estimation 
        x = g(x)                         # The iteration
        err = abs(x-x_old)               # Error estimate
        print(f"k ={k+1:3d}, x = {x:14.10f}") 
        if err < tol:                    # The solution is accepted 
            break
    return x, k+1
# end of fixedpoint


def newton(f, df, x0, tol=1.e-8, max_iter=30):
    ''' Solve f(x)=0 by Newtons method
        The output of each iteration is printed
        Input:
        f, df:   The function f and its derivate f'.
        x0:  Initial values
        tol: The tolerance
      Output:
        The root and the number of iterations
    '''
    x = x0
    print(f"k ={0:3d}, x = {x:18.15f}, f(x) = {f(x):10.3e}")
    for k in range(max_iter):
        fx = f(x)
        if abs(fx) < tol:           # Accept the solution 
            break 
        x = x - fx/df(x)            # Newton-iteration
        print(f"k ={k+1:3d}, x = {x:18.15f}, f(x) = {f(x):10.3e}")
    return x, k+1
# end of newton


np.set_printoptions(precision=15)          # Output with high accuracy

def newton_system(f, jac, x0, tol = 1.e-10, max_iter=20):
    x = x0
    print(f"k = {0:3d}, x = {x}")
    for k in range(max_iter):
        fx = f(x)
        if norm(fx, np.inf) < tol:          # The solution is accepted. 
            break
        Jx = jac(x)
        delta = solve(Jx, -fx) 
        x = x + delta            
        print(f"k = {k+1:3d}, x = {x}")
    return x, k
# end of newton_system

def example1():
    # Example 1
    def f(x):                       # Define the function
        return x**3+x**2-3*x-3

    # Plot the function on an interval 
    x = np.linspace(-2, 2, 101)     # x-values (evaluation points)
    plt.plot(x, f(x))               # Plot the function
    plt.plot(x, 0*x, 'r')           # Plot the x-axis
    plt.xlabel('x')
    plt.grid(True)
    plt.ylabel('f(x)');
# end of example1


def example2():
    # Example 2
    def f(x):                   # Define the function
        return x**3+x**2-3*x-3
    a, b = 1.5, 2               # The interval
    c, nit = bisection(f, a, b, tol=1.e-6) # Apply the bisecetion method

    # Control the result
    r = np.sqrt(3)   
    print(f"\n\nResult:\nx = {c:.10f} \nnumber of iterations = {nit:2d} \nerror = {abs(c-r):.2e}")
               
# end of example2


def example3():
    # Example 3
    x = np.linspace(-2, 2, 101)
    plt.plot(x, (x**3+x**2-3)/3,'b', x, x, '--r' )
    plt.axis([-2, 3, -2, 2])
    plt.xlabel('x')
    plt.grid('True')
    plt.legend(['g(x)','x']);
# end of example3 with plot


def example3_iter():
    # Solve the equation from example 3 by fixed point iterations.
    
    # Define the function
    def g(x):       
        return (x**3+x**2-3)/3

    # Apply the fixed point scheme
    x, nit = fixedpoint(g, x0=1.5, tol=1.e-6, max_iter=30)

    print(f"\n\nResult:\nx = {x:.10f} \nnumber of iterations = {nit:2d}")
# end of example3_iter


def example3_revisited():
    # Demonstrate the assumptions of the fixed point theorem

    def g(x):                              # The function g(x)
        return (x**3+x**2-3)/3

    def dg(x):                             # The derivative g'(x)
        return (3*x**2+2*x)/3
    
    a, b = -1.3, -0.7                      # The interval [a,b] around r.
    x = np.linspace(a, b, 101)             # x-values on [a,b]
    y_is_1 = np.ones(101)                  # For plotting the  for g' 


    plt.rcParams['figure.figsize'] = 10, 5     # Resize the figure
    # Plot x and g(x) around the $r=-1$.
    plt.subplot(1,2,1)
    plt.plot(x, g(x), x, x, 'r--')
    plt.xlabel('x')
    plt.axis([a,b,a,b])
    plt.grid(True)
    plt.legend(['g(x)','x'])

    # Plot g'(x), and the limits -1 and +1
    plt.subplot(1,2,2)
    plt.plot(x, dg(x), x, y_is_1, 'r--', x, -y_is_1, 'r--');
    plt.xlabel('x')
    plt.grid(True)
    plt.title('dg/dx');
# end of example3_revisited


def example4():
    # Example 4
    def f(x):                   # The function f
        return x**3+x**2-3*x-3

    def df(x):                  # The derivative f'
        return 3*x**2+2*x-3

    x0 = 1                      # Starting value
    x, nit = newton(f, df, x0, tol=1.e-14, max_iter=30)  # Apply Newton
    print(f"\n\nResult:\nx = {x:.10f} \nnumber of iterations = {nit:2d}")
# end of example4


def example5_graphs():
    # Example 5 with graphs
    x = np.linspace(-1.0, 1.0, 101)
    plt.plot(x, x**3+0.25);
    t = np.linspace(0,2*np.pi,101)
    plt.plot(np.cos(t), np.sin(t)); 
    plt.axis('equal');
    plt.xlabel('x1')
    plt.ylabel('x2')
    plt.grid(True)
    plt.legend(['x1^3-x2+1/4=0','x1^2+x2^2-1=0']);
# end of example5_graphs. 

def example6():
    # Example 6

    # The vector valued function. Notice the indexing. 
    def f(x):               
        y = np.array([x[0]**3-x[1]   +0.25, 
                      x[0]**2+x[1]**2-1    ])
        return y

    # The Jacobian
    def jac(x):
        J = np.array([[3*x[0]**2, -1     ],
                      [2*x[0],     2*x[1]]])
        return J

    x0 = np.array([1.0, 1.0])          # Starting values
    max_iter = 20
    x, nit = newton_system(f, jac, x0, tol = 1.e-12, max_iter = max_iter)  # Apply Newton's method
  
    print(f"\nTest: f(x)={f(x)}")
    if nit == max_iter:
        print('Warning: Convergence har not been achieved') 
# end of example6


