'''
Implementation of different line search methods
for usage in the course TMA4180 - Optimisation 1.

Markus Grasmair
Trondheim, February 2023 -- March 2024
'''

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import TMA4180_definitions
from matplotlib import animation
import mpl_toolkits.axes_grid1
import matplotlib.widgets
import matplotlib.patches
import scipy.linalg as la
import copy

############################################################
# definition of a Player class for interactive animations
# code by StackOverflow user ImportanceOfBeingErnest,
# taken from his answer to this question:
# https://stackoverflow.com/questions/44985966/managing-dynamic-plotting-in-matplotlib-animation-module/44989063#44989063
class Player(animation.FuncAnimation):
    def __init__(self, fig, func, frames=None, init_func=None, fargs=None,
                 save_count=None, mini=0, maxi=100, pos=(0.125, 0.92), **kwargs):
        self.i = 0
        self.min=mini
        self.max=maxi
        self.runs = True
        self.forwards = True
        self.fig = fig
        self.func = func
        self.setup(pos)
        animation.FuncAnimation.__init__(self,self.fig, self.func, frames=self.play(), 
                                         init_func=init_func, fargs=fargs,
                                         save_count=save_count, **kwargs )    

    def play(self):
        while self.runs:
            self.i = self.i+self.forwards-(not self.forwards)
            if self.i > self.min and self.i < self.max:
                yield self.i
            else:
                self.stop()
                yield self.i

    def start(self):
        self.runs=True
        self.event_source.start()

    def stop(self, event=None):
        self.runs = False
        self.event_source.stop()

    def forward(self, event=None):
        self.forwards = True
        self.start()
    def backward(self, event=None):
        self.forwards = False
        self.start()
    def oneforward(self, event=None):
        self.forwards = True
        self.onestep()
    def onebackward(self, event=None):
        self.forwards = False
        self.onestep()

    def onestep(self):
        if self.i > self.min and self.i < self.max:
            self.i = self.i+self.forwards-(not self.forwards)
        elif self.i == self.min and self.forwards:
            self.i+=1
        elif self.i == self.max and not self.forwards:
            self.i-=1
        self.func(self.i)
        self.fig.canvas.draw_idle()

    def setup(self, pos):
        playerax = self.fig.add_axes([pos[0],pos[1], 0.22, 0.04])
        divider = mpl_toolkits.axes_grid1.make_axes_locatable(playerax)
        bax = divider.append_axes("right", size="80%", pad=0.05)
        sax = divider.append_axes("right", size="80%", pad=0.05)
        fax = divider.append_axes("right", size="80%", pad=0.05)
        ofax = divider.append_axes("right", size="100%", pad=0.05)
        self.button_oneback = matplotlib.widgets.Button(playerax, label=u'$\u29CF$')
        self.button_back = matplotlib.widgets.Button(bax, label=u'$\u25C0$')
        self.button_stop = matplotlib.widgets.Button(sax, label=u'$\u25A0$')
        self.button_forward = matplotlib.widgets.Button(fax, label=u'$\u25B6$')
        self.button_oneforward = matplotlib.widgets.Button(ofax, label=u'$\u29D0$')
        self.button_oneback.on_clicked(self.onebackward)
        self.button_back.on_clicked(self.backward)
        self.button_stop.on_clicked(self.stop)
        self.button_forward.on_clicked(self.forward)
        self.button_oneforward.on_clicked(self.oneforward)

#################################################################################
# methods for creating plots

def CreateAnimation(f,points_x,points_y,speed = 500):
    # create an animation of the algorithm
    animation_fig,ax = plt.subplots()
    ax.clear()
    current_line, = ax.plot([],[],lw=1,color='r')
    previous_line, = ax.plot([],[],lw=1,color='k')
    f.plotf()
    num_iterations = len(points_x)
    num_digits = len(str(num_iterations-1))
    def animation_init():
        plt.title('Iteration 0')
        current_line.set_data([],[])
        previous_line.set_data([],[])
        return(current_line,previous_line,)
    
    def animation_animate(i):
        plt.title('Iteration {:-{count}}'.format(min(i,num_iterations-1),count=num_digits))
        if i==0:
            return()
        elif i==1:
            current_line.set_data(points_x[0:2],points_y[0:2])
        else:
            current_line.set_data(points_x[i-1:i+1],points_y[i-1:i+1])
            previous_line.set_data(points_x[:i],points_y[:i])
        return(current_line,previous_line,)
    
    ani = Player(animation_fig, animation_animate,
                 mini = 0, maxi = num_iterations,
                 init_func = animation_init,
                 interval = speed,
    )
    return ani

def CreateAnimation_ConvergencePlot(f,points_x,points_y,errors,speed = 500):
    # create the animation
    animation_fig,(ax1,ax2) = plt.subplots(1,2)
    plt.axes(ax1)
    ax1.clear()
    f.plotf()
    num_iterations = len(points_x)
    num_digits = len(str(num_iterations-1))
    current_line, = ax1.plot([],[],lw=1,color='r')
    previous_line, = ax1.plot([],[],lw=1,color='k')
    plt.axes(ax2)
    ax2.clear()
    plt.semilogy(errors)
    def animation_init():
        plt.axes(ax1)
        plt.title('Iteration 0')
        current_line.set_data([],[])
        previous_line.set_data([],[])
        return(current_line,previous_line,)
    
    def animation_animate(i):
        plt.axes(ax1)
        plt.title('Iteration {:-{count}}'.format(min(i,num_iterations-1),count=num_digits))
        if i==0:
            current_line.set_data(points_x[0:2],points_y[0:2])
        else:
            current_line.set_data(points_x[i-1:i+1],points_y[i-1:i+1])
            previous_line.set_data(points_x[:i],points_y[:i])
        return(current_line,previous_line,)
    
    ani = Player(animation_fig, animation_animate,
                 mini = 0, maxi = num_iterations,
                 init_func = animation_init,
                 interval = speed,
    )
    return ani

def CreateConvergencePlot(errors):
    fig = plt.semilogy(errors)
    return fig

