'''
Implementation of the Nelder-Mead algorithm for usage in the course 
TMA4180 - Optimisation 1.

Markus Grasmair
Trondheim, January 2023 -- January 2024
'''

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib import animation
from IPython.display import HTML
import mpl_toolkits.axes_grid1
import matplotlib.widgets
import matplotlib.patches

# TMA4180_basic contains the definitions of the test functions
import TMA4180_definitions


# definition of a Player class for interactive animations
# code by StackOverflow user ImportanceOfBeingErnest,
# taken from his answer to this question:
# https://stackoverflow.com/questions/44985966/managing-dynamic-plotting-in-matplotlib-animation-module/44989063#44989063
class Player(animation.FuncAnimation):
    def __init__(self, fig, func, frames=None, init_func=None, fargs=None,
                 save_count=None, mini=0, maxi=100, pos=(0.125, 0.92), **kwargs):
        self.i = 0
        self.min=mini
        self.max=maxi
        self.runs = True
        self.forwards = True
        self.fig = fig
        self.func = func
        self.setup(pos)
        animation.FuncAnimation.__init__(self,self.fig, self.func, frames=self.play(), 
                                         init_func=init_func, fargs=fargs,
                                         save_count=save_count, **kwargs )    

    def play(self):
        while self.runs:
            self.i = self.i+self.forwards-(not self.forwards)
            if self.i > self.min and self.i < self.max:
                yield self.i
            else:
                self.stop()
                yield self.i

    def start(self):
        self.runs=True
        self.event_source.start()

    def stop(self, event=None):
        self.runs = False
        self.event_source.stop()

    def forward(self, event=None):
        self.forwards = True
        self.start()
    def backward(self, event=None):
        self.forwards = False
        self.start()
    def oneforward(self, event=None):
        self.forwards = True
        self.onestep()
    def onebackward(self, event=None):
        self.forwards = False
        self.onestep()

    def onestep(self):
        if self.i > self.min and self.i < self.max:
            self.i = self.i+self.forwards-(not self.forwards)
        elif self.i == self.min and self.forwards:
            self.i+=1
        elif self.i == self.max and not self.forwards:
            self.i-=1
        self.func(self.i)
        self.fig.canvas.draw_idle()

    def setup(self, pos):
        playerax = self.fig.add_axes([pos[0],pos[1], 0.22, 0.04])
        divider = mpl_toolkits.axes_grid1.make_axes_locatable(playerax)
        bax = divider.append_axes("right", size="80%", pad=0.05)
        sax = divider.append_axes("right", size="80%", pad=0.05)
        fax = divider.append_axes("right", size="80%", pad=0.05)
        ofax = divider.append_axes("right", size="100%", pad=0.05)
        self.button_oneback = matplotlib.widgets.Button(playerax, label=u'$\u29CF$')
        self.button_back = matplotlib.widgets.Button(bax, label=u'$\u25C0$')
        self.button_stop = matplotlib.widgets.Button(sax, label=u'$\u25A0$')
        self.button_forward = matplotlib.widgets.Button(fax, label=u'$\u25B6$')
        self.button_oneforward = matplotlib.widgets.Button(ofax, label=u'$\u29D0$')
        self.button_oneback.on_clicked(self.onebackward)
        self.button_back.on_clicked(self.backward)
        self.button_stop.on_clicked(self.stop)
        self.button_forward.on_clicked(self.forward)
        self.button_oneforward.on_clicked(self.oneforward)


