'''
Here I collect definitions of functions used
in examples in the course TMA4180 - Optimisation 1.

Markus Grasmair
Trondheim, January 2023 -- February 2024
'''

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm


class testfunction:
    '''
    Class for test functions to be called by the various algorithms
    implemented for the course TMA4180 - Optimisation 1.

    Attributes:
    - val ... function that returns the value of the testfunction
    - grad ... function that returns the gradient of the testfunction
    - hess ... function that returns the Hessian of the testfunction
    - plots ... Boolean variable; indicates whether the instance supports plotting
    - minimisers ... Position of the minimisers of f. Not yet used.
    If the instance supports plotting, we have the following additional attributes:
    - plotbounds ... bounds of the plot
    - N_points ... number of discretisation points
    - plot_scaling ... rescaling of the function used for better contour plots
    - X ... discretisation of the first variable; only produced if needed
    - Y ... discretisation of the second variable; only produced if needed
    - Z ... function values at the discretisation; only produced if needed

    Methods:
    - __init__
      Initialisation
    - build_discretisation
      Builds a discretisation of the function over the area indicated by plotbounds.
    - plotf
      If plotting is supported, returns a 2d plot including contour lines of the testfunction.
    - information
      Not yet implemented.
      Should at some point provide information about the testfunction.
    '''

    def __init__(self, val = None,
                 grad = lambda: None,
                 hess = lambda: None,
                 plots = False,
                 plotbounds = np.array([[0,1],[0,1]]),
                 N_points = 800,
                 N_lines = 20,
                 plot_scaling = lambda x: x,
                 minimisers = None,
    ):
        '''
        Initialisation of the instance.
        '''

        self.val = val
        self.grad = grad
        self.hess = hess
        self.plots = plots
        if self.plots:
            self.plotbounds = plotbounds
            self.N_points = N_points
            self.plot_scaling = plot_scaling
            self.N_lines = N_lines
        self.minimisers = minimisers
        

    def build_discretisation(self):
        if self.plots:
            try:
                x_lower = self.plotbounds[0,0]
                x_upper = self.plotbounds[0,1]
                y_lower = self.plotbounds[1,0]
                y_upper = self.plotbounds[1,1]
                delta_x = (x_upper-x_lower)/self.N_points
                delta_y = (y_upper-y_lower)/self.N_points
                xx = np.arange(x_lower,x_upper,delta_x)
                yy = np.arange(y_lower,y_upper,delta_y)
                self.X, self.Y = np.meshgrid(xx,yy)
                self.Z = self.val(np.array([self.X,self.Y]))
            except AttributeError:
                print("This function does not support plotting.")
            except IndexError:
                print("Please vectorise the function for plotting.")
        else:
            print("This function does not support plotting.")
    
    def plotf(self):
        if self.plots:
            try:
                x_lower = self.plotbounds[0,0]
                x_upper = self.plotbounds[0,1]
                y_lower = self.plotbounds[1,0]
                y_upper = self.plotbounds[1,1]
                aspect_ratio = (x_upper-x_lower)/(y_upper-y_lower)
                self.Z = self.val(np.array([self.X,self.Y]))
                f_image = plt.imshow(self.Z, interpolation='bilinear',
                                     origin='lower', cmap=cm.Blues, 
                                     extent=(x_lower,x_upper,y_lower,y_upper),aspect=aspect_ratio)
                f_contour = plt.contour(self.X,self.Y,self.plot_scaling(self.Z),self.N_lines)
                return f_image,f_contour
            except AttributeError:
                try:
                    self.build_discretisation()
                    self.plotf()
                except AttributeError:
                    print("This function does not support plotting.")
        else:
            print("This function does not support plotting.")

    def set_plotbounds(self,newbounds):
        self.plotbounds = newbounds
        self.build_discretisation()
            
    def information(self):
        print("Feature not yet implemented")
        
class constraint:
    '''
    Class for constraints to be called by the various algorithms
    implemented for the course TMA4180 - Optimisation 1.

    Attributes:
    - val ... function that returns the value of the constraint
    - grad ... function that returns the gradient of the constraint
    - hess ... function that returns the Hessian of the constraint
    - equality ... Boolean variable; indicates whether this is an equality (True)
                   or inequality (False) constraint

    Methods:
    - __init__
      Initialisation
    '''

    def __init__(self, val = None,
                 grad = lambda: None,
                 hess = lambda: None,
                 equality = True,
    ):
        '''
        Initialisation of the instance.
        '''

        self.val = val
        self.grad = grad
        self.hess = hess
        self.equality = equality
    

### definition of Himmelblau's function
def HB_val(x):
    '''
    Implementation of Himmelblau's function.
    '''
    xx = x[0]
    yy = x[1]
    return((xx**2 + yy - 11)**2 + (xx + yy**2 - 7)**2)