def CreateAnimation_QP(f,c,penalty_list,points_x,points_y,speed = 1000):
    f.build_discretisation()
    x_lower = f.plotbounds[0,0]
    x_upper = f.plotbounds[0,1]
    y_lower = f.plotbounds[1,0]
    y_upper = f.plotbounds[1,1]
    N_points = f.N_points
    delta_x = (x_upper-x_lower)/N_points
    delta_y = (y_upper-y_lower)/N_points
    xx = np.arange(x_lower,x_upper,delta_x)
    yy = np.arange(y_lower,y_upper,delta_y)
    aspect_ratio = (x_upper-x_lower)/(y_upper-y_lower)
    X, Y = np.meshgrid(xx,yy)
    Z_f = f.val(np.array([X,Y]))
    Z_c = c.val(np.array([X,Y]))
    data = f.plot_scaling(Z_f + 0.*Z_c)
    # create an animation of the algorithm
    animation_fig,ax = plt.subplots()
    ax.clear()
    f_image = plt.imshow(data, interpolation='bilinear',
                        origin='lower', cmap=cm.Blues, 
                        extent=(x_lower,x_upper,y_lower,y_upper),
                        aspect=aspect_ratio,
                        vmin = np.min(data[:]),
                        vmax = np.max(data[:])
    )
    f_contour = plt.contour(X,Y,data,f.N_lines)
    current_line, = ax.plot([],[],lw=1,color='r')
    previous_line, = ax.plot([],[],lw=1,color='k')
    num_iterations = len(points_x)
    num_digits = len(str(num_iterations-1))
    def animation_init():
        f_image.set_data(f.plot_scaling(data))
        f_image.set_clim(vmin=np.min(data[:]),vmax=np.max(data[:]))
        current_line.set_data([],[])
        previous_line.set_data([],[])
        return(current_line,previous_line,f_image,f_contour,)
    
    def animation_animate(i):
        data = f.plot_scaling(Z_f + 0.5*penalty_list[i]*Z_c**2)
        plt.title('Penalty mu = {}'.format(penalty_list[i]))
        f_image.set_data(data)
        f_image.set_clim(vmin=np.min(data[:]),vmax=np.max(data[:]))
        for tp in f_contour.collections:
            tp.remove()
        f_contour = plt.contour(X,Y,f.plot_scaling(Z_f + 0.5*penalty_list[i]*Z_c**2),f.N_lines)
        if i==0:
            return(f_image,f_contour,)
        elif i==1:
           current_line.set_data(points_x[0:2],points_y[0:2])
        else:
            current_line.set_data(points_x[i-1:i+1],points_y[i-1:i+1])
            previous_line.set_data(points_x[:i],points_y[:i])
        return(current_line,previous_line,f_image,f_contour,)
    
    ani = Player(animation_fig, animation_animate,
                 mini = 0, maxi = num_iterations,
                 init_func = animation_init,
                 interval = speed,
    )
    return ani


#################################################################################
# step length selection


def Backtracking(f,x,p,initial_step_length,initial_value,descent,
                 c1 = 1e-3,rho = 0.1,max_steps=50):
    '''
    Implementation of backtracking Armijo line search
    '''
    
    alpha = initial_step_length
    n_steps = 0
    next_x = x+alpha*p
    next_value = f.val(next_x)

    while ((next_value > initial_value+c1*alpha*descent) and (n_steps < max_steps)):
        alpha *= rho
        next_x = x+alpha*p
        next_value = f.val(next_x)
        n_steps += 1

    tot_func_evals = n_steps+1
    return next_x,next_value,tot_func_evals

def Goldstein(f,x,p,initial_step_length,initial_value,descent,
              c1 = 1e-3,c2 = 0.9,
              max_extrapolation_iterations = 50,
              max_interpolation_iterations = 20,
              rho = 2.0
):
    '''
    Implementation of a bisection based bracketing method
    for the Goldstein-Price conditions
    '''
   
    # initialise the step length and the lower bound
    alpha = initial_step_length
    alphaL = 0.0
    # compute the proposed update and the function value
    next_x = x+alpha*p
    next_value = f.val(next_x)
    tot_func_evals = 1
    
    # increase alpha until it is no longer deemed too small
    LowerBound = (next_value >= initial_value + c2*alpha*descent)
    max_number_iterations = max_extrapolation_iterations
    itnr = 0
    while (itnr < max_extrapolation_iterations and not(LowerBound)):
        alphaL = alpha
        alpha = rho*alpha
        next_x = x+alpha*p
        next_value = f.val(next_x)
        tot_func_evals += 1
        LowerBound = (next_value >= initial_value + c2*alpha*descent)
        itnr += 1

    # now use bisection until we have found and acceptable step length
    Armijo = (next_value <= initial_value+c1*alpha*descent)
    if (itnr == max_extrapolation_iterations):
        max_interpolation_iterations = 0
    itnr = 0
    while (itnr < max_interpolation_iterations and not(LowerBound and Armijo)):
        # in the first step
        if not(Armijo):
            alphaR = alpha
            alpha = 0.5*(alphaL+alphaR)
        else:
            alphaL = alpha
            alpha = 0.5*(alphaL+alphaR)
        next_x = x+alpha*p
        next_value = f.val(next_x)
        tot_func_evals += 1
        Armijo = (next_value <= initial_value+c1*alphaR*descent)
        LowerBound = (next_value >= initial_value + c2*alphaR*descent)

    if (max_interpolation_iterations == itnr):
        print("Line search did not yield an acceptable step length!")
        
    return next_x,next_value,tot_func_evals,alpha
            