# definition of the Nelder-Mead algorithm
def NelderMead(f,simplex_init,
               max_steps = 10,
               show_plots = False,
               save_plots = False,
               create_animation = False,
               animationspeed = 500,
):

    # 
    if (show_plots or save_plots):
        make_plots = True
    else:
        make_plots = False

    # store the dimensionality of the problem
    num_dim = simplex_init.shape[1]

    # rename the initial simplex and convert to floating point
    # if necessary
    current_simplex = simplex_init.astype('float64')
    #print(current_simplex)

    f_values = np.zeros(num_dim+1)
    # compute the function values for each element of the initial simplex
    for i in range(num_dim+1):
        f_values[i] = f.val(current_simplex[i,:])
    
    # find the smallest function value in the initial simplex
    # as well as its corresponding index
    # NB: if there are multiple indices with the same value, numpy
    # chooses the first index
    ind_min = np.argmin(f_values)
    val_min = f_values[ind_min]
    
    # if called for, plot the current simplex
    if make_plots:
        f.plotf()
        plt.fill(current_simplex[:,0],current_simplex[:,1],fill=False)
        if show_plots:
            plt.show()
        if save_plots:
            plt.savefig("./plotsNelderMead/plot0.png")
            plt.close()

    # if we want to output an animation, store the simplex for later use
    if create_animation:
        simplex_list_x = np.array([current_simplex[:,0]])
        simplex_list_y = np.array([current_simplex[:,1]])

    
    n_step = 0
    while(n_step < max_steps):
        # increase the step counter
        n_step += 1
        # reset a flag indicating whether we were successful with
        # improving the function value
        flag_improvement = False
        # find the largest function value in the current simplex
        # as well as its corresponding index and x-value
        ind_max = np.argmax(f_values)
        val_max = f_values[ind_max]
        x_max = current_simplex[ind_max,:]
        # remove the worst point from the simplex and store the result
        best_face = np.delete(current_simplex,ind_max,axis=0)
        # compute the centroid of the face of the simplex opposite to the worst point
        centroid = np.mean(best_face,axis=0)
        # compute the reflection of the worst point by the centroid
        # and the corresponding function value
        testpoint1 = 2*centroid - x_max
        val_testpoint1 = f.val(testpoint1)
        if val_testpoint1 < val_min:
            # if this function value is smaller than the current smallest value
            # go "one step" further in that direction, compute the resulting point
            # and corresponding function value
            testpoint2 = 3*centroid - 2*x_max
            val_testpoint2 = f.val(testpoint2)
            # now take the smaller of these values:
            # replace the worst point in the current simplex by the newly found one
            # update the current smallest function value as well as its index
            if val_testpoint2 < val_testpoint1:
                current_simplex[ind_max,:] = testpoint2
                f_values[ind_max] = val_testpoint2
                val_min = val_testpoint2
            else:
                current_simplex[ind_max,:] = testpoint1
                f_values[ind_max] = val_testpoint1
                val_min = val_testpoint1
            ind_min = ind_max
            # set a flag that we were successful with improving the function value
            flag_improvement = True
        else:
            # find the second worst function value (worst function value on the face)
            values_best_face = np.delete(f_values,ind_max)
            val_secondmax = np.max(values_best_face)
            # if the value at the reflection point is smaller than the
            # current second worst value, replace the worst point in the current
            # simplex by this reflection point and update the function values
            if val_testpoint1 < val_secondmax:
                current_simplex[ind_max,:] = testpoint1
                f_values[ind_max] = val_testpoint1
                # set a flag that we were successful with improving the function value
                flag_improvement = True
            else:
                # the reflected point is worse than the current second worst point
                if val_testpoint1 < val_max:
                    # if the value at the reflected point is smaller than the
                    # current worst value, consider the mean between reflection
                    # point and centroid of the simplex opposite to the worst point instead
                    # and compute this point and the corresponding function value
                    testpoint2 = 2.5*centroid - 1.5*x_max
                    val_testpoint2 = f.val(testpoint2)
                    if val_testpoint2 <= val_testpoint1:
                        # if this value is smaller than the value at the reflected point,
                        # replace the worst point in the current simplex by this reflection point
                        # and update the function values
                        current_simplex[ind_max,:] = testpoint2
                        f_values[ind_max] = val_testpoint2
                        # set a flag that we were successful with improving the function value
                        flag_improvement = True
                        # if necessary, update the current smallest function value and its index
                        if val_testpoint2 < val_min:
                            val_min = val_testpoint2
                            ind_min = ind_max
                else:
                    # consider the mean between the current worst point and the
                    # centroid of the opposite face instead,
                    # compute that point and the corresponding function value
                    testpoint2 = 0.5*(centroid + x_max)
                    val_testpoint2 = f.val(testpoint2)
                    if val_testpoint2 < val_max:
                        # if this function value is smaller than the current worst
                        # function value, replace the worst point in the simplex
                        # by this new point and update the function values
                        current_simplex[ind_max,:] = testpoint2
                        f_values[ind_max] = val_testpoint2
                        # set a flag that we were successful with improving the function value
                        flag_improvement = True
                        # if necessary, update the current smallest function value and its index
                        if val_testpoint2 < val_min:
                            val_min = val_testpoint2
                            ind_min = ind_max
        if(not(flag_improvement)):
            # if we were not successful with improving the function value
            # shrink the whole simplex towards the current best point
            # and recompute the function values where necessary
            for i in range(num_dim+1):
                if i != ind_min:
                    current_simplex[i,:] = 0.5*(current_simplex[i,:]+current_simplex[ind_min,:])
                    f_values[i] = f.val(current_simplex[i,:])
            # find the smallest function value in the current simplex
            # as well as its corresponding index
            ind_min = np.argmin(f_values)
            val_min = f_values[ind_min]
        if make_plots:
            f.plotf()
            plt.fill(current_simplex[:,0],current_simplex[:,1],fill=False)
            if show_plots:
                plt.show()
            if save_plots:
                plt.savefig("./plotsNelderMead/plot"+str(n_step)+".png")
                plt.close()
        if create_animation:
            simplex_list_x = np.append(simplex_list_x,[current_simplex[:,0]],axis=0)
            simplex_list_y = np.append(simplex_list_y,[current_simplex[:,1]],axis=0)


    if create_animation:
        # create an animation of the algorithm
        animation_fig,ax = plt.subplots()
        animation_simplex, = ax.plot([],[],lw=1,color='k')
        f.plotf()
        num_iterations = simplex_list_x.shape[0]
        num_digits = len(str(num_iterations-1))
        def animation_init():
            plt.title('Nelder-Mead Method - Iteration 0')
            data_x = np.append(simplex_list_x[0],simplex_list_x[0,0])
            data_y = np.append(simplex_list_y[0],simplex_list_y[0,0])
            animation_simplex.set_data(data_x,data_y)
            return(animation_simplex,)
        
        def animation_animate(i):
            plt.title('Nelder-Mead Method - Iteration {:-{count}}'.format(min(i,num_iterations-1),count=num_digits))
            data_x = np.append(simplex_list_x[i],simplex_list_x[i,0])
            data_y = np.append(simplex_list_y[i],simplex_list_y[i,0])
            animation_simplex.set_data(data_x,data_y)
            return(animation_simplex,)

        ani = Player(animation_fig, animation_animate,
                     #range(simplex_list_x.shape[0]),
                     mini = 0, maxi = num_iterations,
                     init_func = animation_init,
                     interval=animationspeed,
        )
        return current_simplex[ind_min,:],ani
    else:
        return current_simplex[ind_min,:]
