#Copied from stdpsynfire_noNMDA_synstim_vartau.py. Otherwise the same but saves a huge amount of weights - the result files requires a lot of disk space.
from brian2 import *
import numpy as np
import matplotlib.pyplot as plt
import scipy.io
#prefs.codegen.target = 'cython'
start_scope()
timenow = time.time()
N = 10 # How many neurons per population
Lmax = 30 # Length of the single long synfire chain
stim_duration = 12.0 #Duration of stimulus in seconds
tot_duration = 16.0 #Duration of the whole stimulation in seconds
stim_freq_Hz = 2 #Stimulus rate
g_EE = 12.0 #Synapse weights E->E
g_EI = 20.0 #Synapse weights E->I
g_IE = 20.0 #Synapse weights I->E
A_plus_nS = 0.05 # How much the w is added in pre->post when inter-spike-interval approaches 0
A_minus_nS = 0.06 # How much the w is subtracted in post->pre when inter-spike-interval approaches 0
w_init_nS = 0.05 # Initial synaptic weight
A_tau_ms = 20 #Decay constant of the traces
doSave = 1 #Save in mat and PDF
myseed = 1 #Random number seed
if len(sys.argv) > 1:
N = int(float(sys.argv[1]))
if len(sys.argv) > 2:
Lmax = int(float(sys.argv[2]))
if len(sys.argv) > 3:
stim_duration = float(sys.argv[3])
if len(sys.argv) > 4:
tot_duration = float(sys.argv[4])
if len(sys.argv) > 5:
stim_freq_Hz = float(sys.argv[5])
if len(sys.argv) > 6:
g_EE = float(sys.argv[6])
if len(sys.argv) > 7:
g_EI = float(sys.argv[7])
if len(sys.argv) > 8:
g_IE = float(sys.argv[8])
if len(sys.argv) > 9:
A_plus_nS = float(sys.argv[9])
if len(sys.argv) > 10:
A_minus_nS = float(sys.argv[10])
if len(sys.argv) > 11:
w_init_nS = float(sys.argv[11])
if len(sys.argv) > 12:
A_tau_ms = int(float(sys.argv[12]))
if len(sys.argv) > 13:
doSave = int(float(sys.argv[13]))
if len(sys.argv) > 14:
myseed = int(float(sys.argv[14]))
seed(myseed)
seed_addition = '' if myseed == 1 else '_seed'+str(myseed)
tau_addition = '' if A_tau_ms == 20 else '_tau'+str(A_tau_ms)
w_max = g_EE * 10/N * nS
w_init = w_init_nS * nS
# ----------------------------
# Network parameters
# ----------------------------
stim_freq = stim_freq_Hz * Hz # 2 Hz rhythmic drive
T_stim = 1.0 / stim_freq # Interval between cycles
sim_duration = tot_duration * second
drive_off_time = stim_duration * second
# neuron model
C_m = 500 * pF # membrane capacitance
g_L = 10 * nS # leak conductance -> tau_m = C_m / g_L = 20 ms
El = -65*mV
Vr = -65*mV
Vth = -50*mV
# Conductance-based synapse parameters
E_exc = 0*mV
E_inh = -80*mV
tau_exc = 5*ms
tau_inh = 10*ms
eqs = '''
dv/dt = (g_exc*(E_exc - v) + g_inh*(E_inh - v) + g_L*(El - v)) / C_m : volt (unless refractory)
dg_exc/dt = -g_exc / tau_exc : siemens
dg_inh/dt = -g_inh / tau_inh : siemens
'''
# ----------------------------
# Populations
# ----------------------------
base = NeuronGroup(N, eqs, threshold='v>Vth', reset='v=Vr', refractory=3*ms, method='euler', name='base')
base.v = El + randn()*2*mV
# ---------- Inhibitory population for base ----------
N_inh = max(1, N/4)
inhbase = NeuronGroup(N_inh, eqs, threshold='v>Vth', reset='v=Vr', refractory=3*ms, method='euler', name="inhbase")
inhbase.v = El + randn()*2*mV
# Create one long chain: list of subpopulations
chain = []
inhchain = []
for i in range(Lmax):
G = NeuronGroup(N, eqs, threshold='v>Vth', reset='v=Vr', refractory=3*ms, method='euler', name=f'chain_pool_{i}')
G.v = El + randn()*2*mV
chain.append(G)
G = NeuronGroup(N_inh, eqs, threshold='v>Vth', reset='v=Vr', refractory=3*ms, method='euler', name=f'inhchain_pool_{i}')
G.v = El + randn()*2*mV
inhchain.append(G)
# ----------------------------
# External rhythmic stimulus
# ----------------------------
stim_times = np.arange(0.0, float(stim_duration), float(1/stim_freq)) * second
stim = SpikeGeneratorGroup(N, indices=np.repeat(np.arange(N), len(stim_times)), times=np.tile(stim_times, N), name='stim')
#stim = SpikeGeneratorGroup(1, indices=[0 for i in stim_times], times=stim_times, name='stim')
A_stim = g_EE * 10/N * nS
S_stim = Synapses(stim, base, on_pre='g_exc_post += A_stim', name='stim_to_base', method='euler', delay=5*ms)
S_stim.connect(p=0.5)
# ----------------------------
# Feedforward connections
# ----------------------------
w_ff = g_EE * 10/N * nS
delay_per_hop = 5 * ms
ff_syns = []
on_pre_ee = '''
g_exc_post += w_ff
'''
# base -> first pool
S0 = Synapses(base, chain[0], on_pre=on_pre_ee, delay=delay_per_hop, name='base_to_chain0', method='euler')
S0.connect(p=0.5)
ff_syns.append(S0)
# chain[i] -> chain[i+1]
for i in range(Lmax - 1):
S = Synapses(chain[i], chain[i+1], on_pre=on_pre_ee, delay=delay_per_hop, name=f'chain_{i}_to_{i+1}', method='euler')
S.connect(p=0.5)
ff_syns.append(S)
on_pre_ei = '''
g_exc_post += w_ei
'''
on_pre_ie = '''
g_inh_post += w_ie
'''
# E -> I and I -> E synapses (base <-> inh)
w_ei = g_EI * 10/N * nS # base -> inh (excitatory conductance). Scaled inversely with respect to N
w_ie = g_IE * 10/N * nS # inh -> base (inhibitory conductance). Scaled inversely with respect to N
S_ei = Synapses(base, inhbase, on_pre=on_pre_ei, delay=5*ms, name='E_to_I')
S_ei.connect(p=0.5)
S_ie = Synapses(inhbase, base, on_pre=on_pre_ie, delay=5*ms, name='I_to_E')
S_ie.connect(p=0.8)
ei_syns = []
ie_syns = []
for i in range(Lmax):
S = Synapses(chain[i], inhchain[i], on_pre='g_exc_post += w_ei', delay=5*ms, name=f'chain_{i}_to_I', method='euler')
S.connect(p=0.5)
ei_syns.append(S)
S = Synapses(inhchain[i], chain[i], on_pre='g_inh_post += w_ie', delay=5*ms, name=f'I_to_chain_{i}', method='euler')
S.connect(p=0.8)
ie_syns.append(S)
# ----------------------------
# STDP feedback connections: from each chain pool -> base
# ----------------------------
tau_pre = A_tau_ms*ms
tau_post = A_tau_ms*ms
A_plus = A_plus_nS * nS
A_minus = A_minus_nS * nS
stdp_syns = []
for i, pool in enumerate(chain):
model = '''
w : siemens
dpre/dt = -pre/tau_pre : siemens (clock-driven)
dpost/dt = -post/tau_post : siemens (clock-driven)
'''
on_pre = '''
g_exc_post += w
pre += A_plus
w = clip(w + post, 0*nS, w_max)
'''
on_post = '''
post += A_minus
w = clip(w + pre, 0*nS, w_max)
'''
S = Synapses(pool, base, model=model, on_pre=on_pre, on_post=on_post, name=f'stdp_chain{i}_to_base', method='euler')
S.connect(p=0.5)
S.w = w_init
stdp_syns.append(S)
# ----------------------------
# Monitors
# ----------------------------
spike_monitors = []
spike_monitors.append(SpikeMonitor(base,record=True))
for i, pool in enumerate(chain):
spike_monitors.append(SpikeMonitor(pool,record=True))
spike_monitors.append(SpikeMonitor(inhbase,record=True)) # monitor for inhibitory population
weight_monitors = []
for i, S in enumerate(stdp_syns):
weight_monitors.append(StateMonitor(S, 'w', record=True, dt=50*ms))
voltage_monitors = []
voltage_monitors.append(StateMonitor(base, 'v', record=0))
for i, pool in enumerate(chain):
voltage_monitors.append(StateMonitor(pool, 'v', record=0))
voltage_monitors.append(StateMonitor(inhbase, 'v', record=0)) # voltage monitor for inh
net = Network(
base, # base population
*chain, # chain populations
inhbase, # base inhibitory population
*inhchain, # inhibitory chain populations
S_ei, S_ie, # Synapses to and from the inhibitory population
stim, S_stim, # Stimulus and the stimulus synapses
*ff_syns, # feed-forward chain synapses
*ie_syns,
*ei_syns,
*stdp_syns, # feedback STDP synapses
*spike_monitors, # all spike monitors
*voltage_monitors, # all v monitors
*weight_monitors # all weight monitors
)
# ----------------------------
# Simulation
# ----------------------------
print("Initialization done in "+str(time.time()-timenow)+" seconds")
timenow = time.time()
net.run(sim_duration)
print("Whole simulation done in "+str(time.time()-timenow)+" seconds")
# ----------------------------
# Plot results
# ----------------------------
f,axarr = subplots(2,1)
axarr[0].plot(spike_monitors[0].t/second, spike_monitors[0].i, '.', markersize=2)
axarr[0].set_title('Base population spikes')
axarr[0].set_xlabel('Time (s)')
axarr[0].set_ylabel('Neuron index')
axnew = []
axnew.append(f.add_axes([0.5,0.65,0.2,0.25]))
axnew.append(f.add_axes([0.78,0.65,0.2,0.25]))
axnew[0].plot(spike_monitors[0].t/second, spike_monitors[0].i, '.', markersize=2)
axnew[0].set_xlim([0,0.1])
axnew[1].plot(voltage_monitors[0].t/second,(voltage_monitors[0].v/mV)[0],'r-',lw=0.3,label='base')
axnew[1].plot(voltage_monitors[5].t/second,(voltage_monitors[5].v/mV)[0],'m-',lw=0.3,label='pool5')
axnew[1].set_xlim([0,1])
spm = spike_monitors[0]
axarr[1].plot(spm.t/second, spm.i, '.', markersize=1.5)
offset = N+2
for i in range(0, Lmax): # plot every pool
spm = spike_monitors[1+i]
axarr[1].plot(spm.t/second, spm.i + offset, '.', markersize=1.5)
offset += N+2
axarr[1].set_xlabel('Time (s)')
axarr[1].set_ylabel('Chain pools (stacked)')
axarr[1].set_title('Synfire chain activity (subset)')
tight_layout()
if doSave:
f,axarr = subplots(1,1)
for i, S in enumerate(stdp_syns):
if len(S.w):
mean_w = np.mean(weight_monitors[i].w, axis=0)
axarr.plot(weight_monitors[i].t/second, mean_w, label=f'Pool {i}')
try:
axarr.text(weight_monitors[i].t[-1]/second, mean_w[-1], 'Pool '+str(i), fontsize=6, ha='right', va='bottom')
except:
print('Problem text weight '+str(i))
axarr.axvline(float(drive_off_time/second), color='k', linestyle='--', label='Drive off')
axarr.set_xlabel('Time (s)')
axarr.set_ylabel('Mean feedback weight')
axarr.set_title('Evolution of feedback synapse weights (pool#base)')
scipy.io.savemat("synfirefiles/stdpsynfire_synstim_allweights_N"+str(N)+"_L"+str(Lmax)+"_T"+str(stim_duration)+"_"+str(tot_duration)+"_"+str(stim_freq_Hz)+"Hz_gEE"+str(g_EE)+"_gEI"+str(g_EI)+"_gIE"+str(g_IE)+"_A"+str(A_plus_nS)+"_"+str(A_minus_nS)+"_"+str(w_init_nS)+tau_addition+seed_addition+".mat",{'spikes': [[spike_monitors[i].t/second for i in range(0,len(spike_monitors))],[spike_monitors[i].i[:] for i in range(0,len(spike_monitors))]],'weights': [weight_monitors[0].t/second,[weight_monitors[i].w/psiemens for i in range(0,len(weight_monitors))]]})