def HB_grad(x):
    '''
    Gradient of Himmelblau's function.
    '''
    xx = x[0]
    yy = x[1]
    grad1 = 4*xx*(xx**2 + yy - 11) + 2*(xx + yy**2 - 7)
    grad2 = 2*(xx**2 + yy - 11) + 4*yy*(xx + yy**2 - 7)
    return np.array([grad1,grad2])

def HB_hess(x):
    '''
    Hessian of Himmelblau's function.
    '''
    xx = x[0]
    yy = x[1]
    h11 = 4*(xx**2+y-11) + 8*xx**2 + 2
    h12 = 4*xx + 4*yy
    h21 = h12
    h22 = 2 + 4*(xx+yy**2-7) + 8*yy**2
    return np.array([[h11,h12],[h21,h22]])   

HB = testfunction(val = HB_val,
                  grad = HB_grad,
                  hess = HB_hess,
                  plots = True,
                  plotbounds = np.array([[-5.,5.],[-5.,5.]]),
                  N_points = 800,
                  N_lines = 20,
                  plot_scaling = lambda x: ((x>=0)*(np.abs(x)+0.8) + (x<0)*(0.8*np.exp(-np.abs(x))))**(0.25),
                  minimisers = None,
)

# clean up
del(HB_val)
del(HB_grad)
del(HB_hess)



### definition of the 2d Rosenbrock function
def RB_val(x):
    '''
    Implementation of the Rosenbrock function.
    '''
    xx = x[0]
    yy = x[1]
    return (1-xx)**2 + 100*(yy-xx**2)**2

def RB_grad(x):
    '''
    Gradient of the Rosenbrock function.
    '''
    xx = x[0]
    yy = x[1]
    grad1 = -2*(1-xx) - 400*(yy-xx**2)*xx
    grad2 = 200*(yy-xx**2)
    return np.array([grad1,grad2])

def RB_hess(x):
    '''
    Hessian of the Rosenbrock function.
    '''
    xx = x[0]
    yy = x[1]
    h11 = 1200*xx**2 - 400*yy + 2
    h12 = -400*xx
    h21 = h12
    h22 = 200
    return(np.array([[h11,h12],[h21,h22]]))

RB = testfunction(val = RB_val,
                  grad = RB_grad,
                  hess = RB_hess,
                  plots = True,
                  plotbounds = np.array([[-1,3],[-1,3]]),
                  N_points = 800,
                  plot_scaling = lambda x: np.log(x+0.8),
                  minimisers = np.array([1.0,1.0]),
)

# clean up
del(RB_val)
del(RB_grad)
del(RB_hess)



### Example problem from lecture notes

def Note1_Ex28_val(x):
    xx = x[0]
    yy = x[1]
    return(3*xx**4 + 4*xx**3 + 12*yy**2 - 24*xx*yy)

def Note1_Ex28_grad(x):
    xx = x[0]
    yy = x[1]
    grad1 = 12*xx**3 + 12*xx**2 - 24*yy
    grad2 = 24*yy - 24*xx
    return(np.array([grad1,grad2]))

def Note1_Ex28_hess(x):
    xx = x[0]
    yy = x[1]
    h11 = 12*xx**3 + 12*xx**2 - 24*yy
    h12 = -24
    h21 = -24
    h22 = 24
    return(np.array([[h11,h12],[h21,h22]]))

Note1_Ex28 = testfunction(val = Note1_Ex28_val,
                          grad = Note1_Ex28_grad,
                          hess = Note1_Ex28_hess,
                          plots = True,
                          plotbounds = np.array([[-3,3],[-3,3]]),
                          N_points = 800,
                          N_lines = 27,
                          plot_scaling = lambda x: np.log(x+33.11),
)

# clean up
del(Note1_Ex28_val)
del(Note1_Ex28_grad)
del(Note1_Ex28_hess)



### Counterexample for Nelder-Mead algorithm
# See McKinnon, Convergence of the Nelder-Mead Simplex Method to a Nonstationary Point
# SIAM J. Optimization 9(1), pp. 148-158, 1998.

def NM_CounterEx_val(x):
    xx = x[0]
    yy = x[1]
    value = (xx>=0)*(6*xx**2 + yy + yy**2) + (xx<0)*(360*xx**2 + yy + yy**2)
    return(value)

def NM_CounterEx_grad(x):
    xx = x[0]
    yy = x[1]
    if xx >= 0:
        grad1 = 12*xx
        grad2 = 1 + 2*yy
    else:
        grad1 = 720*xx
        grad2 = 1 + 2*yy
    return(np.array([grad1,grad2]))

NM_CounterEx = testfunction(val = NM_CounterEx_val,
                            grad = NM_CounterEx_grad,
                            plots = True,
                            plotbounds = np.array([[-0.1,1.1],[-1.1,1.1]]),
                            N_points = 800,
                            N_lines = 20,
                            plot_scaling = lambda x: np.log(x+1),
)

del(NM_CounterEx_val)
del(NM_CounterEx_grad)


### Quadratic function


def Quadratic_val(x):
    eps = 0.05
    xx = x[0]
    yy = x[1]
    value = xx**2 + eps*yy**2
    return(value)

