#!/usr/bin/env python
'''
Long-term synaptic dynamics
Costa et al 2015
:: Figure 3 ::
:: Savings simulation (with two Guassian input profiles) ::
'''
import matplotlib as mpl
mpl.use('TkAgg')
from brian import *
from time import time
from numpy import *
import matplotlib.pyplot as plt
import matplotlib.pylab as plab
exportOn = 1
plotOn = -1 #-1: none, 0: all, 1: only for input1, 2: only for input2
nruns = 1; #Number of runs
for nrun in range(0,nruns):
close('all')
#Reset
#if 'neurons' in vars():
# reinit()
realtime = 0
stime = 1 * second
stime2 = 50 * second
extractFiringRatesOn = 1
resolution_export = 10; # every x ms
N = 100
taum = 10 * ms
Ee = 0 * mV
taue = 2 * ms
Fon = 50 * Hz
Foff = 3 * Hz
#s = 55.0000e-10
s = 100.0000e-10
Amax = 2.
Amin = 0
Ainit = 0.1
Umax = 1.
Umin = 0
Uinit = 0.1
dFBn = 0
dFBp = 0
dFFp = 0
#Short-term plasticity params
tau_u = 50 * ms
tau_r = 200 * ms
#prepostSTDP params: AFBn tau_FBn AFBp tau_FBp AFFp tau_FFp
params = [0.1771, 0.0327, 0.1548, 0.2302, 0.0618, 0.0666];
AFBn = params[0]
tau_FBn = params[1]*1e3 * ms
AFBp = params[2]
tau_FBp = params[3]*1e3 * ms
AFFp = params[4]
tau_FFp = params[5]*1e3 * ms
#etaU = 0.35
etaU = 0.15
etaA = 0.15
#etaA = 0.35
defaultclock.dt = 1*ms
# Adex Parameters
C = 281 * pF
gL = 30 * nS
taum = C / gL
EL = -70.6 * mV
DeltaT = 2 * mV
vti = -50.4 * mV
#vtrest = vti + 5 * DeltaT
vtrest = -45 * mV
VTmax = 18 * mV
tauvt = 50 * ms
tauw, c, b, Vr = 144 * ms, 4 * nS, 0.0805 * nA, -70.6 * mV # Regular spiking (as in the paper)
eqs_neuron = """
dvm/dt=(gL*(EL-vm)+gL*DeltaT*exp((vm-vt)/DeltaT)+I-x)/C : volt
dvt/dt=-(vt-vtrest)/tauvt : volt
dx/dt=(c*(vm-EL)-x)/tauw : amp #In the standard formulation x is w
I : amp
"""
input1_pos = 25
input2_pos = 75
rad = 5
#Define input 1
F_input1 = ones(N)*Foff
#F_input1[input1_pos-rad:input1_pos+rad] = Fon
for i in range(0,N): #Define gaussian input
F_input1[i] = exp(-((((i+1)-input1_pos)**2)/(2.0*rad**2)))*(Fon-Foff)+Foff;
#Define input 2
F_input2 = ones(N)*Foff
#F_input2[input2_pos-rad:input2_pos+rad] = Fon
for i in range(0,N): #Define gaussian input
F_input2[i] = exp(-((((i+1)-input2_pos)**2)/(2.0*rad**2)))*(Fon-Foff)+Foff;
input = PoissonGroup(N, rates=F_input1)
neurons = NeuronGroup(1, model=eqs_neuron, threshold='vm>vt', reset="vm=Vr;x+=b;vt=VTmax", freeze = True)
neurons.vt = vtrest
neurons.vm = EL
neurons.I = 0
neurons.x = 0
model='''w : 1
FFp : 1
FBp : 1
FBn : 1
R : 1
u : 1
U : 1
A : 1
dFFp/dt=-FFp/tau_FFp : 1 (event-driven)
dFBp/dt=-FBp/tau_FBp : 1 (event-driven)
dFBn/dt=-FBn/tau_FBn : 1 (event-driven)
dR/dt=(1-R)/tau_r : 1 (event-driven)
du/dt=(U-u)/tau_u : 1 (event-driven)
'''
syn = Synapses(input, neurons, model, pre='''I=s*A*R*u;
U=clip(U+etaU*(-AFBn*FBn*FBp + AFBp*FBp*FFp),Umin,Umax);
w=U*A;
FFp+=1; R-=R*u; u+=U*(1-u)''',
post='''A=A+etaA*(AFFp*FFp*FBn);
A=A-etaA*0.5*mean(AFFp*FFp*FBn);
A=clip(A,Amin,Amax);
w=U*A;
FBp+=1.;FBn+=1.''')
#syn.connect_one_to_one(input, neurons)
syn[:,:]=True
syn.FBp=0
syn.FBn=0
syn.R=1
#syn.U='rand()*Uinit'
#syn.A='rand()*Ainit'
#syn.U[:]=Umin
#syn.U[:]=0.5
for i in range(0,size(syn.U[:])): #Define gaussian input
syn.U[i] = exp(-((((i+1)-input1_pos)**2)/(2.0*(rad+0)**2)))*(Umax-Umin)+Umin;
syn.A[i] = exp(-((((i+1)-input1_pos)**2)/(2.0*(rad+3)**2)))*(Amax-Amin)+Amin;
# syn.U[input1_pos-rad:input1_pos+rad]=Umax
#syn.A[:]=Amin
#syn.A[input1_pos-rad-3:input1_pos+rad+3]=Amax
Si = SpikeMonitor(input)
So = SpikeMonitor(neurons)
Mpost = MultiStateMonitor(neurons, record=True)
Mrate = PopulationRateMonitor(neurons,bin=100*ms)
Msyn = MultiStateMonitor(syn, record=True)
synU = StateMonitor(syn, 'U', record=True)
synA = StateMonitor(syn, 'A', record=True)
#Mstdp_post = MultiStateMonitor(stdp.post_group, record=True)
#Msyn = StateMonitor(synapses, 'W', record=True)
ref = 200
if(realtime):
ion()
subplot(411)
raster_plot(Si, refresh=ref*ms, showlast=stime2*3+stime)
subplot(412)
raster_plot(So, refresh=ref*ms, showlast=stime2*3+stime)
plab.ylim([-0.5,0.5])
subplot(413)
synU.plot([input1_pos, 50, input2_pos], refresh=ref*ms, showlast=stime2*3+stime)
plab.ylim([-0.05, 1.05])
subplot(414)
synA.plot([input1_pos, 50, input2_pos], refresh=ref*ms, showlast=stime2*3+stime)
plab.ylim([-0.05, 3.05])
show()
start_time = time()
run(stime)
#input1.rate = 0.1
input.rate = F_input2
run(stime2)
input.rate = F_input1
run(stime2)
input.rate = F_input2
run(stime2)
print "Simulation time:", time() - start_time
#G = NeuronGroup(...)
#spikemon = SpikeMonitor(G)
#statemon = StateMonitor(G, 'V', record=range(5))
#subplot(211)
#raster_plot(spikemon, refresh=10*ms, showlast=200*ms)
#subplot(212)
#statemon.plot(refresh=10*ms, showlast=200*ms)
#run(1*second)
if plotOn>=0:
plt.figure()
ion()
subplot(411)
raster_plot(Si)
subplot(412)
raster_plot(So)
subplot(413)
plot(syn.U[:], '.')
subplot(414)
plot(syn.A[:], '.')
show()
plt.figure()
i1 = 25
pFBn, = plt.plot(Msyn['FBn'].times, Msyn['FBn'][i1,:])
pFBp, = plt.plot(Msyn['FBp'].times, Msyn['FBp'][i1,:])
pPreLTD, = plt.plot(Msyn['FBn'].times, AFBn*Msyn['FBn'][i1,:]*Msyn['FBp'][i1,:])
pPreLTP, = plt.plot(Msyn['FBp'].times, AFBp*Msyn['FBp'][i1,:]*Msyn['FFp'][i1,:])
pFFp, = plt.plot(Msyn['FFp'].times, Msyn['FFp'][i1,:])
#pu, = plt.plot(Msyn['u'].times, Msyn['u'][i,:])
#pR, = plt.plot(Msyn['R'].times, Msyn['R'][i,:])
plt.legend([pFFp, pFBn, pFBp, pPreLTD, pPreLTP], ['FFp', 'FBn', 'FBp', 'preLTP', 'preLTD'])
ion()
show()
plt.figure()
i1 = input1_pos
i2 = input2_pos
pu1, = plt.plot(Msyn['U'].times, Msyn['U'][i1,:])
pA1, = plt.plot(Msyn['A'].times, Msyn['A'][i1,:])
pu2, = plt.plot(Msyn['U'].times, Msyn['U'][i2,:])
pA2, = plt.plot(Msyn['A'].times, Msyn['A'][i2,:])
plt.legend([pu1, pA1, pu2, pA2], ['U1', 'A1', 'U2', 'A2'])
ion()
show()
plt.figure()
pv, = plt.plot(Mpost['vm'].times, Mpost['vm'][0,:])
pvt, = plt.plot(Mpost['vt'].times, Mpost['vt'][0,:])
#pge, = plt.plot(Mpost['ge'].times, Mpost['ge'][0,:])
plt.legend([pv, pvt], ['Vm', 'vt'])
plt.show()
#Plot firing rate
rates = Mrate.smooth_rate(width=1000*ms,filter='gaussian')
if extractFiringRatesOn==0:
plt.figure()
pv, = plt.plot(Mrate.times, rates)
plt.show()
if exportOn:
#Export results to be plotted in matlab
import os as os
path = 'fromBrian/'
if os.path.exists(path + 'outParams.br'):
os.remove(path + 'outParams.br')
if os.path.exists(path + 'outU_run' + str(nrun) + '.br'):
os.remove(path + 'outU_run' + str(nrun) + '.br')
if os.path.exists(path + 'outA_run' + str(nrun) + '.br'):
os.remove(path + 'outA_run' + str(nrun) + '.br')
#f_handle = file(filename, 'a')
savetxt(path + 'outParams.br', [N, resolution_export, stime, stime2, input1_pos, input2_pos, rad, nruns, Amax], fmt='%f', newline='\n') # Number of postsynaptic neurons
savetxt(path + 'outU_run' + str(nrun) + '.br', Msyn['U'][:,::resolution_export], fmt='%f', newline='\n') # Save Us
savetxt(path + 'outA_run' + str(nrun) + '.br', Msyn['A'][:,::resolution_export], fmt='%f', newline='\n') # Save As
#savetxt(f_handle, Msyn['U'][:,0:2], fmt='%f', newline='\n') # Save Us
#savetxt(f_handle, Msyn['A'][:,0:2], fmt='%f', newline='\n') # Save As
#f_handle.close()
if extractFiringRatesOn:
#Extract postsynaptic firing rate for input1 and input2
Usim = Msyn['U'][:,:];
Asim = Msyn['A'][:,:];
post_nspikes1 = So.nspikes
forget(syn)
reinit()
neurons.vt = vtrest
neurons.vm = EL
neurons.I = 0
neurons.x = 0
modelAfter='''R : 1
u : 1
w : 1
U : 1
A : 1
dR/dt=(1-R)/tau_r : 1 (event-driven)
du/dt=(U-u)/tau_u : 1 (event-driven)
'''
synAfter=Synapses(input, neurons, modelAfter, pre='''I=s*A*R*u;
w=U*A;
R-=R*u; u+=U*(1-u)''',
post='''w=U*A''')
synAfter[:,:]=True
synAfter.R=1
synAfter.u=Uinit
synAfter.U=Umin
synAfter.U[input1_pos-rad:input1_pos+rad]=Umax
synAfter.A[:]=Amin
synAfter.A[input1_pos-rad:input1_pos+rad]=Amax
Uaux = synAfter.U[:]
Uaux[:] = Umin
Uaux[input2_pos-rad:input2_pos+rad]=Umax
Aaux = synAfter.A[:]
Aaux[:] = Amin
Aaux[input2_pos-rad:input2_pos+rad]=Amax
@network_operation
def loadUandA(clock):
synAfter.U[:] = Usim[:,clock.t/clock.dt]
synAfter.A[:] = Asim[:,clock.t/clock.dt]
#synAfter.U[:] = Uaux[:]
#synAfter.A[:] = Aaux[:]
#INPUT 1
MsynAfter = MultiStateMonitor(synAfter, record=True)
Mpost = MultiStateMonitor(neurons, record=True)
start_time = time()
F_input1n = F_input1
F_input1n[F_input1n<(Fon/2)] = 0
input.rate = F_input1n
#run(stime+stime2)
run(stime+stime2*3)
rates_1 = Mrate.smooth_rate(width=1000*ms,filter='gaussian')
print "Simulation time (for input1 alone):", time() - start_time
#INPUT 2
post_nspikes2 = So.nspikes
reinit()
neurons.vt = vtrest
neurons.vm = EL
neurons.I = 0
neurons.x = 0
Mpost = MultiStateMonitor(neurons, record=True)
start_time = time()
F_input2n = F_input2
F_input2n[F_input2n<(Fon/2)] = 0
input.rate = F_input2
run(stime+stime2*3)
rates_2 = Mrate.smooth_rate(width=1000*ms,filter='gaussian')
print "Simulation time (for input2 alone):", time() - start_time
post_nspikes3 = So.nspikes
if plotOn>=1:
'''
plt.figure()
ion()
subplot(411)
raster_plot(Si)
subplot(412)
raster_plot(So)
subplot(413)
plot(synAfter.U[:], '.')
subplot(414)
plot(synAfter.A[:], '.')
show()
plt.figure()
i1 = 25
i2 = 75
pu1, = plt.plot(MsynAfter['U'].times, MsynAfter['U'][i1,:])
pA1, = plt.plot(MsynAfter['A'].times, MsynAfter['A'][i1,:])
pu2, = plt.plot(MsynAfter['U'].times, MsynAfter['U'][i2,:])
pA2, = plt.plot(MsynAfter['A'].times, MsynAfter['A'][i2,:])
plt.legend([pu1, pA1, pu2, pA2], ['U1', 'A1', 'U2', 'A2'])
ion()
show()
plt.figure()
pv, = plt.plot(Mpost['vm'].times, Mpost['vm'][0,:])
pvt, = plt.plot(Mpost['vt'].times, Mpost['vt'][0,:])
#pge, = plt.plot(Mpost['ge'].times, Mpost['ge'][0,:])
plt.legend([pv, pvt], ['Vm', 'vt'])
plt.show()
'''
#Plot firing rate
plt.figure()
prates, = plt.plot(Mrate.times, rates)
prates_1, = plt.plot(Mrate.times, rates_1)
prates_2, = plt.plot(Mrate.times, rates_2)
plt.legend([prates, prates_1, prates_2], ['Learning', 'Input_1', 'Input_2'])
plt.show()
print "Nspikes before: ", post_nspikes1, "| Nspikes after (Input1): ", post_nspikes2, "| Nspikes after (Input2): ", post_nspikes3
if exportOn:
if os.path.exists(path + 'rateInput1_run' + str(nrun) + '.br'):
os.remove(path + 'rateInput1_run' + str(nrun) + '.br')
if os.path.exists(path + 'rateInput2_run' + str(nrun) + '.br'):
os.remove(path + 'rateInput2_run' + str(nrun) + '.br')
savetxt(path + 'rateInput1_run' + str(nrun) + '.br', rates_1, fmt='%f', newline='\n') # Save post spike times
savetxt(path + 'rateInput2_run' + str(nrun) + '.br', rates_2, fmt='%f', newline='\n') # Save post spike times
clear()