from pylab import *

# Neuron model parameters
th = 10.        # V_peak
g_L = 30.       # [nS]
V_T = -50.4     # [mV]
Delta_T = 2.    # [mV]
tau_w = 144.    # [ms]
a = 4.          # [nS]
b = 0.0805      # [nA]
eps = 3.        # little current inject from the presynpatic spike: conjunction of many presynaptic spikes
E_L = -70.6     # [mV]                resting pot
C = 281.        # [pF]                neuron capacitance

# adaptive exponential (AdEx) neuron model
def aEIF(u,w,I):
    if u==20:
        u = E_L
        w = w+b
    u += 1/C*(-g_L*(u-E_L) + g_L*Delta_T*exp((u-V_T)/Delta_T) - w + I)
    if u>th:
        u = 20
    w += 1/tau_w*(a*(u-E_L) - w)
    return u,w

# spiketrain generator
def sp_tr(pathw,burst,pulse,freq,ibi):#[Hz],[ms]
    p = 1000/freq
    T = burst*p*(pulse-1)+(burst-1)*ibi+1
    sptr = zeros((Np,T+100),int)
    for i in range(burst):
        sptr[pathw,i*((pulse-1)*p+ibi):(i+1)*(pulse-1)*p+i*ibi+1:p] = 1
    return sptr

# triplet voltage (TriVo) rule parameters
A_ltd = 1e-2                 # amplitude for depression
A_ltp = 1e-3                 # amplitude for potentiation
tau_ltp = 100.               # time constant low pass of the membrane pot [ms] (pot part)
theta_ltd = E_L              # threshold for depression
theta_ltp = -50.             # threshold for potentiation
tau_ltd = 1000.              # time constant low pass of the membrane pot [ms] (dep part)
tau_x = 100.                 # time constant low pass r [ms]

# dynamics during stimulus
def e_trace(h,l,z,p,x,t):# initial{h,l,z,p,rho_h,rho_l}, spike train, start time

    # inizalization
    I = 0.                       # current due to presynpatic spikes
    wad = 0.                     # adaptation
    v = E_L                      # membrane pot
    u_m1 = E_L                   # filtered membrane pot 1
    u_m2 = E_L                   # filtered membrane pot 2
    u_m2_sig = 0.
    u_m1_sig = 0.
    r = zeros(Np,float)          # low-pass x

    # update for the weight
    for i in range(len(x[0])):
        I = sum(eps*(array([sum((g_h*h+g_l*l+g_z*z)[0:N]),sum((g_h*h+g_l*l+g_z*z)[N:2*N])])+wo)*x[:,i]*C) # eps*w*...
        v,wad = aEIF(v, wad, I)
        u_sig = (v>theta_ltp)*(v-theta_ltp)
        rho_l = A_ltd*x[:,i]*u_m1_sig # depression rate
        rho_h = A_ltp*u_sig*r*u_m2_sig # potentiation rate
        r += 1./tau_x*(x[:,i]-r)
        u_m1 += 1./tau_ltd*(v-u_m1)
        u_m1_sig = (u_m1 > theta_ltd)*(u_m1-theta_ltd)
        u_m2 += 1./tau_ltp*(v-u_m2)
        u_m2_sig = (u_m2 > theta_ltd)*(u_m2-theta_ltd)
        if rand()>0.5: # randomly starts with potentiation or depression
            h[0:N] = (1-h[0:N])*(1+l[0:N])*(rand(N)<rho_h[0])+h[0:N]*(rand(N)>k_h/1000.)
            h[N:2*N] = (1-h[N:2*N])*(1+l[N:2*N])*(rand(N)<rho_h[1])+h[N:2*N]*(rand(N)>k_h/1000.)
            l[0:N] = (-1-l[0:N])*(1-h[0:N])*(rand(N)<rho_l[0])+l[0:N]*(rand(N)>k_l/1000.)
            l[N:2*N] = (-1-l[N:2*N])*(1-h[N:2*N])*(rand(N)<rho_l[1])+l[N:2*N]*(rand(N)>k_l/1000.)
        else:
            l[0:N] = (-1-l[0:N])*(1-h[0:N])*(rand(N)<rho_l[0])+l[0:N]*(rand(N)>k_l/1000.)
            l[N:2*N] = (-1-l[N:2*N])*(1-h[N:2*N])*(rand(N)<rho_l[1])+l[N:2*N]*(rand(N)>k_l/1000.)
            h[0:N] = (1-h[0:N])*(1+l[0:N])*(rand(N)<rho_h[0])+h[0:N]*(rand(N)>k_h/1000.)
            h[N:2*N] = (1-h[N:2*N])*(1+l[N:2*N])*(rand(N)<rho_h[1])+h[N:2*N]*(rand(N)>k_h/1000.)
        p += -p/tau_pd/1000.+(1-p)/tau_pm/1000.*(1-1*(tp1<t+int(i/1000.)<tp2))*(sum(h-l)>theta_p)
        z += 1./tau_z/1000.*(z*(1-z)*(z-kappa)+gamma*(h+l)*p)
        h_fin[:,t+int(i/1000.)] = h
        l_fin[:,t+int(i/1000.)] = l
        z_fin[:,t+int(i/1000.)] = z
        p_fin[t+int(i/1000.)] = p
        if (i+1)/6e4==int((i+1)/6e4):
            print '%i [min]/%i [min]' %(int((i+1)/6e4),int(len(x[0])/6e4))
    return h,l,z,p