def StrongWolfe(fun,x,p,
                initial_value,
                initial_descent,
                initial_step_length = 1.0,
                c1 = 1e-3,
                c2 = 0.9,
                max_extrapolation_iterations = 100,
                max_interpolation_iterations = 50,
                rho = 2.0):
    '''
    Implementation of a bisection based bracketing method
    for the strong Wolfe conditions
    '''

    # initialise the bounds of the bracketing interval
    alphaR = initial_step_length
    alphaL = 0.0
    # Armijo condition and the two parts of the Wolfe condition
    # are implemented as Boolean variables
    next_x = x+alphaR*p
    next_value = fun.val(next_x)
    tot_func_evals = 1
    next_grad = fun.grad(next_x)
    Armijo = (next_value <= initial_value+c1*alphaR*initial_descent)
    descentR = np.inner(p,next_grad)
    curvatureLow = (descentR >= c2*initial_descent)
    curvatureHigh = (descentR <= -c2*initial_descent)
    # We start by increasing alphaR as long as Armijo and curvatureHigh hold,
    # but curvatureLow fails (that is, alphaR is definitely too small).
    # Note that curvatureHigh is automatically satisfied if curvatureLow fails.
    # Thus we only need to check whether Armijo holds and curvatureLow fails.
    itnr = 0
    while (itnr < max_extrapolation_iterations and (Armijo and (not curvatureLow))):
        itnr += 1
        # alphaR is a new lower bound for the step length
        # the old upper bound alphaR needs to be replaced with a larger step length
        alphaL = alphaR
        alphaR *= rho
        # update function value and gradient
        next_x = x+alphaR*p
        next_value = fun.val(next_x)
        tot_func_evals += 1
        next_grad = fun.grad(next_x)
        # update the Armijo and Wolfe conditions
        Armijo = (next_value <= initial_value+c1*alphaR*initial_descent)
        descentR = np.inner(p,next_grad)
        curvatureLow = (descentR >= c2*initial_descent)
        curvatureHigh = (descentR <= -c2*initial_descent)
    # at that point we should have a situation where alphaL is too small
    # and alphaR is either satisfactory or too large
    # (Unless we have stopped because we used too many iterations. There
    # are at the moment no exceptions raised if this is the case.)
    if(Armijo and curvatureLow and curvatureHigh):
        return next_x,next_value,next_grad,tot_func_evals,alphaR
    alpha = np.copy(alphaR)
    itnr = 0
    # Use bisection in order to find a step length alpha that satisfies
    # all conditions.
    while (itnr < max_interpolation_iterations and (not (Armijo and curvatureLow and curvatureHigh))):
        itnr += 1
        if (Armijo and (not curvatureLow)):
            # the step length alpha was still too small
            # replace the former lower bound with alpha
            alphaL = alpha
        else:
            # the step length alpha was too large
            # replace the upper bound with alpha
            alphaR = alpha
        # choose a new step length as the mean of the new bounds
        alpha = (alphaL+alphaR)/2
        # update function value and gradient
        next_x = x+alpha*p
        next_value = fun.val(next_x)
        tot_func_evals += 1
        next_grad = fun.grad(next_x)
        # update the Armijo and Wolfe conditions
        Armijo = (next_value <= initial_value+c1*alpha*initial_descent)
        descentR = np.inner(p,next_grad)
        curvatureLow = (descentR >= c2*initial_descent)
        curvatureHigh = (descentR <= -c2*initial_descent)
    # return the next iterate as well as the function value and gradient there
    # (in order to save time in the outer iteration; we have had to do these
    # computations anyway)
    #if(itnr == max_interpolation_iterations):
    #    print("Step length not converged")
    return next_x,next_value,next_grad,tot_func_evals,alpha


#################################################################################
# implementation of line search methods





def GradDescLinear(f,x_init,
                   max_steps = 50,
                   animationspeed = 500,
):

    x = x_init
    points_x = [x_init[0]]
    points_y = [x_init[1]]

    n_step = 0
    while n_step < max_steps:
        n_step += 1
        search_direction = - f.grad(x)
        alpha = np.inner(search_direction,search_direction)/np.inner(search_direction,f.hess(x)@search_direction)
        x = x + alpha*search_direction
        points_x = points_x + [x[0]]
        points_y = points_y + [x[1]]
        
    ani = CreateAnimation(f,points_x,points_y,speed = animationspeed)
    return x,ani

def BacktrackingGradDesc(f,x_init,
                         max_steps = 50,
                         alpha_0 = 1.0,
                         create_animation = False,
                         convergence_plot = False,
                         gradient_stop = 1e-8,
                         animationspeed = 500,
                         c1_Armijo = 1e-3,
):
    
    # Initialise the variable x
    x = x_init
   
    # If we show a convergence plot, store the errors in each step
    if convergence_plot:
        if any(f.minimisers) == None:
            print("A convergence plot requires the (analytic) minimisers of the function to be set.")
            convergence_plot = False
        else:
            errors = []
    
    # If we want to show the whole iteration, store all the points
    if create_animation:
        points_x = [x_init[0]]
        points_y = [x_init[1]]
    
    # Main loop
    n_step = 0
    current_value = f.val(x)
    search_direction = -f.grad(x)
    norm_grad = np.linalg.norm(search_direction)
    tot_func_evals = 1
    while ((n_step < max_steps) and (norm_grad > gradient_stop)):
        n_step += 1
        descent = -np.inner(search_direction,search_direction)
        x,current_value,func_evals_backtracking = Backtracking(f,x,search_direction,alpha_0,current_value,descent,c1=c1_Armijo,)
        if create_animation:
            points_x = points_x + [x[0]]
            points_y = points_y + [x[1]]
        if convergence_plot:
            errors = errors + [np.linalg.norm(x-f.minimisers)]
        search_direction = -f.grad(x)
        norm_grad = np.linalg.norm(search_direction)
        tot_func_evals += func_evals_backtracking

    if n_step < max_steps:
        print("Converged after {} iterations.".format(n_step))
    else:
        print("Did not converge after {} iterations.".format(max_steps))

    print("A total of {} function evaluations were required.".format(tot_func_evals))
    
    if create_animation and not(convergence_plot):
        ani = CreateAnimation(f,points_x,points_y,speed=animationspeed)
        return x,ani
    elif create_animation and convergence_plot:
        ani = CreateAnimation_ConvergencePlot(f,points_x,points_y,errors,speed=animationspeed)
        return x,ani
    elif (not(create_animation) and convergence_plot):
        fig = CreateConvergencePlot(errors)
        return x,fig
    else:
        return x

