'''
Visualisation of a long-step central path following
interior point method for the course TMA4180 - Optimisation 1.

Markus Grasmair
Trondheim, April 2023 -- April 2024
'''

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import TMA4180_definitions
from matplotlib import animation
import mpl_toolkits.axes_grid1
import matplotlib.widgets
import matplotlib.patches

############################################################
# definition of a Player class for interactive animations
# code by StackOverflow user ImportanceOfBeingErnest,
# taken from his answer to this question:
# https://stackoverflow.com/questions/44985966/managing-dynamic-plotting-in-matplotlib-animation-module/44989063#44989063
class Player(animation.FuncAnimation):
    def __init__(self, fig, func, frames=None, init_func=None, fargs=None,
                 save_count=None, mini=0, maxi=100, pos=(0.125, 0.92), **kwargs):
        self.i = 0
        self.min=mini
        self.max=maxi
        self.runs = True
        self.forwards = True
        self.fig = fig
        self.func = func
        self.setup(pos)
        animation.FuncAnimation.__init__(self,self.fig, self.func, frames=self.play(), 
                                         init_func=init_func, fargs=fargs,
                                         save_count=save_count, **kwargs )    

    def play(self):
        while self.runs:
            self.i = self.i+self.forwards-(not self.forwards)
            if self.i > self.min and self.i < self.max:
                yield self.i
            else:
                self.stop()
                yield self.i

    def start(self):
        self.runs=True
        self.event_source.start()

    def stop(self, event=None):
        self.runs = False
        self.event_source.stop()

    def forward(self, event=None):
        self.forwards = True
        self.start()
    def backward(self, event=None):
        self.forwards = False
        self.start()
    def oneforward(self, event=None):
        self.forwards = True
        self.onestep()
    def onebackward(self, event=None):
        self.forwards = False
        self.onestep()

    def onestep(self):
        if self.i > self.min and self.i < self.max:
            self.i = self.i+self.forwards-(not self.forwards)
        elif self.i == self.min and self.forwards:
            self.i+=1
        elif self.i == self.max and not self.forwards:
            self.i-=1
        self.func(self.i)
        self.fig.canvas.draw_idle()

    def setup(self, pos):
        playerax = self.fig.add_axes([pos[0],pos[1], 0.22, 0.04])
        divider = mpl_toolkits.axes_grid1.make_axes_locatable(playerax)
        bax = divider.append_axes("right", size="80%", pad=0.05)
        sax = divider.append_axes("right", size="80%", pad=0.05)
        fax = divider.append_axes("right", size="80%", pad=0.05)
        ofax = divider.append_axes("right", size="100%", pad=0.05)
        self.button_oneback = matplotlib.widgets.Button(playerax, label=u'$\u29CF$')
        self.button_back = matplotlib.widgets.Button(bax, label=u'$\u25C0$')
        self.button_stop = matplotlib.widgets.Button(sax, label=u'$\u25A0$')
        self.button_forward = matplotlib.widgets.Button(fax, label=u'$\u25B6$')
        self.button_oneforward = matplotlib.widgets.Button(ofax, label=u'$\u29D0$')
        self.button_oneback.on_clicked(self.onebackward)
        self.button_back.on_clicked(self.backward)
        self.button_stop.on_clicked(self.stop)
        self.button_forward.on_clicked(self.forward)
        self.button_oneforward.on_clicked(self.oneforward)


#########################################################################3
# Definitions


def IntPoint_ConvergencePlot(c,points_x,points_y,errors,speed):
    # create the animation
    animation_fig,ax1 = plt.subplots(1,1)
    plt.axes(ax1)
    ax1.clear()
    x_lower = -0.5
    x_upper = 5.5
    y_lower = -0.5
    y_upper = 5.5
    N_points = 200
    delta_x = (x_upper-x_lower)/N_points
    delta_y = (y_upper-y_lower)/N_points
    xx = np.arange(x_lower,x_upper,delta_x)
    yy = np.arange(y_lower,y_upper,delta_y)
    X, Y = np.meshgrid(xx,yy)
    Z = c[0]*X + c[1]*Y
    aspect_ratio = (x_upper-x_lower)/(y_upper-y_lower)
    f_image = plt.imshow(Z, interpolation='bilinear',
                         origin='lower', cmap=cm.Blues, 
                         extent=(x_lower,x_upper,y_lower,y_upper),aspect=aspect_ratio)
    plt.plot([0.,0.,1.,4.,5.,2.,0.],[0.,2.,5.,4.,1.,0.,0.,],lw=2,color='r')
    N_lines = 20
    f_contour = plt.contour(X,Y,Z,N_lines)
    
    current_line, = ax1.plot([],[],lw=1,color='r')
    previous_line, = ax1.plot([],[],lw=1,color='k')
    def animation_init():
        current_line.set_data([],[])
        previous_line.set_data([],[])
        return(current_line,previous_line,)
    
    def animation_animate(i):
        plt.axes(ax1)
        if i==0:
            current_line.set_data(points_x[0:2],points_y[0:2])
        else:
            current_line.set_data(points_x[i-1:i+1],points_y[i-1:i+1])
            previous_line.set_data(points_x[:i],points_y[:i])
        return(current_line,previous_line,)
    
    ani = Player(animation_fig, animation_animate,
                 mini = 0, maxi = len(points_x),
                 init_func = animation_init,
                 interval = speed,
    )
    return ani



