# -*- coding: utf-8 -*-
This code implements receptor kinetics for realistic spiking neuron models.

kineticBase() - Kinetic base class
    firstOrdKin(kineticBase) - First-order kinetic that implements
                               Equation 6 from [1].
    firstOrdKinExp(firstOrdKin) - First-order kinetic that implements
                                  an exponential decay variant from 
    secondOrdKin(kineticBase) - Second-order kinetic that implments
                                Equations 4 and 5 from [1].


Change log
03 May 20 - Cleaned up the code to be more Pythonic and added more
            documentation. Final version for eventual upload.
18 Sep 20 - Split out the kinetics from the currents.
[1] Wang, X.J. (1999). Synaptic Basis of Cortical Persistent 
    Activity:  the Importance of NMDA Receptors to Working Memory. 
    The Journal of Neuroscience, 19(21), 9587–9603. 
    doi: 10.1523/JNEUROSCI.19-21-09587.1999

[2] Compte, A., Brunel, N., Goldman-Rakic, P.S., & Wang, X.J. (2000). 
    Synaptic mechanisms and network dynamics underlying spatial 
    working memory in a cortical network model. Cerebral Cortex, 10,
    910-923. doi: 10.1093/cercor/10.9.910 

from math import exp

from numpy import zeros, copy

class kineticBase(object):
    Base class for kinetics.  This assumes that there is a single
    s kinetic, which can be expanded upon by subclasses.
    size : int
        The size of the neural population.
    track : boolean
        Whether one of the kinetics should be tracked.
    track_ID : int
        The index of the neuron whose kinetic is being tracked.
    tracked_s : list(float)
        If track is True, the s kinetic value during the experiment.
    __init__(size, track=False, track_ID=0)
        Initializes kineticBase
        Predicts future channel gating.
    update(dt, spk_times, weights)
        Updates the kinetics.
        Gets the kinetic value.
        Resets the kinetics and tracking.
    def __init__(self, size, track=False, track_ID=0):
        Initializes kineticBase.

        size : int
            Size of the neural population
        track : boolean
            Whether the kinetics should be tracked. The default is False.
        track_ID : uint
            Which neuron's kinetics should be tracked. The default is 0.

        self.size = size      # size of the population
        #tracking variables
        self.track = track
        self.track_ID = track_ID
        if self.track_ID >= self.size:
            raise ValueError('track_ID was set to a value larger than the'
                             + ' size of the neural population.')
        self.tracked_s = []
        # S kinetic parameter
        self._s = zeros([self.size,], dtype='float')      # channel gating
        self._prev_s = copy(self._s)

    def project(self, dt):
        Predicts future channel gating with the assumption of no new firing.

        dt : float
            The change in time since this kinetic was last updated.

            This function is not implemented in the kineticBase.

        s_kinetic : numpy.ndarray(float)
            Predicted channel gatings at dt.
        raise NotImplementedError(
            'Trying to call "project" with the kineticBase class')
    def update(self, dt, spk_times, weights):
        Updates the kinetics based on the population spiking and the 
        connection weights.

        dt : float
            The change in time.
        spk_times : numpy.ndarray(float; spikes)
            When spikes occur during this dt.
        weights : numpy.ndarray(float; spikes x size)
            Matrix of connection weights of the presynaptic neuron to 
            postsynaptic neurons. 

            This function is not implemented in the kineticBase.


        raise NotImplementedError(
            'Trying to call "update" with the kineticBase class')
    def kinetic(self, dt):
        Returns the kinetic value at the beginning or end of the dt.
        Should only be called after update() has been called.

        dt : float
            The change in time since this kinetic was last updated.

        s_kinetic : numpy.ndarray(float)
            Channel gating by neuron at dt.
        # Returns the kinetic value at the start of this time step
        if dt == 0: return copy(self._prev_s)
        return copy(self._s)
    def reset(self):
        Resets the s kinetic and tracking for a new run.

        self._s = zeros([self.size,], dtype='float')
        self._prev_s = copy(self._s)
        self.tracked_s = []