def GD_Goldstein(f,x_init,
                 max_steps = 50,
                 alpha_0 = 1.0,
                 create_animation = False,
                 convergence_plot = False,
                 gradient_stop = 1e-8,
                 animationspeed = 500,
                 c1_Armijo = 1e-3,
                 c2_Goldstein = 0.9,
):
    
    # Initialise the variable x
    x = x_init
   
    # If we show a convergence plot, store the errors in each step
    if convergence_plot:
        if any(f.minimisers) == None:
            print("A convergence plot requires the (analytic) minimisers of the function to be set.")
            convergence_plot = False
        else:
            errors = []
    
    # If we want to show the whole iteration, store all the points
    if create_animation:
        points_x = [x_init[0]]
        points_y = [x_init[1]]
    
    # Main loop
    n_step = 0
    current_value = f.val(x)
    search_direction = -f.grad(x)
    norm_grad = np.linalg.norm(search_direction)
    tot_func_evals = 1
    while ((n_step < max_steps) and (norm_grad > gradient_stop)):
        n_step += 1
        descent = -np.inner(search_direction,search_direction)
        x,current_value,func_evals_LS,alpha_0 = Goldstein(f,x,search_direction,alpha_0,current_value,descent,c1=c1_Armijo,c2=c2_Goldstein)
        if create_animation:
            points_x = points_x + [x[0]]
            points_y = points_y + [x[1]]
        if convergence_plot:
            errors = errors + [np.linalg.norm(x-f.minimisers)]
        search_direction = -f.grad(x)
        norm_grad = np.linalg.norm(search_direction)
        tot_func_evals += func_evals_LS

    if n_step < max_steps:
        print("Converged after {} iterations.".format(n_step))
    else:
        print("Did not converge after {} iterations.".format(max_steps))

    print("A total of {} function evaluations were required.".format(tot_func_evals))
    
    if create_animation and not(convergence_plot):
        ani = CreateAnimation(f,points_x,points_y,speed=animationspeed)
        return x,ani
    elif create_animation and convergence_plot:
        ani = CreateAnimation_ConvergencePlot(f,points_x,points_y,errors,speed=animationspeed)
        return x,ani
    elif (not(create_animation) and convergence_plot):
        fig = CreateConvergencePlot(errors)
        return x,fig
    else:
        return x

def BacktrackingNewton(f,x_init,
                       max_steps = 50,
                       alpha_0 = 1.0,
                       create_animation = False,
                       convergence_plot = False,
                       descent_eps = 1e-6,
                       gradient_stop = 1e-12,
                       animationspeed = 500,
                       c1_Armijo = 1e-3,
):
    
    # Initialise the variable x
    x = x_init

    # If we show a convergence plot, store the errors in each step
    if convergence_plot:
        if any(f.minimisers) == None:
            print("A convergence plot requires the (analytic) minimisers of the function to be set.")
            convergence_plot = False
        else:
            errors = []
    
    # If we want to show the whole iteration, store all the points
    if create_animation:
        points_x = [x_init[0]]
        points_y = [x_init[1]]
    
    # Main loop
    n_step = 0
    current_value = f.val(x)
    current_gradient = f.grad(x)
    norm_grad = np.linalg.norm(current_gradient)
    tot_func_evals = 1
    while ((n_step < max_steps) and (norm_grad > gradient_stop)):
        n_step += 1
        search_direction = -np.linalg.solve(f.hess(x),current_gradient)
        descent = np.inner(search_direction,current_gradient)
        if descent > -descent_eps*np.linalg.norm(current_gradient)*np.linalg.norm(search_direction):
            print("Newton direction is no descent direction. Switching to gradient direction.")
            search_direction = -f.grad(x)
            descent = -np.linalg.norm(search_direction)**2
        x,current_value,func_evals_backtracking = Backtracking(f,x,search_direction,alpha_0,current_value,descent,c1=c1_Armijo)
        current_gradient = f.grad(x)
        norm_grad = np.linalg.norm(current_gradient)
        tot_func_evals += func_evals_backtracking
        if create_animation:
            points_x = points_x + [x[0]]
            points_y = points_y + [x[1]]
        if convergence_plot:
            errors = errors + [np.linalg.norm(x-f.minimisers)]

    if n_step < max_steps:
        print("Converged after {} iterations.".format(n_step))
    else:
        print("Did not converge after {} iterations.".format(max_steps))
        
    print("A total of {} function evaluations were required.".format(tot_func_evals))

    if create_animation and not(convergence_plot):
        ani = CreateAnimation(f,points_x,points_y,speed=animationspeed)
        return x,ani
    elif create_animation and convergence_plot:
        ani = CreateAnimation_ConvergencePlot(f,points_x,points_y,errors,speed=animationspeed)
        return x,ani
    elif (not(create_animation) and convergence_plot):
        fig = CreateConvergencePlot(errors)
        return x,fig
    else:
        return x

