# -*- coding: utf-8 -*-
"""

Perfect IF model with slow adaptation or fractional noise
When run, this script will generate 5 plots with different simulation conditions:
    1) Noise in voltage (V)
    2) Noise in the adaptation (Z)
    3) Noise in both
    4) Noise in adaptation with larger time constant
    5) Fractional Noise with alpha=0.7
    
Plots will contain 4 panels (from top to bottom, then left to right):
    1) Short time course of both variables (V,Z)
    2) Long (500s) inter-spike interval (ISI) plot
    3) Windowed Kolmogorov-Smirnoff test, to visually show stationarity of ISIs
    4) R/S (rescaled range) plot for the whole ISI sequence. Red line is the 
       fit of all points to a power law (straight line in loglog). The slope (H),
       r value and p-value of the fit are indicated. Colored lines are fits of
       small sets of points, in a moving window fashion.
    5) Plot of the moving slopes (colored lines in above plot) against the length
       of the sequence. Red line and blue shadow indicate the mean and 2*SD
       obtained with 100x surrogate data
    6) Autocorrelation plot of the ISI sequence, in log-linear scale
    7) Detrended fluctuation analysis (DFA) plot. Red and colored lines are the
       same as in plot 4
    8) Similar to plot 5, but with the DFA
    9) Autocorrelation plot of the ISI sequence, in log-log scale

@author: porio
"""

import numpy as np
import matplotlib.pyplot as plt
#import hurst
import time
import scipy.stats
import mainPlots
import fBm

def stocV_detz(x):
    #Noise only in voltage dynamics
    v,z=x
    return np.array([mu - gamma*z  + sqDdt*np.random.normal(),
                     (zinf - z)/tauz])

def detV_stochz(x):
    #Noise only in the adaptation variable
    v,z=x
    return np.array([mu - gamma*z,
                     (zinf - z)/tauz + sqDzdt*np.random.normal()])

def stocV_stochz(x):
    #Noise in both voltage and adaptation
    v,z=x
    return np.array([mu - gamma*z + sqDdt*np.random.normal(),
                     (zinf - z)/tauz + sqDzdt*np.random.normal()])

mu = 0.03    #V/ms  base current
gamma = 0.3    #V/ms max adaptation current
tauz = 200  #time constant for slow adaptation = 1/lambdaZ
Dz = 2.5e-6  #  == sigma'  - slow adaptation noise
D = 0.0002   # ==sigma  - current noise intensity, V²/ms  

dt = 0.05   #ms - integration step

def run_vtrace(fun,tstop):
    """
    Function to run a small simulation and obtain the time course of the variables
    
    Parameters
    ----------
    fun : function
        Function to integrate
    tstop : float
        Total time to simulate. Integration time step, dt, is set outside.
        
    Returns
    -------
    X_t : ndarray (npoints,2)
        Time course of voltage (V) and adaptation variable (Z)
    spikes : list
        Times of spikes
    
    """
    global zinf,sqDzdt,sqDdt
    
    points = int(tstop/dt)  #number of points to simulate
     
    v,z = 0,0.01   #Initial conditions
    X=np.array([v,z])
    X_t=np.zeros((points,2))
    i = 0
    sqDzdt=np.sqrt(Dz/dt)
    sqDdt=np.sqrt(D/dt)
    zinf=0
    spikes=[]
    
    while i<points:
        X_t[i]=X
        if zinf==1:
            if i*dt > spikes[-1] + 1:
                zinf =0
        i+=1
        X += dt * fun(X)  #Euler integration scheme
        if X[0]>1:
            spikes.append(i*dt)
            zinf = 1
            X[0]=0

    return X_t,spikes

def run_SpikesOnly(fun,tstop):
    """
    Function to run a long simulation and obtain only the times of spikes
    
    Parameters
    ----------
    fun : function
        Function to integrate
    tstop : float
        Total time to simulate. Integration time step, dt, is set outside.
        
    Returns
    -------
    spikes : list
        Times of spikes
    """
    
    global zinf,sqDzdt,sqDdt
    points = int(tstop/dt)  #number of points to simulate
    
    v,z = 0,0.01   #Initial conditions
    X=np.array([v,z])
    i = 0
    sqDzdt=np.sqrt(Dz/dt)
    sqDdt=np.sqrt(D/dt)
    zinf=0
    spikes=[]
    time0 = time.time()
    
    while i<points:
        if i%500000==0:
            print("time %g of %g - real time %g"%(i*dt,tstop,time.time()-time0))
        
        if zinf==1:
            if i*dt > spikes[-1] + 1:
                zinf =0
        i+=1
        X += dt * fun(X)  #Euler integration scheme
        if X[0]>1:
            spikes.append(i*dt)
            zinf = 1
            X[0]=0

    return spikes