# maintenance rule parameters
gamma = 0.1         # shift of bistable f-curve
kappa = 0.5         # unstable fix pt
k_h = 1./3600.      # decay rate of h [s]^-1
k_l = 1./5400.      # decay rate of l [s]^-1
tau_z = 360.        # l-ltp time cte [s]
tau_pm = 360.       # protein time cte [s]
tau_pd = 3600.      # protein time cte [s]

theta_p = 60        # protein-production threshold
tp1 = 1             # block protein start
tp2 = 0             # block protein stop

# weight parameters
g_l = 0.5           # l-weight
g_h = 1.            # h-weight
g_z = 2.            # z-weight
wo = 1.             # residual weight
iu = 0.3            # initial up ratio

# dynamics without external input
def z_trace(h,l,z,p,T,t):# initial{h,l,z,p,rho_h,rho_l}, length, start time
    for i in range(T):
        if rand()>0.5:
            h[0:N] = h[0:N]*(rand(N)>k_h)
            h[N:2*N] = h[N:2*N]*(rand(N)>k_h)
            l[0:N] = l[0:N]*(rand(N)>k_l)
            l[N:2*N] = l[N:2*N]*(rand(N)>k_l)
        else:
            l[0:N] = l[0:N]*(rand(N)>k_l)
            l[N:2*N] = l[N:2*N]*(rand(N)>k_l)
            h[0:N] = h[0:N]*(rand(N)>k_h)
            h[N:2*N] = h[N:2*N]*(rand(N)>k_h)
        p += -p/tau_pd+(1-p)/tau_pm*(1-1*(tp1<t+int(i/1000.)<tp2))*(sum(h-l)>theta_p)
        z += 1./tau_z*(z*(1-z)*(z-kappa)+gamma*(h+l)*p)
        h_fin[:,t+i] = h
        l_fin[:,t+i] = l
        z_fin[:,t+i] = z
        p_fin[t+i] = p
    return h,l,z,p

# dictionary for protocols = [#burst,#pulse,freq [Hz],interburst interval [s]]
dic = {'wtet':[1,21,100,0],'stet':[3,100,100,600],'wlfs':[1,900,1,0],'slfs':[900,3,20,1],'nothing':[0,0,1,0],}

# parameters for simulation
N = 100                       # synapse per pathway
Np = 2                        # pathway
L = 5*3600                    # simulation time [s]
h_fin = zeros((Np*N,L),int)   # record variables
l_fin = zeros((Np*N,L),int)
z_fin = zeros((Np*N,L),float)
p_fin = zeros(L,float)