def CG_FR(f,x_init,
          max_steps = 50,
          alpha_0 = 1.0,
          create_animation = False,
          convergence_plot = False,
          gradient_stop = 1e-8,
          animationspeed = 500,
          c1 = 1e-3,
          c2 = 0.49,
          restart = 0,
):
    
    # Initialise the variable x
    x = x_init

    # If we show a convergence plot, store the errors in each step
    if convergence_plot:
        if any(f.minimisers) == None:
            print("A convergence plot requires the (analytic) minimisers of the function to be set.")
            convergence_plot = False
        else:
            errors = []
    
    # If we want to show the whole iteration, store all the points
    if create_animation:
        points_x = [x_init[0]]
        points_y = [x_init[1]]

    if (restart != 0):
        restart_flag = 1
    else:
        restart_flag = 0
        
    # Main loop
    n_step = 0
    current_value = f.val(x)
    func_evals = 1
    current_gradient = f.grad(x)
    norm_grad = np.linalg.norm(current_gradient)
    old_gradient_norm_2 = norm_grad**2
    search_direction = -current_gradient
    while ((n_step < max_steps) and (norm_grad > gradient_stop)):
        n_step += 1
        descent = np.inner(search_direction,current_gradient)
        x,current_value,current_gradient,func_evals_LS,alpha_0 = StrongWolfe(f,x,search_direction,
                                                                             current_value,descent,
                                                                             initial_step_length = alpha_0,
                                                                             c1 = c1,
                                                                             c2 = c2)
        func_evals += func_evals_LS
        if create_animation:
            points_x = points_x + [x[0]]
            points_y = points_y + [x[1]]
        if convergence_plot:
            errors = errors + [np.linalg.norm(x-f.minimisers)]
        norm_grad = np.linalg.norm(current_gradient)
        current_gradient_norm_2 = norm_grad**2
        betaFR = current_gradient_norm_2/old_gradient_norm_2
        old_gradient_norm_2 = current_gradient_norm_2
        if (restart_flag):
            if (not(n_step%restart)):
                betaFR = 0
        search_direction = -current_gradient + betaFR*search_direction
        
    if n_step < max_steps:
        print("Converged after {} iterations.".format(n_step))
    else:
        print("Did not converge after {} iterations.".format(max_steps))

    print("A total of {} function and gradient evaluations were required.".format(func_evals))
        
    if create_animation and not(convergence_plot):
        ani = CreateAnimation(f,points_x,points_y,speed=animationspeed)
        return x,ani
    elif create_animation and convergence_plot:
        ani = CreateAnimation_ConvergencePlot(f,points_x,points_y,errors,speed=animationspeed)
        return x,ani
    elif (not(create_animation) and convergence_plot):
        fig = CreateConvergencePlot(errors)
        return x,fig
    else:
        return x


def CG_PR(f,x_init,
          max_steps = 50,
          alpha_0 = 1.0,
          create_animation = False,
          convergence_plot = False,
          gradient_stop = 1e-12,
          animationspeed = 500,
          c1 = 1e-3,
          c2 = 0.49,
          restart = 0,
):
    
    # Initialise the variable x
    x = x_init

    # If we show a convergence plot, store the errors in each step
    if convergence_plot:
        if any(f.minimisers) == None:
            print("A convergence plot requires the (analytic) minimisers of the function to be set.")
            convergence_plot = False
        else:
            errors = []
    
    # If we want to show the whole iteration, store all the points
    if create_animation:
        points_x = [x_init[0]]
        points_y = [x_init[1]]
    
    if (restart != 0):
        restart_flag = 1
    else:
        restart_flag = 0
        
    # Main loop
    n_step = 0
    current_value = f.val(x)
    func_evals = 1
    current_gradient = f.grad(x)
    norm_grad = np.linalg.norm(current_gradient)
    old_gradient_norm_2 = norm_grad**2
    search_direction = -current_gradient
    while ((n_step < max_steps) and (norm_grad > gradient_stop)):
        n_step += 1
        descent = np.inner(search_direction,current_gradient)
        old_gradient = current_gradient
        x,current_value,current_gradient,func_evals_LS,alpha_0 = StrongWolfe(f,x,search_direction,
                                                                             current_value,descent,
                                                                             initial_step_length = alpha_0,
                                                                             c1 = c1, c2 = c2)
        func_evals += func_evals_LS
        if create_animation:
            points_x = points_x + [x[0]]
            points_y = points_y + [x[1]]
        if convergence_plot:
            errors = errors + [np.linalg.norm(x-f.minimisers)]
        norm_grad = np.linalg.norm(current_gradient)
        current_gradient_norm_2 = norm_grad**2
        betaFR = current_gradient_norm_2/old_gradient_norm_2
        betaPR = np.inner(current_gradient,current_gradient-old_gradient)/old_gradient_norm_2
        if betaPR < -betaFR:
            beta = -betaFR
        elif betaPR > betaFR:
            beta = betaFR
        else:
            beta = betaPR
        old_gradient_norm_2 = current_gradient_norm_2
        if (restart_flag):
            if (not(n_step%restart)):
                beta = 0
        search_direction = -current_gradient + beta*search_direction
        
    if n_step < max_steps:
        print("Converged after {} iterations.".format(n_step))
    else:
        print("Did not converge after {} iterations.".format(max_steps))

    print("A total of {} function and gradient evaluations were required.".format(func_evals))

    if create_animation and not(convergence_plot):
        ani = CreateAnimation(f,points_x,points_y,speed=animationspeed)
        return x,ani
    elif create_animation and convergence_plot:
        ani = CreateAnimation_ConvergencePlot(f,points_x,points_y,errors,speed=animationspeed)
        return x,ani
    elif (not(create_animation) and convergence_plot):
        fig = CreateConvergencePlot(errors)
        return x,fig
    else:
        return x