# -------------------- First Order Kinetics -----------------------------
class firstOrdKin(kineticBase):
    A first-order kinetic that increases by alpha with every spike, 
    is bounded, and proportionately decreases. Implements equation 6 
    from [1]. This is a subclass of kineticBase.
    [1] Wang, X.J. (1999). Synaptic Basis of Cortical Persistent 
        Activity:  the Importance of NMDA Receptors to Working Memory. 
        The Journal of Neuroscience, 19(21), 9587–9603. 
        doi: 10.1523/JNEUROSCI.19-21-09587.1999
    size : int
        The size of the neural population.
    track : boolean
        Whether one of the kinetics should be tracked.
    track_ID : int
        The index of the neuron whose kinetic is being tracked.
    tracked_s : list(float)
        If track is True, the s kinetic value during the experiment.
    alpha_s : float
        Multiplier to increase the s kinetic by each spike.
    tau_s : float
        Average life expectancy of the s kinetic.
    __init__(size, alpha, tau, **kwargs)
        Initializes FirstOrdKin
        Determines the lost proportion of the kinetic.
        Predicts future channel gating.
    update(dt, spk_times, weights)
        Updates the kinetics.
        Gets the kinetic value.
        Resets the kinetics and tracking.
    def __init__(self, size, alpha_s, tau_s, **kwargs):
        Initalize FirstOrdKin.  

        size : int
            Size of the neural population.
        alpha_s : float
            Multiplier to increase the s kinetic by each spike.
        tau_s : float
            Average life expectancy of the s kinetic (in ms).
        **kwargs : dict
            Keyword arguments. Accepts the following kwargs:
                track (boolean) 
                track_ID (int) 


        super().__init__(size, **kwargs)

        self.alpha_s = alpha_s
        self.tau_s = tau_s
    def _loss_prop(self, dt): 
        Proportion of the s kinetic that was lost.

        dt : float
            Time that the kinetics for.
        tau : float
            Average life expectancy of the kinetic.

        prop : float
            The proportion of the kinetic that was lost.

        return (1 - dt/self.tau_s)
    # quickly predicts the channel gating
    def project(self, dt): 
        Predicts future channel gating with the assumption of no new firing.

        dt : float
            The change in time since this kinetic was last updated.

        s_kinetic : numpy.ndarray(float)
            Predicted channel gatings at dt.
        if dt == 0: return copy(self._s)
        # proportion after decay
        return self._s * self._loss_prop(dt)
    # updates the kinetic for this time step
    def update(self, dt, spk_times, weights):
        Updates the kinetics based on the population spiking and the 
        connection weights.

        dt : float
            The change in time.
        spk_times : numpy.ndarray(float; spikes)
            When spikes occur during this dt.
        weights : numpy.ndarray(float; spikes x size)
            Matrix of connection weights of the presynaptic neuron to 
            postsynaptic neurons. 


        _time = 0
        self._prev_s = copy(self._s)
        # Processes Equation 6 in a stepwise manner.
        for index, spk_t in enumerate(spk_times):
            self._s *= self._loss_prop(spk_t - _time)
            self._s += self.alpha_s * weights[index] * (1 - self._s)
            _time = spk_t
        # Decays the rest of the time.
        self._s *= self._loss_prop(dt - _time)
        # tracks the kinetic of a single neuron if instructed to do so
        if self.track: self.tracked_s.append(self._s[self.track_ID])

class firstOrdKinExp(firstOrdKin):
    A first-order kinetic that increases by alpha with every spike, 
    is unbounded, and exponentially decreases [2]. This is a subclass 
    of FirstOrdKin.
    [2] Compte, A., Brunel, N., Goldman-Rakic, P.S., & Wang, X.J. (2000). 
        Synaptic mechanisms and network dynamics underlying spatial 
        working memory in a cortical network model. Cerebral Cortex, 10,
        910-923. doi: 10.1093/cercor/10.9.910
    size : int
        The size of the neural population.
    track : boolean
        Whether one of the kinetics should be tracked.
    track_ID : int
        The index of the neuron whose kinetic is being tracked.
    tracked_s : list(float)
        If track is True, the s kinetic value during the experiment.
    alpha_s : float
        Multiplier to increase the s kinetic by each spike.
    tau_s : float
        Average life expectancy of the s kinetic.
    __init__(size, alpha, tau, **kwargs)
        Initializes FirstOrdKinExp
        Determines the lost proportion of the kinetic.
        Predicts future channel gating.
    update(dt, spk_times, weights)
        Updates the kinetics.
        Gets the kinetic value.
        Resets the kinetics and tracking.
    def __init__(self, size, alpha_s, tau_s, **kwargs):
        Initalize FirstOrdKinExp.  

        size : int
            Size of the neural population.
        alpha_s : float
            Multiplier to increase the s kinetic by each spike.
        tau_s : float
            Average life expectancy of the s kinetic (in ms).
        **kwargs : dict
            Keyword arguments. Accepts the following kwargs:
                track (boolean) 
                track_ID (int) 


        super().__init__(size, alpha_s, tau_s, **kwargs)
    def _loss_prop(self, dt): 
        Proportion of the s kinetic that was lost.

        dt : float
            Time that the kinetics for.
        tau : float
            Average life expectancy of the kinetic.

        prop : float
            The proportion of the kinetic that was lost.

        return exp(-dt/self.tau_s)
    # updates the kinetic for this time step
    def update(self, dt, spk_times, weights):
        Updates the kinetics based on the population spiking and the 
        connection weights.

        dt : float
            The change in time.
        spk_times : numpy.ndarray(float; spikes)
            When spikes occur during this dt.
        weights : numpy.ndarray(float; spikes x size)
            Matrix of connection weights of the presynaptic neuron to 
            postsynaptic neurons. 


        _time = 0
        self._prev_s = copy(self._s)
        for index, spk_t in enumerate(spk_times):
            self._s *= self._loss_prop(spk_t - _time)
            self._s += self.alpha_s * weights[index]                
            _time = spk_t
        # Decays the rest of the time.
        self._s *= self._loss_prop(dt - _time)
        # tracks the kinetic of a single neuron if instructed to do so.
        if self.track: self.tracked_s.append(self._s[self.track_ID])