# simulation
def sim(prot1,prot2,wplot): # 1st protocol, 2nd protocol (from dictionary), what plot [w1,w2,hl,p]
    print 'starting %s protocol' %prot1
    # initialization
    h = zeros(Np*N,int)
    l = zeros(Np*N,int)
    z = zeros(Np*N,float)
    z_in = zeros(N,float)
    z_in[:int(iu*N)] = 1
    for i in range(2):
        z[i*N:(i+1)*N]=permutation(z_in)
    p = 0.
    bu = 1 # burst variable, 1=yes
    t = 0 # [s]
    tt = 2*30*60+10 # second tet/lfs start time [s]
    # simulation
    h,l,z,p = z_trace(h,l,z,p,30*60,t)
    t += 30*60
    for i in range(2*((dic[prot1][3]>100)*dic[prot1][0]+(dic[prot1][3]<100))-1): # long ibi->2*nb burst-1; short ibi->1
        if bu==1:
            if dic[prot1][0]==1 or dic[prot1][3]>100: # spiketrain only as long as one burst
                sptr = sp_tr(0,1,dic[prot1][1],dic[prot1][2],0)
            else: # several bursts AND short ibi->one long spiketrain
                sptr = sp_tr(0,dic[prot1][0],dic[prot1][1],dic[prot1][2],1000*dic[prot1][3])
            print 'burst n. ',1+(i+1)/2
            h,l,z,p = e_trace(h,l,z,p,sptr,t)
            t += int(len(sptr[0])/1000.)
        else:
            print 'interburst'
            h,l,z,p = z_trace(h,l,z,p,dic[prot1][3],t)
            t += dic[prot1][3]
        bu = 1-bu
    h,l,z,p = z_trace(h,l,z,p,tt-t,t)
    print 'starting %s protocol' %prot2
    bu = 1-bu
    t = tt
    for i in range(2*((dic[prot2][3]>100)*dic[prot2][0]+(dic[prot2][3]<100))-1): # long ibi->2*nb burst-1; short ibi->1
        if bu==1:
            if dic[prot2][0]==1 or dic[prot2][3]>100: # spiketrain only as long as one burst
                sptr = sp_tr(1,1,dic[prot2][1],dic[prot2][2],0)
            else: # several bursts AND short ibi->one long spiketrain
                sptr = sp_tr(1,dic[prot2][0],dic[prot2][1],dic[prot2][2],1000*dic[prot2][3])
            print 'burst n. ',1+(i+1)/2
            h,l,z,p = e_trace(h,l,z,p,sptr,t)
            t += int(len(sptr[0])/1000.)
        else:
            print 'interburst'
            h,l,z,p = z_trace(h,l,z,p,dic[prot2][3],t)
            t += dic[prot2][3]
        bu = 1-bu
    print 'terminating...'
    h,l,z,p = z_trace(h,l,z,p,L-t,t)
    # plotting
    print 'plotting'
    sc = N*(wo+iu*g_z)/100.
    if 'w1' in wplot:
        figure()
        plot(sum(g_z*z_fin[0:N]+g_h*h_fin[0:N]+g_l*l_fin[0:N]+wo,axis=0)/sc)
        plot(sum(h_fin[0:N],axis=0),'--')
        plot(sum(l_fin[0:N],axis=0),'--')
        plot(sum(g_z*z_fin[0:N]+wo,axis=0)/sc,'--')
        legend(('w','h','l','z'))
    if 'w2' in wplot:
        figure()
        plot(sum(g_z*z_fin[N:2*N]+g_h*h_fin[N:2*N]+g_l*l_fin[N:2*N]+wo,axis=0)/sc)
        plot(sum(h_fin[N:2*N],axis=0),'--')
        plot(sum(l_fin[N:2*N],axis=0),'--')
        plot(sum(g_z*z_fin[N:2*N]+wo,axis=0)/sc,'--')
        legend(('w','h','l','z'))
    if 'hl' in wplot:
        figure()
        subplot(131)
        imshow(z_fin[:,::150])
        subplot(132)
        imshow(h_fin[:,::150])
        subplot(133)
        imshow(l_fin[:,::150])
    if 'p' in wplot:
        figure()
        plot(p_fin)

def simoconn(f): # freq
    # initialization
    h = zeros(Np*N,int)
    l = zeros(Np*N,int)
    z = zeros(Np*N,float)
    z_in = zeros(N,float)
    z_in[:int(iu*N)] = 1
    for i in range(2):
        z[i*N:(i+1)*N]=permutation(z_in)
    p = 0.
    bu = 1 # burst variable, 1=yes
    t = 0 # [s]
    for i in range(5): # 2*nb burst-1
        if bu==1:
            sptr = sp_tr(0,1,100,f,0)
            print 'burst n. ',1+(i+1)/2
            h,l,z,p = e_trace(h,l,z,p,sptr,t)
            t += int(len(sptr[0])/1000.)
        else:
            print 'interburst'
            h,l,z,p = z_trace(h,l,z,p,300,t)
            t += 300
        bu = 1-bu
    dw = (sum(g_z*z[0:N]+g_h*h[0:N]+g_l*l[0:N]+wo)-N*(g_z*iu+wo))/(N*(g_z*iu+wo))*100
    print 'dw/w0 = %0.2f'%dw,' [%]'
    return dw