def CG_HS(f,x_init,
          max_steps = 50,
          alpha_0 = 1.0,
          create_animation = False,
          convergence_plot = False,
          gradient_stop = 1e-12,
          animationspeed = 500,
          c1 = 1e-3,
          c2 = 0.49,
          restart = 0,
):
    
    # Initialise the variable x
    x = x_init

    # If we show a convergence plot, store the errors in each step
    if convergence_plot:
        if any(f.minimisers) == None:
            print("A convergence plot requires the (analytic) minimisers of the function to be set.")
            convergence_plot = False
        else:
            errors = []
    
    # If we want to show the whole iteration, store all the points
    if create_animation:
        points_x = [x_init[0]]
        points_y = [x_init[1]]
    
    if (restart != 0):
        restart_flag = 1
    else:
        restart_flag = 0
        
    # Main loop
    n_step = 0
    current_value = f.val(x)
    func_evals = 1
    current_gradient = f.grad(x)
    norm_grad = np.linalg.norm(current_gradient)
    old_gradient_norm_2 = norm_grad**2
    search_direction = -current_gradient
    while ((n_step < max_steps) and (norm_grad > gradient_stop)):
        n_step += 1
        descent = np.inner(search_direction,current_gradient)
        old_gradient = current_gradient
        x,current_value,current_gradient,func_evals_LS,alpha_0 = StrongWolfe(f,x,search_direction,
                                                                             current_value,descent,
                                                                             initial_step_length = alpha_0,
                                                                             c1 = c1, c2 = c2)
        func_evals += func_evals_LS
        if create_animation:
            points_x = points_x + [x[0]]
            points_y = points_y + [x[1]]
        if convergence_plot:
            errors = errors + [np.linalg.norm(x-f.minimisers)]
        norm_grad = np.linalg.norm(current_gradient)
        current_gradient_norm_2 = norm_grad**2
        betaFR = current_gradient_norm_2/old_gradient_norm_2
        grad_update = current_gradient-old_gradient
        betaHS = np.inner(current_gradient,grad_update)/np.inner(grad_update,search_direction)
        if betaHS < -betaFR:
            beta = -betaFR
        elif betaHS > betaFR:
            beta = betaFR
        else:
            beta = betaHS
        old_gradient_norm_2 = current_gradient_norm_2
        if (restart_flag):
            if (not(n_step%restart)):
                beta = 0
        search_direction = -current_gradient + beta*search_direction
        
    if n_step < max_steps:
        print("Converged after {} iterations.".format(n_step))
    else:
        print("Did not converge after {} iterations.".format(max_steps))

    print("A total of {} function and gradient evaluations were required.".format(func_evals))

    if create_animation and not(convergence_plot):
        ani = CreateAnimation(f,points_x,points_y,speed=animationspeed)
        return x,ani
    elif create_animation and convergence_plot:
        ani = CreateAnimation_ConvergencePlot(f,points_x,points_y,errors,speed=animationspeed)
        return x,ani
    elif (not(create_animation) and convergence_plot):
        fig = CreateConvergencePlot(errors)
        return x,fig
    else:
        return x

def BFGS(fun,x_init,
         max_steps = 50,
         create_animation = False,
         convergence_plot = False,
         gradient_stop = 1e-12,
         animationspeed = 500,
):
    
    # Initialise the variable x
    x = x_init
    # Initialise the quasi-Newton matrix as the identity
    H = np.identity(np.size(x))
    
    # If we show a convergence plot, store the errors in each step
    if convergence_plot:
        if any(fun.minimisers) == None:
            print("A convergence plot requires the (analytic) minimisers of the function to be set.")
            convergence_plot = False
        else:
            errors = []
    
    # If we want to show the whole iteration, store all the points
    if create_animation:
        points_x = [x_init[0]]
        points_y = [x_init[1]]
    
    # Main loop
    n_step = 0
    current_value = fun.val(x)
    current_gradient = fun.grad(x)
    func_evals = 1
    norm_grad = np.linalg.norm(current_gradient)
    while ((n_step < max_steps) and (norm_grad > gradient_stop)):
        n_step += 1
        search_direction = -H@current_gradient
        descent = np.inner(search_direction,current_gradient)
        x_old = x
        gradient_old = np.copy(current_gradient)
        #print("Iteration {}; gradient: {}; descent: {}".format(n_step,np.linalg.norm(current_gradient),descent))
        #print("Search direction: {}".format(np.linalg.norm(search_direction)))
        x,current_value,current_gradient,func_ls,_ = StrongWolfe(fun,x,search_direction,
                                                                 current_value,descent,
                                                                 initial_step_length = 1.0)
        func_evals += func_ls
        norm_grad = np.linalg.norm(current_gradient)
        if create_animation:
            points_x = points_x + [x[0]]
            points_y = points_y + [x[1]]
        if convergence_plot:
            errors = errors + [np.linalg.norm(x-fun.minimisers)]
        s = x - x_old
        y = current_gradient - gradient_old
        rho = 1/np.inner(y,s)
        if n_step==1:
            H = H*(1/(rho*np.inner(y,y)))
        z = H.dot(y)
        H += -rho*(np.outer(s,z) + np.outer(z,s)) + rho*(rho*np.inner(y,z)+1)*np.outer(s,s)
        #print(x)
        
    if n_step < max_steps:
        print("Converged after {} iterations.".format(n_step))
    else:
        print("Did not converge after {} iterations.".format(max_steps))
    print("A total of {} function and gradient evaluations were required.".format(func_evals))
        
    if create_animation and not(convergence_plot):
        ani = CreateAnimation(fun,points_x,points_y,speed=animationspeed)
        return x,ani
    elif create_animation and convergence_plot:
        ani = CreateAnimation_ConvergencePlot(fun,points_x,points_y,errors,speed=animationspeed)
        return x,ani
    elif (not(create_animation) and convergence_plot):
        fig = CreateConvergencePlot(errors)
        return x,fig
    else:
        return x

def NewtonWolfe(f,x_init,
                max_steps = 50,
                create_animation = False,
                convergence_plot = False,
                descent_eps = 1e-6,
                gradient_stop = 1e-12,
                animationspeed = 500,
):
    
    # Initialise the variable x
    x = x_init

    # If we show a convergence plot, store the errors in each step
    if convergence_plot:
        if any(f.minimisers) == None:
            print("A convergence plot requires the (analytic) minimisers of the function to be set.")
            convergence_plot = False
        else:
            errors = []
    
    # If we want to show the whole iteration, store all the points
    if create_animation:
        points_x = [x_init[0]]
        points_y = [x_init[1]]
    
    # Main loop
    n_step = 0
    current_value = f.val(x)
    current_gradient = f.grad(x)
    func_evals = 1
    norm_grad = np.linalg.norm(current_gradient)
    while ((n_step < max_steps) and (norm_grad > gradient_stop)):
        n_step += 1
        try:
            search_direction = -np.linalg.solve(f.hess(x),current_gradient)
            descent = np.inner(search_direction,current_gradient)
            if descent > -descent_eps*np.linalg.norm(current_gradient)*np.linalg.norm(search_direction):
                print("Step {}: Newton direction is no descent direction. Switching to gradient direction.".format(n_step))
                search_direction = -f.grad(x)
                descent = -np.linalg.norm(search_direction)**2
        except np.linalg.LinAlgError:
            print("Step {}: Problems when solving the Newton system; switching to gradient descent.".format(n_step))
            search_direction = -f.grad(x)
            descent = -np.linalg.norm(search_direction)**2
        x,current_value,current_gradient,func_ls,_ = StrongWolfe(f,x,search_direction,current_value,descent,initial_step_length=1.0)
        func_evals += func_ls
        norm_grad = np.linalg.norm(current_gradient)
        if create_animation:
            points_x = points_x + [x[0]]
            points_y = points_y + [x[1]]
        if convergence_plot:
            errors = errors + [np.linalg.norm(x-f.minimisers)]

    if n_step < max_steps:
        print("Converged after {} iterations.".format(n_step))
    else:
        print("Did not converge after {} iterations.".format(max_steps))
    print("A total of {} function and gradient evaluations were required.".format(func_evals))
        
    if create_animation and not(convergence_plot):
        ani = CreateAnimation(f,points_x,points_y,speed=animationspeed)
        return x,ani
    elif create_animation and convergence_plot:
        ani = CreateAnimation_ConvergencePlot(f,points_x,points_y,errors,speed=animationspeed)
        return x,ani
    elif (not(create_animation) and convergence_plot):
        fig = CreateConvergencePlot(errors)
        return x,fig
    else:
        return x