def IntPoint(A,b,c,x,sigma=0.5,eps_stop = 1e-10, max_steps = 500, speed = 250, show_plots = True):
    # solves the linear programme
    # c^T c \to \min s.t. Ax \le b
    # with an interior point method.
    # Requires an initialisation x satisfying Ax < b
    # The code includes no safeguards against
    # ill-conditioning, and it has only been tested with
    # the system Ax \le b defined below.
    # The visualisation will only make sense with that system.

    # start by rewriting the problem in standard form
    
    # find the dimensionality of the problem
    m = np.shape(A)[0]
    d = np.shape(A)[1]
    n = m+d
    # construct the matrix for the larger problem
    B = np.concatenate((A,np.eye(m)),axis=1)
    # compute the slack variable
    y = b-A@x
    # construct the complete vector
    z = np.concatenate((x,y))
    # complete the cost vector
    chat = np.concatenate((c,np.zeros(m)))
    s_target = np.ones(n)
    tau = 1.

    # compute a strictly feasible dual variable
    lam = np.linalg.solve(tau*np.eye(m)+B@B.transpose(),B@(chat-s_target))
    print('Initialisation of lambda: {}'.format(lam))
    # construct the initialisation of s
    s = chat - B.transpose()@lam
    print('Initialisation of s: {}'.format(s))
    # find the initial duality measure
    mu = np.inner(z,s)/n
    gamma = min(z*s/mu)*0.5
    
    if (gamma <= 0) or (mu <= 0):
        raise ArithmeticError('The initialisation is not strictly admissible.')
    
    errors = [mu]
    points_x = [z[0]]
    points_y = [z[1]]

    # initialise the top part of the Newton matrix
    # (this remains the same for all iterations)
    NM_temp = np.concatenate((B,np.zeros([m,m]),np.zeros([m,n])),axis=1)
    NM_temp2 = np.concatenate((np.zeros([n,n]),B.transpose(),np.eye(n)),axis=1)
    NM_top = np.concatenate((NM_temp,NM_temp2))
    
    #
    n_step = 0

    while ((mu >= eps_stop) and n_step < max_steps):
        # compute the new target duality measure
        tau = sigma*mu
        # construct the Newton matrix
        NM_temp = np.concatenate((np.diag(s),np.zeros([n,m]),np.diag(z)),axis=1)
        NM = np.concatenate((NM_top,NM_temp))
        # construct the right hand side
        rhs = np.concatenate((np.zeros(m+n),tau-z*s))
        # compute the update step
        update = np.linalg.solve(NM,rhs)

        # compute the maximal admissible step length
        Dz = update[0:n]
        Dlam = update[n:(n+m)]
        Ds = update[(n+m):]

        #
        p = z*s - gamma*np.inner(z,s)/n
        q = Dz*s + z*Ds - gamma*(np.inner(z,Ds)+np.inner(Dz,s))/n
        r = Dz*Ds - gamma*np.inner(Dz,Ds)/n
        disc = q**2 - 4*p*r
        safety_factor = 1e-12
        if any(disc > 0):
            disc_red = disc[disc>0]
            p_red = p[disc>0]
            q_red = q[disc>0]
            r_red = r[disc>0]
            alpha1 = (-q_red + np.sqrt(disc_red))/(2*r_red)
            alpha2 = (-q_red - np.sqrt(disc_red))/(2*r_red)
            alpha_all = np.concatenate((alpha1,alpha2,np.ones(1)))
            alpha_all_pos = alpha_all[alpha_all>safety_factor]
            alpha = min(alpha_all_pos)
        else:
            alpha = 1.

        # update the variable
        z += alpha*Dz
        lam += alpha*Dlam
        s += alpha*Ds

        # compute the new duality measure
        mu = np.inner(z,s)/n
        
        points_x = points_x + [z[0]]
        points_y = points_y + [z[1]]
        errors = errors + [mu]
        n_step += 1
    if n_step == max_steps:
        print('Did not converge after {} steps.'.format(n_step))
        print('The final duality measure was {}.'.format(mu))
    else:
        print('Converged after {} steps.'.format(n_step))
        print('The final duality measure was {}.'.format(mu))

    x = z[0:d]
    if(show_plots):
        ani = IntPoint_ConvergencePlot(c,points_x,points_y,errors,speed)
        return x,ani
    else:
        return x
    
A = np.array([[-3.,1.],[1.,3.],[3.,1.],[1.,-3.]])

b = np.array([2.,16.,16.,2.])

c = np.array([-1.,-2.])

def RunAlg(x0=np.array([1.0,4.8]),c = np.array([-1.0,-1.0]),sigma=0.5):
    x=IntPoint(A,b,c,x0,sigma)
    plt.show()