def Quadratic_grad(x):
    eps = 0.05
    xx = x[0]
    yy = x[1]
    return(np.array([2*xx,2*eps*yy]))

def Quadratic_hess(x):
    eps = 0.05
    return(np.array([[2,0],[0,2*eps]]))

QuadraticEx = testfunction(val = Quadratic_val,
                           grad = Quadratic_grad,
                           hess = Quadratic_hess,
                           plots = True,
                           plotbounds = np.array([[-1.2,1.2],[-1.2,1.2]]),
                           N_points = 800,
                           N_lines = 20,
                           plot_scaling = lambda x: np.log(x+0.05),
)

del(Quadratic_val)
del(Quadratic_grad)
del(Quadratic_hess)



### Extended Rosenbrock function

def ExtRosenbrock_value(x):
    n = 100
    alpha = 50
    value = np.sum(alpha*(x[1:]-x[:-1]**2)**2 + (1-x[:-1])**2)
    return(value)

def ExtRosenbrock_grad(x):
    n = 100
    alpha = 50
    temp = np.zeros(n)
    temp[1:] += 2*alpha*(x[1:]-x[:-1]**2)
    temp[:-1] += -4*alpha*(x[1:]-x[:-1]**2)*x[:-1] - 2*(1-x[:-1])
    return(temp)

def ExtRosenbrock_hess(x):
    n = 100
    alpha = 50
    temp = np.zeros([n,n])
    temp += np.diagflat(-4*alpha*x[:-1],k=1)
    temp += np.diagflat(-4*alpha*x[:-1],k=-1)
    tempdiag = np.zeros(n)
    tempdiag[:-1] += -4*alpha*x[1:]+12*alpha*x[:-1]**2 + 2
    tempdiag[1:] += 2*alpha
    temp += np.diagflat(tempdiag,k=0)
    return(temp)

ExtRosenbrock = testfunction(val = ExtRosenbrock_value,
                             grad = ExtRosenbrock_grad,
                             hess = ExtRosenbrock_hess,
                             plots = False,
                             minimisers = np.ones(100)
)

del(ExtRosenbrock_value)
del(ExtRosenbrock_grad)
del(ExtRosenbrock_hess)


### Non-smooth test function
# The function is not differentiable, but a gradient is provided nevertheless.
# Note, though, that gradient based optimisation methods are expected to fail.

def NonSmooth_value(x):
    value = (1/4)*x[0]**4 + (1/2)*x[0]**2 + np.abs(x[1]) + (1/4)*x[1]**2
    return(value)

def NonSmooth_gradient(x):
    return(np.array([x[0]**3+x[0],np.sign(x[1])]+(1/2)*x[1]))

NonSmoothTestFunction = testfunction(val = NonSmooth_value,
                                     grad = NonSmooth_gradient,
                                     plots = True,
                                     plotbounds = np.array([[-1.2,1.2],[-1.2,1.2]]),
                                     N_points = 800,
                                     N_lines = 20,
                                     plot_scaling = lambda x: x,
                                     minimisers = np.array([0.0,0.0]),
)

del(NonSmooth_value)
del(NonSmooth_gradient)


### Elastic net regression - primal function
# The function is not differentiable, but a gradient is provided nevertheless.
# Note, though, that gradient based optimisation methods are expected to fail.

# specifically, we consider elastic net 

def ElasticNet_value(x):
    value = (1/2)*(x[0]+2*x[1]-2)**2 + (3/2)*(np.abs(x[0])+np.abs(x[1])) + (1/2)*(x[0]**2+x[1]**2)
    return(value)

def ElasticNet_gradient(x):
    g0 = (x[0]+2*x[1]-2) + (3/2)*np.sign(x[0]) + x[0]
    g1 = 2*(x[0]+2*x[1]-2) + (3/2)*np.sign(x[1]) + x[1]
    return(np.array([g0,g1]))

ElasticNet = testfunction(val = ElasticNet_value,
                          grad = ElasticNet_gradient,
                          plots = True,
                          plotbounds = np.array([[-0.5,1.2],[-0.5,1.2]]),
                          N_points = 800,
                          N_lines = 20,
                          plot_scaling = lambda x: np.log(x-1.3),
                          minimisers = np.array([0.0,0.5]),
)

del(ElasticNet_value)
del(ElasticNet_gradient)


## constraint of a circle of radius 4

def Circle_val(x):
    xx = x[0]
    yy = x[1]
    value = 0.5*(16 - xx**2 - yy**2)
    return(value)

def Circle_grad(x):
    xx = x[0]
    yy = x[1]
    g0 = -xx
    g1 = -yy
    return(np.array([g0,g1]))

def Circle_hess(x):
    return(-np.eye(2))

Circle_4 = constraint(val = Circle_val,
                      grad = Circle_grad,
                      hess = Circle_hess
)

del(Circle_val)
del(Circle_grad)
del(Circle_hess)