#################################################################################
# implementation of trust region methods

def SolveTrustRegion(g,B,Delta,
                     max_solvesteps = 100,
                     accuracy = 1e-6,
                     min_lambda_par = 1e-6):
    n = np.size(g)
    # find the smallest eigenvalue of B
    lambda_min = np.min(np.linalg.eigvalsh(B))
    if lambda_min > 0:
        p = -np.linalg.solve(B,g)
        norm_p = np.linalg.norm(p)
        if norm_p <= Delta:
            return p
        else:
            lambda_par = 0
            Blambda = B
    else:
        # modify B such that it becomes positive definite
        min_lambda_par += -lambda_min
        lambda_par = -lambda_min+1.e1
        Blambda = lambda_par*np.eye(n) + B
        p = -np.linalg.solve(Blambda,g)
        norm_p = np.linalg.norm(p)
    # run Newton's method
    n_step = 0
    rel_error = (norm_p-Delta)/Delta
    while ((n_step < max_solvesteps) and (np.abs(rel_error) > accuracy)):
        L = np.linalg.cholesky(Blambda)
        ptemp = -la.solve_triangular(L,g,lower=True)
        p = la.solve_triangular(np.transpose(L),ptemp)
        q = la.solve_triangular(L,p,lower=True)
        norm_p = np.linalg.norm(p)
        rel_error = (norm_p-Delta)/Delta
        if ((rel_error < 0) and (lambda_par < 1.1*min_lambda_par)):
            break
        lambda_par += (norm_p**2)*rel_error/(np.inner(q,q))
        if lambda_par < min_lambda_par:
            lambda_par = min_lambda_par
        Blambda = lambda_par*np.identity(n) + B
        n_step += 1
    return p


def TrustRegionNewton(f,x_init,
                      max_steps = 50,
                      create_animation = False,
                      convergence_plot = False,
                      Delta_0 = 1.0,
                      Delta_max = 1e2,
                      eta = 1e-3,
                      gradient_stop = 1e-12,
                      animationspeed = 500,
):
    # Initialise the variable x and the trust region radius
    x = x_init
    Delta = Delta_0

    # If we show a convergence plot, store the errors in each step
    if convergence_plot:
        if any(f.minimisers) == None:
            print("A convergence plot requires the (analytic) minimisers of the function to be set.")
            convergence_plot = False
        else:
            errors = []
    
    # If we want to show the whole iteration, store all the points
    if create_animation:
        points_x = [x_init[0]]
        points_y = [x_init[1]]
    
    # Main loop
    n_step = 0
    # count the number of times where the update was rejected
    n_rejected = 0
    current_value = f.val(x)
    current_gradient = f.grad(x)
    norm_grad = np.linalg.norm(current_gradient)
    current_hessian = f.hess(x)
    while ((n_step < max_steps) and (norm_grad > gradient_stop)):
        p = SolveTrustRegion(current_gradient,current_hessian,Delta)
        next_x = x+p
        # compute the ratio between expected and actual update
        next_value = f.val(next_x)
        actual_update = next_value-current_value
        expected_update = np.inner(current_gradient,p) + (1/2)*np.inner(p,current_hessian@p)
        rho = actual_update/expected_update
        # decide whether to change the trust region radius
        if rho < 1/4:
            Delta *= 1/4
        elif ((rho > 3/4) and (np.linalg.norm(p) > 0.9*Delta)):
            Delta*= 2
            if Delta > Delta_max:
                Delta = Delta_max
        # decide whether to make a step
        if rho > eta:
            x = next_x
            current_value = next_value
            current_gradient = f.grad(x)
            norm_grad = np.linalg.norm(current_gradient)
            current_hessian = f.hess(x)
            if create_animation:
                points_x = points_x + [x[0]]
                points_y = points_y + [x[1]]
            if convergence_plot:
                errors = errors + [np.linalg.norm(x-f.minimisers)]
        else:
            n_rejected += 1
        n_step += 1

        
    if n_step < max_steps:
        print("Converged after {} iterations.".format(n_step))
    else:
        print("Did not converge after {} iterations.".format(max_steps))
    print("In {} of these steps, the update was rejected.".format(n_rejected))
        
    if create_animation and not(convergence_plot):
        ani = CreateAnimation(f,points_x,points_y,speed=animationspeed)
        return x,ani
    elif create_animation and convergence_plot:
        ani = CreateAnimation_ConvergencePlot(f,points_x,points_y,errors,speed=animationspeed)
        return x,ani
    elif (not(create_animation) and convergence_plot):
        fig = CreateConvergencePlot(errors)
        return x,fig
    else:
        return x