class secondOrdKin(kineticBase):
    A second-order kinetic that increases by alpha with every spike, 
    is bounded, and proportionately decreases. Implements equations 
    4 and 5 from [1]. This is a subclass of kineticBase.
    [1] Wang, X.J. (1999). Synaptic Basis of Cortical Persistent 
        Activity:  the Importance of NMDA Receptors to Working Memory. 
        The Journal of Neuroscience, 19(21), 9587–9603. 
        doi: 10.1523/JNEUROSCI.19-21-09587.1999
    size : int
        The size of the neural population.
    track : boolean
        Whether one of the kinetics should be tracked.
    track_ID : int
        The index of the neuron whose kinetic is being tracked.
    tracked_s : list(float)
        If track is True, the s kinetic value during the experiment.
    alpha_x : float
        Multiplier to increase the x kinetic by each spike.
    alpha_s : float
        Rate that the x kinetic influences the s kinetic.
    tau_x : float
        Average life expectancy of the x kinetic (in ms).
    tau_s : float
        Average life expectancy of the s kinetic (in ms).
    __init__(size, alpha_x, alpha_s, tau_x, tau_s **kwargs)
        Initializes SecondOrdKin
        Change in the s kinetic over time.
        Predicts future channel gating.
    update(dt, spk_times, weights)
        Updates the kinetics.
        Gets the kinetic value.
        Resets the kinetics and tracking.
    def __init__(self, size, alpha_x, alpha_s, tau_x, tau_s, **kwargs):
        Initialize SecondOrdKin.

        size : int
            Size of the neural population.
        alpha_x : float
            Multiplier to increase the x kinetic by each spike.
        alpha_s : float
            Rate that the x kinetic influences the s kinetic.
        tau_x : float
            Average life expectancy of the x kinetic (in ms).
        tau_s : float
            Average life expectancy of the s kinetic (in ms).
        **kwargs : dict
            Keyword arguments. Accepts the following kwargs:
                track (boolean) 
                track_ID (int) .


        super().__init__(size, **kwargs)
        self.alpha_x = alpha_x                             
        self.tau_x = tau_x                             
        self.alpha_s = alpha_s
        self.tau_s = tau_s
        # spike kinetic
        self._x = zeros([self.size,], dtype='float') 
        self._prev_x = copy(self._x)
        # tracked values
        self.tracked_x = []
    def reset(self):
        Resets the s and x kinetics and tracking for a new run.

        self._x = zeros([self.size,], dtype='float')
        self._prevX = copy(self._s)
        self.trackedX = []
    # project the channel gating into the future
    def project(self, dt):
        Predicts future channel gating with the assumption of no new firing.

        dt : float
            The change in time since this kinetic was last updated.

        s_kinetic : numpy.ndarray(float)
            Predicted channel gatings at dt.
        if dt == 0: 
            return copy(self._s)
            return (self._s + self._dsdt(dt))
    def _dsdt(self, dt):
        Change in the s kinetic over time.

        dt : float
            Time since the kinetics were last updated.

        dsdt : np.array(float)
            Change in the s kinetic.

        return (dt * (self.alpha_s 
                      * self._x 
                      * (1 - dt/self.tau_x) 
                      * (1 - self._s) 
                      - self._s/self.tau_s
    def update(self, dt, spk_times, weights):
        Updates the kinetics based on the population spiking and the 
        connection weights.

        dt : float
            The change in time.
        spk_times : numpy.ndarray(float; spikes)
            When spikes occur during this dt.
        weights : numpy.ndarray(float; spikes x size)
            Matrix of connection weights of the presynaptic neuron to 
            postsynaptic neurons. 


        _time = 0
        _d_dur = 0
        self._prev_x = copy(self._x)
        self._prev_s = copy(self._s)        
        for index, spk_t in enumerate(spk_times):
            _d_dur = spk_t - _time
            self._s += self._dsdt(_d_dur)
            self._x *= (1 - _d_dur/self.tau_x)
            self._x += self.alpha_x * weights[index]
            _time = spk_t
        # Decays the rest of the time.
        _d_dur = dt - _time
        self._s += self._dsdt(_d_dur)
        self._x *= (1 - _d_dur/self.tau_x)
        # tracks the kinetics of a single neuron if instructed to do so.
        if self.track: 