from brian2 import *
from inputs import *
import os
import time
import numpy
# BrianLogger.suppress_hierarchy('brian2.codegen')
# BrianLogger.suppress_name('resolution_conflict')
# BrianLogger.suppress_hierarchy('brian2')
# passive
Cm = 1e-6 * farad * cm**-2
Ri = 100 * ohm * cm
g_l = 1e-4 * siemens * cm**-2 # Some models use
E_l = -70 * mV
# hh
E_Na = 58 * mV
E_K = -80 * mV
g_Na = 12 * msiemens / cm**2 # I doubled this value / 12 : This is the value from (Destexhe,Pare,1999) Some models use 50 ms/cm2 (pyramidal neurons, Wang 1998) or 35 ms/cm2 (interneurons)
# Experiments: 1994 - Stuart/Sakmann > 4 mS/cm2
g_K = 7 * msiemens / cm**2 # I doubled this value / 7 : In (Destexhe,Pare, 1999) = 10 ms/cm2. Some modles use 10.5 ms/cm2 (pyramidal neuron, Wang 1998)
v_th = - 63 * mV
# synaptic
E_ampa = 0 * mV
E_gaba = -75 * mV
tau_ampa = 5 * ms
tau_gaba = 5 * ms
dg_ampa = 1 * nS
dg_gaba = 1 * nS
morphology = 'point' # point, ball+stick
dend_shape = 'cone' # cylinder, cone
mode = 'simple' # simple, detailed
speed = 'fast' # fast, long
chosen_param = 'wgh' # rate_per_weight, wgh, ch_dns
ns_ampa_tot = 200
gaba_ratio = 0.2
jtr = 10. * ms # plots for madex 10 ms; plots for hh 30 ms
ch_dns = 1.
# if scanning over rate
wgh = 0.5 # nS # 0.5
if morphology == 'point':
rate_range = 17.
if morphology == 'ball+stick':
rate_range = 5. # program will scan over 10%, 25%, 50%, 75% and 100% of this value
# rate range for 0.5 nS synapse: 7
# rate range for 0.2 nS synapse: 40
# rate range for 1.5 ns synapse: 1.2 (3: too high!!!)
# if scanning over wgh
rate_per_weight = 0.75 # 0.75
method = 'rk2' # 'linear', 'exponential_euler', 'rk2 or 'heun'
dt = 0.05 * ms # 0.05 !!!
diam_soma = 40 * um # 40
# length of dendrite
length = 1000 * um
len_comp = 5 * um
n_cmp = int(length / um / (len_comp / um))
# cylinder
diam_cln = 1 * um
# cone
diam_0 = 5
diam_f = 0.5
beta = 0.05 # smaller more steep decrease of diameter 0.1
lengthT = linspace(0, length / um, num=n_cmp + 1)
if morphology == 'point':
ns_ampa = ns_ampa_tot
ns_gaba_soma = int(ceil(ns_ampa * gaba_ratio))
wgh = 0.105 # nS # 0.11 is good for 0.5 synapse and 10.3 Hz rate
if morphology == 'ball+stick':
ns_ampa = int(ns_ampa_tot / n_cmp)
ns_gaba_soma = int(ceil(ns_ampa_tot * gaba_ratio))
param_to_scan = {}
# rate
param_to_scan['rate_per_weight'] = np.array([0.1, 0.25, 0.5, 0.75, 1.]) * rate_range
param_to_scan['rate_per_weight'] = param_to_scan['rate_per_weight'].tolist()
# weight
if chosen_param == 'wgh':
if morphology == 'ball+stick':
wgh_list = [0.25, 0.5, 0.75, 1.] # nS
param_to_scan['wgh'] = wgh_list
rate_wgh = np.array([4., 2., 1.5, 1.]) # it was 2 then 1.85
if morphology == 'point':
wgh_list = [0.05, 0.1, 0.15, 0.2] # nS
param_to_scan['wgh'] = wgh_list
rate_wgh = np.array([23., 11.5, 7.7, 5.75]) # it was 2 then 1.85
if chosen_param == 'ch_dns':
if morphology == 'ball+stick':
param_to_scan['ch_dns'] = [0., 0.25, 0.5, 0.75, 1.]
# parameters used for wgh: 0.5 nS and jtr: 15 ms > 10 Hz for noncorrelated
# rate_ch_dns = np.array([10.3, 8., 4., 2.3, 1.7]) * 1 # [20, 2] (2.5 for 0.75 was too high) (1.6 for 1 was too low)
#rate_ch_dns = np.array([7, 2., 1, 0.5, 0.25]) # [20, 2] (2.5 for 0.75 was too high) (1.6 for 1 was too low)
rate_ch_dns = np.array([ 11.5, 9. , 4.8 , 2.5, 2.]) # it was 2 then 1.85
### Scaling of synpses
if morphology == 'point':
param_to_scan['ch_dns'] = [0.]
rate_ch_dns = np.array([11.5]) # [20, 2] (2.5 for 0.75 was too high) (1.6 for 1 was too low)
# local correlation
param_to_scan['rL'] = [0.1, 0.5, 1.]
# correlation scan
corr_scan = np.arange(0.,1.1,0.1)
if speed == 'fast':
t_sim = 2000 * ms
n_rep = 1
if speed == 'long':
t_sim = 20000 * ms
n_rep = 10
if mode == 'detailed':
t_sim = 1000 * ms
if chosen_param == 'rate_per_weight':
chosen_param_short = 'rate'
chosen_param_short = chosen_param
if morphology == 'point':
length = diam_soma
n_cmp = 1
morpho = Soma(diameter = diam_soma)
if morphology == 'ball+stick':
morpho = Soma(diameter=diam_soma)
if dend_shape == 'cylinder':
morpho.dendrite = Cylinder(diameter=diam_cln, length=length, n=n_cmp)
if dend_shape == 'cone':
alpha = (diam_0 - diam_f) / ((length / um) ** beta)
diameterT = - alpha * lengthT ** beta + diam_0
morpho.dendrite = Section(diameter=diameterT * um, length=[len_comp / um] * n_cmp * um, n=n_cmp)
# there is an error (Magic Network) when there are no spikes
date_time = time.strftime('%y-%m-%d--%H-%M-%S', time.localtime())
if not os.path.exists('../DATA/HH/' + str(date_time) + '_[scan_' + chosen_param_short + ']_[' + morphology + ']'):
os.makedirs('../DATA/HH/' + str(date_time) + '_[scan_' + chosen_param_short + ']_[' + morphology + ']')
main_path = os.getcwd()
os.chdir('../DATA/HH/' + str(date_time) + '_[scan_' + chosen_param_short + ']_[' + morphology + ']')
saved_param = {'date_time': str(date_time),
'morphology': morphology,
'diam_soma': diam_soma,
'diam_0': diam_0,
'diam_f': diam_f,
'beta': beta,
'length': length,
'number_of_comp': n_cmp,
'ns_ampa': ns_ampa,
'gaba_ratio': gaba_ratio,
'chosen_param': chosen_param,
'rate_per_weight': rate_per_weight,
'jitter': jtr,
'wgh': wgh,
'param_to_scan': param_to_scan[chosen_param],
'corr_scan': corr_scan,
if chosen_param == 'ch_dns':
saved_param['rate_ch_dns'] = rate_ch_dns
print('\n *** Chosen parameter: ' + str(chosen_param) + ' ***\n\n')
t0 = time.time()
counter = 0
for param in param_to_scan[chosen_param]:
print('\n\n *** ' + str(chosen_param) + ' = ' + str(round(param,4)) + ' *** \n')
if mode == 'detailed':
os.makedirs(str(chosen_param) + '_' + str(param))
globals()[chosen_param] = param
rate_ampa = rate_per_weight * Hz
rate_gaba = rate_per_weight * Hz
if chosen_param == 'ch_dns':
rate_per_weight = rate_ch_dns[counter]
rate_ampa = rate_per_weight * Hz
rate_gaba = rate_per_weight * Hz
if chosen_param == 'wgh':
rate_per_weight = rate_wgh[counter]
rate_ampa = rate_per_weight * Hz
rate_gaba = rate_per_weight * Hz
# parameters from Pare 1998
eqs = """
Im = g_l * (E_l-v) + g_Na * m**3 * h * (E_Na-v) + g_K * n**4 * (E_K-v) : amp/meter**2
dm/dt = alpham * (1-m) - betam * m : 1
dh/dt = alphah * (1-h) - betah * h : 1
dn/dt = alphan * (1-n) - betan * n : 1
alpham = - 0.32/mV * (v - v_th - 13 * mV) / (exp(-(v - v_th - 13 * mV)/(4 * mV)) - 1) /ms : Hz
betam = 0.28/mV * (v - v_th - 40 * mV) / (exp((v - v_th - 40 * mV)/(5 * mV)) -1) /ms : Hz
alphah = 0.128 * exp(-(v - v_th - 17 * mV) / (18 * mV)) /ms : Hz
betah = 4 / (1 + exp(-(v - v_th - 40 * mV) / (5 * mV))) /ms : Hz
alphan = -0.032 / mV * (v - v_th - 15 * mV) / (exp(-(v - v_th - 15 * mV) / (5 * mV)) - 1) / ms : Hz
betan = 0.5 * exp(-(v - v_th - 10 * mV) / (40 * mV)) /ms : Hz
Is = g_ampa * (E_ampa - v) + g_gaba * (E_gaba - v) + I : amp (point current)
dg_ampa/dt = -g_ampa/tau_ampa : siemens
dg_gaba/dt = -g_gaba/tau_gaba : siemens
I : amp
g_Na : siemens / meter**2
g_K : siemens / meter**2
corr_frq = [[0, 0]]
for rG in corr_scan.tolist() * n_rep:
if chosen_param != 'rL':
rL = 1
print('\n\nCorrelation: ' + str(round(rG,2)) + '\n')
neuron_with_dendrite = SpatialNeuron(morphology = morpho, model = eqs,
method = method, dt = dt, Cm = Cm, Ri = Ri)
neuron_with_dendrite.v = -70 * mV
neuron_with_dendrite.h = 0.
neuron_with_dendrite.m = 0.
neuron_with_dendrite.n = 0.
neuron_with_dendrite.I = 0.
neuron_with_dendrite.g_Na[0] = g_Na
neuron_with_dendrite.g_Na[1:] = ch_dns * g_Na
neuron_with_dendrite.g_K[0] = g_K
neuron_with_dendrite.g_K[1:] = ch_dns * g_K
trace = StateMonitor(neuron_with_dendrite, 'v', record=True)
spt_exc = spike_trains_hierarch(n_cmp, ns_ampa, rate_ampa/Hz, t_sim/second, rG, rG * rL, jtr/second)
range_of_compartments = range(n_cmp)
if morphology == 'ball+stick':
range_of_compartments = range(1,n_cmp)
times_all_exc = []
compartments_all_exc = []
for k in range_of_compartments:
for i_s in range(ns_ampa):
times = spt_exc[k][i_s]
condition = times > 0.
times = np.extract(condition, times)
times_all_exc = np.concatenate((times_all_exc, times))
compartments_all_exc = np.concatenate((compartments_all_exc, [int(k)] * len(times)))
if times_all_exc.tolist() != []:
indices_all_exc = np.array(range(len(times_all_exc)))
inp_exc = SpikeGeneratorGroup(len(times_all_exc), indices_all_exc, times_all_exc * second)
syn_exc = Synapses(inp_exc, neuron_with_dendrite, on_pre = 'g_ampa += ' + str(wgh) + '* dg_gaba')
syn_exc.connect(i = indices_all_exc.astype(int), j = compartments_all_exc.astype(int))
# inhibitory synapses method 2
spt_inh = spike_trains_hierarch(1, ns_gaba_soma, rate_gaba / Hz, t_sim / second, 0, 0, jtr / second)
times_all_inh = []
compartments_all_inh = []
for i_s in range(ns_gaba_soma):
times = spt_inh[0][i_s]
condition = times > 0.
times = np.extract(condition, times)
times_all_inh = np.concatenate((times_all_inh, times))
compartments_all_inh = np.concatenate((compartments_all_inh, [0] * len(times)))
if times_all_inh.tolist() != []:
indices_all_inh = np.array(range(len(times_all_inh)))
inp_inh = SpikeGeneratorGroup(len(times_all_inh), indices_all_inh, times_all_inh * second)
syn_inh = Synapses(inp_inh, neuron_with_dendrite, on_pre = 'g_gaba += ' + str(wgh) + '* dg_gaba')
syn_inh.connect(i = indices_all_inh.astype(int), j = compartments_all_inh.astype(int))
if times_all_exc.tolist() == [] or times_all_inh.tolist() == []:
tr = time.time()
run(t_sim, report = 'text')
print('Time of run: ' + str(time.time() - tr) + 's')
if mode == 'detailed':
plot(trace.t/second,trace.v[0].T/mV )
xlabel('time (ms)')
ylabel('membrane potential (mV)')
savefig(str(chosen_param) + '_' + str(param) + '/figure_V_t_wgh_' + str(wgh) + '_corr_' + str(rG) +'.png', dpi = 300)
plot(trace.t / second, trace.v[100].T / mV)
xlabel('time (ms)')
ylabel('dendritic membrane potential (mV)')
savefig(str(chosen_param) + '_' + str(param) + '/figure_V_t_dend_wgh_' + str(wgh) + '_corr_' + str(rG) + '.png',
contf = contourf(cumsum(neuron_with_dendrite.length) / cm, trace.t / ms, trace.v.T / mV, cmap='YlOrRd', alpha=1.,
levels=np.linspace(np.amin(trace.v / mV), np.amax(trace.v / mV), num = 100))
xlabel('Position [cm]')
ylabel('Time [ms]')
savefig(str(chosen_param) + '_' + str(param) + '/figure_V_t_contour_wgh' + str(wgh) + '_corr_' + str(rG) +'.png', dpi=300)
# Spike counting
vT = trace.v[0] /mV
th = 0.
df_vT = vT - th
df_vT_rl = np.roll(df_vT,1)
th_cr_det = df_vT[1:] * df_vT_rl[1:]
ind_cr = np.where(th_cr_det < 0.)
n_sp = float(len(ind_cr[0])) / 2.
frq = n_sp / t_sim
print('Frequency of spikes: ' + str(frq))
corr_frq = np.append(corr_frq, [[rG, frq]], axis=0)
if morphology == 'ball+stick':
vT_dend = trace.v[100] /mV
th_dend = -20.
df_vT_dend = vT_dend - th_dend
df_vT_dend_rl = np.roll(df_vT_dend, 1)
th_cr_det_dend = df_vT_dend[1:] * df_vT_dend_rl[1:]
ind_cr_dend = np.where(th_cr_det_dend < 0.)
n_sp_dend = float(len(ind_cr_dend[0])) / 2.
frq_dend = n_sp_dend / t_sim
print('Frequency of dendritic spikes: ' + str(frq_dend))
corr_frq = corr_frq[1:]
corr_mean_frq = [mean(corr_frq[i:len(corr_frq):len(corr_scan)], axis=0).tolist() for i in range(len(corr_scan))]
fig_n_sp = figure()
plot(corr_frq[:, 0], corr_frq[:, 1], 'b+')
plot(array(corr_mean_frq)[:, 0], array(corr_mean_frq)[:, 1], 'r+')
savefig('figure_frq_weight_' + str(wgh) + '_jtr_' + str(jtr) + '_rate_per_weight_' + str(
round(rate_per_weight, 4)) + '_rL_' + str(rL) + '_ch_dns_' + str(ch_dns) + '.png', dpi=300)'corr_frq_weight_' + str(wgh) + '_jtr_' + str(jtr) + '_rate_per_weight_'+ str(
round(rate_per_weight, 4)) + '_rL_' + str(rL) + '_ch_dns_' + str(ch_dns) + '.npy', corr_frq)
counter = counter + 1
t_cmp = (time.time() - t0) / 3600.
print('Time of computation: ' + str(t_cmp) + ' hours')
n_runs = len(param_to_scan[chosen_param] * n_rep * len(corr_scan))
print('For ' + str(n_runs) + ' of ' + str(t_sim/second) + ' s runs \n')
print('Which gives ' + str(round((t_cmp * 3600) / (n_runs * t_sim/second),2)) + ' s of computation for 1 s' )
import numpy as np
import matplotlib.pyplot as plt
import as cm
plt.rcParams[""] = "sans-serif"
plt.rcParams['font.sans-serif'] = 'Museo Sans'
plt.rcParams["font.size"] = "22"
corr_frq = {}
corr_frq_mean = {}
corr_frq_std = {}
n_corr = len(corr_scan)
low_plot = 0
high_plot = len(param_to_scan[str(chosen_param)])
maxT = []
counter = 0
for param in param_to_scan[chosen_param]:
if chosen_param == 'ch_dns':
rate_per_weight = rate_ch_dns[counter]
if chosen_param == 'wgh':
rate_per_weight = rate_wgh[counter]
globals()[chosen_param] = param
corr_frq[str(param)] = np.load('corr_frq_weight_' + str(wgh) + '_jtr_' + str(jtr) + '_rate_per_weight_' + str(
round(rate_per_weight,3)) + '_rL_' + str(rL) + '_ch_dns_' + str(ch_dns) + '.npy')
corr_list = corr_frq[str(param)][:,0]
bool_corr = {}
corr_frq_red = {}
all_mean = [[0,0]]
all_std = [[0,0]]
for rG in corr_scan:
bool_corr[str(rG)] = np.where(corr_list == np.ones(len(corr_list)) * rG)[0]
corr_frq_red[str(rG)] = corr_frq[str(param)][bool_corr[str(rG)]]
mean = np.mean(corr_frq_red[str(rG)], axis = 0)
std = np.std(corr_frq_red[str(rG)], axis = 0)
all_mean = np.append(all_mean, [mean],axis = 0)
all_std = np.append(all_std, [std],axis = 0)
corr_frq_mean[str(param)] = all_mean[1:]
corr_frq_std[str(param)] = all_std[1:]
meanT = np.array(corr_frq_mean[str(param)])[:,1]
mean_without_nan = meanT[~np.isnan(meanT)]
maxT = np.append(maxT, np.amax(mean_without_nan))
counter = counter + 1
maxmaxT = np.amax(maxT)
color_array = np.linspace(0.3, 1, num=len(param_to_scan[chosen_param]))
color_array = color_array.tolist()
#color_array_reversed = color_array.reverse()
param_to_scan[chosen_param] = param_to_scan[chosen_param][low_plot:high_plot]
# labels
labels = array(param_to_scan[chosen_param])
if chosen_param == 'rate_per_weight':
labels = array(param_to_scan[chosen_param]) * wgh
fig_corr = plt.figure(figsize=(12,7))
ax = fig_corr.add_subplot(111)
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width*0.9, box.height])
ax.set_ylim(bottom = 0, top = 1.2 * maxmaxT )
ax.set_xlabel('ratio of shared spikes', fontsize = 24)
ax.set_ylabel('firing rate [Hz]', fontsize = 24)
clr = 0
for param in param_to_scan[chosen_param]:
globals()[chosen_param] = param
c = cm.YlOrRd(color_array[clr], 1)
upper_limit = np.array(corr_frq_mean[str(param)])[:,1] + np.array(corr_frq_std[str(param)])[:,1]
lower_limit = np.array(corr_frq_mean[str(param)])[:,1] - np.array(corr_frq_std[str(param)])[:,1]
ax.plot(np.array(corr_frq_mean[str(param)])[:,0], np.array(corr_frq_mean[str(param)])[:,1], color = c, label = labels[clr])
ax.fill_between(np.array(corr_frq_mean[str(param)])[:,0], lower_limit, upper_limit, color = c, alpha = 0.5)
leg = plt.legend()
# get the individual lines inside legend and set line width
for line in leg.get_lines():
clr = clr + 1
ax.legend(loc = 'center left', bbox_to_anchor = (1.0, 0.5), frameon = False, fontsize = 22)
plt.savefig('[scan_' + chosen_param_short + '].png', dpi = 300)
plt.savefig('[scan_' + chosen_param_short + '].svg', dpi = 300)'chosen_param.npy', chosen_param)'param_to_scan.npy', param_to_scan[chosen_param])'corr_frq_mean.npy', corr_frq_mean)'corr_frq_std.npy', corr_frq_std)