def TrustRegionSR1(f,x_init,
                   max_steps = 50,
                   create_animation = False,
                   convergence_plot = False,
                   Delta_0 = 1.0,
                   Delta_max = 1e2,
                   eta = 1e-3,
                   gradient_stop = 1e-12,
                   eps_safety = 1e-3,
                   animationspeed = 500,
):
    # Initialise the variable x and the trust region radius
    x = x_init
    Delta = Delta_0

    # If we show a convergence plot, store the errors in each step
    if convergence_plot:
        if any(f.minimisers) == None:
            print("A convergence plot requires the (analytic) minimisers of the function to be set.")
            convergence_plot = False
        else:
            errors = []
    
    # If we want to show the whole iteration, store all the points
    if create_animation:
        points_x = [x_init[0]]
        points_y = [x_init[1]]
    
    # Main loop
    n_step = 0
    # count the number of times where the update was rejected
    n_rejected = 0
    current_value = f.val(x)
    current_gradient = f.grad(x)
    norm_grad = np.linalg.norm(current_gradient)
    B = np.identity(np.size(x))
    while ((n_step < max_steps) and (norm_grad > gradient_stop)):
        p = SolveTrustRegion(current_gradient,B,Delta)
        next_x = x+p
        # compute the ratio between expected and actual update
        next_value = f.val(next_x)
        actual_update = next_value-current_value
        expected_update = np.inner(current_gradient,p) + (1/2)*np.inner(p,B@p)
        rho = actual_update/expected_update
        # decide whether to change the trust region radius
        if rho < 1/4:
            Delta *= 1/4
        elif ((rho > 3/4) and (np.linalg.norm(p) > 0.9*Delta)):
            Delta*= 2
            if Delta > Delta_max:
                Delta = Delta_max
        # update B
        next_gradient = f.grad(next_x)
        y = next_gradient-current_gradient
        z = y-B@p
        pz = np.inner(p,z)
        if (np.abs(pz) > eps_safety*np.linalg.norm(p)*np.linalg.norm(z)):
                B += np.outer(z,z)/pz
        # decide whether to make a step
        if rho > eta:
            x = next_x
            current_value = next_value
            old_gradient = current_gradient
            current_gradient = next_gradient
            norm_grad = np.linalg.norm(current_gradient)
            if create_animation:
                points_x = points_x + [x[0]]
                points_y = points_y + [x[1]]
            if convergence_plot:
                errors = errors + [np.linalg.norm(x-f.minimisers)]
        else:
            n_rejected += 1
        n_step += 1

        
    if n_step < max_steps:
        print("Converged after {} iterations.".format(n_step))
    else:
        print("Did not converge after {} iterations.".format(max_steps))
    print("In {} of these steps, the update was rejected.".format(n_rejected))
        
    if create_animation and not(convergence_plot):
        ani = CreateAnimation(f,points_x,points_y,speed=animationspeed)
        return x,ani
    elif create_animation and convergence_plot:
        ani = CreateAnimation_ConvergencePlot(f,points_x,points_y,errors,speed=animationspeed)
        return x,ani
    elif (not(create_animation) and convergence_plot):
        fig = CreateConvergencePlot(errors)
        return x,fig
    else:
        return x



#################################################################################
# constrained optimisation

def QuadPenalty(f,c,x_init,
                penaltyparam_init = 1.e-2,
                penalty_mult = 10.,
                tolerance = 1.e-5
):

    PenFunc = copy.deepcopy(f)
    x = x_init
    penaltyparam = penaltyparam_init

    max_penalty = 1.e8
    converged = False
    penalty_list = []
    constraints_list = []
    function_values_list = []
    
    points_x = [x_init[0]]
    points_y = [x_init[1]]
 
    while(not(converged)):
        penalty_list.append(penaltyparam)
        stop = max(tolerance,min((1.e-1)/penaltyparam,1.e-4))
        PenFunc.val = lambda x: f.val(x) + 0.5*penaltyparam*(c.val(x)**2)
        PenFunc.grad = lambda x: f.grad(x) + penaltyparam*c.val(x)*c.grad(x)
        print("Minimising the quadratic penalty functional with a penalty parameter of {}".format(penaltyparam))
        x = BFGS(PenFunc,x,
                 max_steps = 50,
                 create_animation = False,
                 convergence_plot = False,
                 gradient_stop = stop,
                 animationspeed = 500)
        points_x.append(x[0])
        points_y.append(x[1])
        constraints_list.append(np.abs(c.val(x)))
        function_values_list.append(f.val(x))
        penaltyparam = penaltyparam*penalty_mult
        print("Solution: {}".format(x))
        print("Function value: {}; Constraint: {}".format(f.val(x),c.val(x)))
        print(" ")
        if(penaltyparam >= max_penalty):
            converged = True
        if((np.abs(c.val(x)) < tolerance)):
            converged = True
    #ani = CreateAnimation_QP(f,c,penalty_list,points_x,points_y,speed = 500)
    
    return x,np.array(penalty_list),np.array(constraints_list)
    
def AugmentedLagrangian(f,c,x_init,lambda_init,penaltyparam=10,tolerance= 1.e-6):
    AugLag = copy.deepcopy(f)
    x = x_init
    Lag = lambda_init
    converged = False
    num_iter = 0
    max_iter = 100

    while(not(converged)):
        num_iter += 1
        AugLag.val = lambda x: f.val(x) - Lag*c.val(x) + 0.5*penaltyparam*(c.val(x)**2)
        AugLag.grad = lambda x: f.grad(x) - Lag*c.grad(x) + penaltyparam*c.val(x)*c.grad(x)
        x = BFGS(AugLag,x,
                 max_steps = 50,
                 create_animation = False,
                 convergence_plot = False,
                 gradient_stop = 1.e-7,
                 animationspeed = 500)
        print("x = {}; lambda = {}".format(x,Lag))
        print("|c(x)| = {}".format(np.abs(c.val(x))))
        print("nabla_x L = {}".format(np.abs(f.grad(x) - Lag*c.grad(x))))
        print(" ")
        if((np.abs(c.val(x)) < tolerance) or num_iter >= max_iter):
            converged = True
        Lag += - penaltyparam*c.val(x)

    return x,Lag
    
    
                        