mu_fB = 0.0303    #V/ms  base current
D_fB = 0.01    #V²/ms  noise intensity

def run_vtrace_fBM(tstop,H=0.7):
    """
    Runs a small simulation of the LIF model with fBM and obtains 
    the time course of the variables
    
    Parameters
    ----------
    tstop : float
        Total time to simulate. Integration time step, dt, is set outside.
    H : float
        H exponent of the fractional noise
        
    Returns
    -------
    X_t : ndarray (npoints,2)
        Time course of voltage (V) and adaptation variable (Z)
    spikes : list
        Times of spikes
    
    """
    points = int(tstop/dt)
    
    print("calculating %g points of noise with H=%g"%(points,H))
    noise = fBm.fracNoiseN(points,H)

    X_t=np.zeros((points))
    v = 0
    i = 0
    sqDdt=np.sqrt(2*D/dt)
    spikes=[]

    while i<points:             
        X_t[i]=v     
        v += dt*(mu + sqDdt*noise[i])
        i+=1
    
        if v>=1:
            spikes.append(i*dt)
            v=0
    return X_t,spikes

def run_SpikesOnly_fBM(tstop,H=0.7):
    """
    Runs a small simulation of the LIF model with fBM and obtains 
    the time course of the variables
    
    Parameters
    ----------
    tstop : float
        Total time to simulate. Integration time step, dt, is set outside.
    H : float
        H exponent of the fractional noise
        
    Returns
    -------
    X_t : ndarray (npoints,2)
        Time course of voltage (V) and adaptation variable (Z)
    spikes : list
        Times of spikes
    
    """
    points = int(tstop/dt)
    
    print("calculating %g points of noise with H=%g"%(points,H))
    noise = fBm.fracNoiseN(points,H)

    v = 0
    i = 0
    sqDdt=np.sqrt(2*D/dt)
    spikes=[]
    time0 = time.time()

    while i<points:
        if i%500000==0:
            print("time %g of %g - real time %g"%(i*dt,tstop,time.time()-time0))
                
        v += dt*(mu + sqDdt*noise[i])
        i+=1
    
        if v>=1:
            spikes.append(i*dt)
            v=0
    return spikes


time1=np.arange(0,1000,dt)

#%% Noise in Voltage only
print("Noise in Voltage only: simulating short voltage trace")
trace,_ = run_vtrace(stocV_detz,1000)
print("simulating long ISI sequence...")
spikes = run_SpikesOnly(stocV_detz,500000)

mainPlots.AllPlots(trace,spikes,time1,label='Stochastic V, Deterministic Z')

plt.show()

#%% Noise in Adaptation only
print("Noise in Adaptation only: simulating short voltage trace")
trace,_ = run_vtrace(detV_stochz,1000)
print("simulating long ISI sequence...")
spikes = run_SpikesOnly(detV_stochz,500000)

mainPlots.AllPlots(trace,spikes,time1,label='Deterministic V, Stochastic Z')

plt.show()

#%% Noise in Both
print("Noise in Adaptation and Voltage: simulating short voltage trace")
trace,_ = run_vtrace(stocV_stochz,1000)
print("simulating long ISI sequence...")
spikes = run_SpikesOnly(stocV_stochz,500000)

mainPlots.AllPlots(trace,spikes,time1,label='Stochastic V, Stochastic Z')

plt.show()

#%% Noise in Adaptation with larger time scale
tauz=2000
print("Noise in Adaptation with larger time scale: simulating short voltage trace")
trace,_ = run_vtrace(detV_stochz,1000)
print("simulating long ISI sequence...")
spikes = run_SpikesOnly(detV_stochz,500000)

mainPlots.AllPlots(trace,spikes,time1,label=r'Deterministic V, Stochastic Z, $\tau_Z$=2000')

plt.show()

#%% LIF model with fractional noise
print("LIF with fBm: simulating short voltage trace")
trace,_ = run_vtrace_fBM(1000)
print("simulating long ISI sequence...")
spikes = run_SpikesOnly_fBM(500000)

mainPlots.AllPlots(trace,spikes,time1,label=r'LIF with fBm')

plt.show()