import mytools
import numpy as np
import matplotlib.pyplot as plt
import scipy.io
from pylab import *
from os.path import exists
from matplotlib.collections import PatchCollection
def format3(x):
return format(x, ".3f").rstrip("0").rstrip(".")
def format4(x):
return format(x, ".4f").rstrip("0").rstrip(".")
def hex_to_RGB(hex_color):
hex_color = hex_color.lstrip('#')
return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
def RGB_to_hex(rgb_color):
return f"#{int(rgb_color[0]):02x}{int(rgb_color[1]):02x}{int(rgb_color[2]):02x}".upper()
def generate_gradient(start_hex, end_hex, steps):
s = hex_to_RGB(start_hex)
f = hex_to_RGB(end_hex)
gradient_hex = []
for t in range(steps):
factor = t / (steps - 1) if steps > 1 else 0
r = s[0] + factor * (f[0] - s[0])
g = s[1] + factor * (f[1] - s[1])
b = s[2] + factor * (f[2] - s[2])
gradient_hex.append(RGB_to_hex((r, g, b)))
return gradient_hex
def detect_bursts_population(spike_times, isi_threshold):
"""
Detect bursts in a population spike train using an ISI threshold.
Parameters
----------
spike_times : array-like
Sorted spike times (float).
isi_threshold : float
Maximum ISI to consider spikes part of the same burst.
Returns
-------
bursts : list of bursts where each burst contains the spike times
"""
spike_times = asarray(spike_times)
if spike_times.size < 2:
return [], zeros_like(spike_times, dtype=bool)
isis = diff(spike_times)
burst_spike = zeros(spike_times.shape[0], dtype=bool)
burst_spike[1:] = isis <= isi_threshold
bursts = []
i = 0
while i < len(spike_times):
if burst_spike[i]:
start = i - 1
j = i
while j < len(spike_times) and burst_spike[j]:
j += 1
burst_times = spike_times[start:j]
bursts.append(burst_times[:])
i = j
else:
i += 1
return bursts
N = 50
Lmax = 70
stim_duration = 12.0
tot_duration = 16.0
stim_freq_Hz = 2.0
g_EE = 12.0
g_EI = 20.0
g_IE = 20.0
A_pluss = [0.02, 0.022, 0.024, 0.026, 0.027, 0.028, 0.029, 0.03, 0.032, 0.035, 0.04, 0.045]
A_minus_factors = [-0.7,-0.8,-0.9,-1.0,-1.1,-1.2,-1.5,-2.0,-3.0,-4.0]
istim_to_plot = [24,31]
A_plus_saved = []
A_minus_saved = []
mean_spike_t_saved = []
std_spike_t_saved = []
Nspikes_saved = []
weights_saved = []
N_is_longbursts_saved = []
N_is_sparsebursts_saved = []
N_is_24_OKtimed_saved = []
N_is_1_OKtimed_after_12s_saved = []
N_is_3_OKtimed_after_12s_saved = []
N_is_8_OKtimed_after_12s_saved = []
N_is_3_bursts_after_12s_saved = []
N_is_8_bursts_after_12s_saved = []
N_is_spikes_after_12s_saved = []
DT = 0.05
longBurstLen = 0.075
if len(sys.argv) > 1:
DT = float(sys.argv[1])
if len(sys.argv) > 2:
longBurstLen = float(sys.argv[2])
for iA_plus in range(0,len(A_pluss)):
A_plus_nS = A_pluss[iA_plus]
w_init_nS = A_plus_nS
for ifA in range(0,len(A_minus_factors)):
A_minus_nS = A_plus_nS*A_minus_factors[ifA]
spikes_all = []
weights_all = []
is_longbursts = []
is_sparsebursts = []
is_24_OKtimed = []
is_too_many_after_12s = []
is_1_OKtimed_after_12s = []
is_3_OKtimed_after_12s = []
is_8_OKtimed_after_12s = []
is_3_bursts_after_12s = []
is_8_bursts_after_12s = []
is_spikes_after_12s = []
for myseed in range(1,21):
seed_addition = '' if myseed == 1 else '_seed'+str(myseed)
filename_body = "stdpsynfire_synstim_N"+str(N)+"_L"+str(Lmax)+"_T"+str(stim_duration)+"_"+str(tot_duration)+"_"+str(stim_freq_Hz)+"Hz_gEE"+str(g_EE)+"_gEI"+str(g_EI)+"_gIE"+str(g_IE)+"_A"+format4(A_plus_nS)+"_"+format4(A_minus_nS)+"_"+format4(w_init_nS\
)
if exists("synfirefiles/"+filename_body+seed_addition+".mat"):
try:
A = scipy.io.loadmat("synfirefiles/"+filename_body+seed_addition+".mat")
except:
print("Could not load synfirefiles/"+filename_body+seed_addition+".mat, continuing")
continue
spikes_t = A['spikes'][0] # [[spike_monitors[i].t/second for i in range(0,len(spike_monitors))],[spike_monitors[i].i for i in range(0,len(spike_monitors))]],
spikes_i = A['spikes'][1] # [[spike_monitors[i].t/second for i in range(0,len(spike_monitors))],[spike_monitors[i].i for i in range(0,len(spike_monitors))]],
weights = A['weights'] # [weight_monitors[0].t/second,[np.mean(weight_monitors[i].w, axis=0) for i in range(0,len(weight_monitors))]]
if max(weights[0][1][0]) < 1e-4:
print("synfirefiles/"+filename_body+seed_addition+".mat is not in pS!")
weights[0][1][0] = weights[0][1][0]*1e12
spikes_all.append([spikes_t[:],spikes_i[:]])
weights_all.append(weights[:])
bursts=detect_bursts_population(spikes_t[0][0],DT)
burstlens = [x[-1]-x[0] for x in bursts]
is_spikes_after_12s.append(max(spikes_t[0][0]) > 12.6)
is_longbursts.append(max(burstlens) > longBurstLen)
is_sparsebursts.append(min([len(x) for x in bursts if x[0] < 24*0.5]) < 5)
NburstsOK = 0
burstsOK_after_12s = []
for iburst in range(1,len(bursts)):
if bursts[iburst-1][-1] + 0.5-DT <= bursts[iburst][0] <= bursts[iburst][-1] <= bursts[iburst-1][-1] + 0.5+DT and burstlens[iburst] <= DT:
NburstsOK = NburstsOK + 1
if bursts[iburst][0] >= 24*0.5:
burstsOK_after_12s.append(bursts[iburst-1][-1] + 0.5-DT <= bursts[iburst][0] <= bursts[iburst][-1] <= bursts[iburst-1][-1] + 0.5+DT and burstlens[iburst] <= DT)
is_24_OKtimed.append(NburstsOK >= 24 and 0.500*23-DT <= bursts[23][0] <= bursts[23][-1] <= 0.500*23+DT)
is_1_OKtimed_after_12s.append(sum(burstsOK_after_12s) > 0 and burstsOK_after_12s[0])
is_3_OKtimed_after_12s.append(sum(burstsOK_after_12s) > 2 and burstsOK_after_12s[0] == burstsOK_after_12s[1] == burstsOK_after_12s[2])
is_8_OKtimed_after_12s.append(sum(burstsOK_after_12s) > 7 and sum([x for x in burstsOK_after_12s[0:8]]) == 8)
is_3_bursts_after_12s.append(sum([1 for x in bursts if x[0] >= 24*0.5]) > 2)
is_8_bursts_after_12s.append(sum([1 for x in bursts if x[0] >= 24*0.5]) > 7)
A_plus_saved.append(A_plus_nS)
A_minus_saved.append(A_minus_nS)
N_is_spikes_after_12s_saved.append(sum(is_spikes_after_12s))
N_is_longbursts_saved.append(sum(is_longbursts))
N_is_sparsebursts_saved.append(sum(is_sparsebursts))
N_is_24_OKtimed_saved.append(sum(is_24_OKtimed))
N_is_1_OKtimed_after_12s_saved.append(sum(is_1_OKtimed_after_12s))
N_is_3_OKtimed_after_12s_saved.append(sum(is_3_OKtimed_after_12s))
N_is_8_OKtimed_after_12s_saved.append(sum(is_8_OKtimed_after_12s))
N_is_3_bursts_after_12s_saved.append(sum(is_3_bursts_after_12s))
N_is_8_bursts_after_12s_saved.append(sum(is_8_bursts_after_12s))
f,axarr = subplots(1,1)
cols = generate_gradient('#0000CC','#FFFFFF',21)
stuffPlotted = []
for iA_plus in range(0,len(A_pluss)):
A_plus_nS = A_pluss[iA_plus]
for ifA in range(0,len(A_minus_factors)):
A_minus_nS = A_plus_nS*A_minus_factors[ifA]
ind = [i for i in range(0,len(N_is_longbursts_saved)) if A_plus_saved[i] == A_plus_nS and A_minus_saved[i] == A_minus_nS]
if len(ind) == 0:
polygon = Polygon(array([[iA_plus,iA_plus,iA_plus+1,iA_plus+1],[ifA,ifA+1,ifA+1,ifA]]).T)
p = PatchCollection([polygon], cmap=matplotlib.cm.jet)
p.set_facecolor('#000000')
p.set_edgecolor(None)
axarr.add_collection(p)
stuffPlotted.append([[[iA_plus,iA_plus,iA_plus+1,iA_plus+1],[ifA,ifA+1,ifA+1,ifA]],'#000000',''])
continue
if len(ind) > 1:
qwe
ind = ind[0]
polygon = Polygon(array([[iA_plus,iA_plus,iA_plus+1,iA_plus+1],[ifA,ifA+1,ifA+1,ifA]]).T)
N_is_longbursts = N_is_longbursts_saved[ind]
N_is_3_OKtimed_after_12s = N_is_3_OKtimed_after_12s_saved[ind]
N_is_longbursts = N_is_longbursts_saved[ind]
p = PatchCollection([polygon], cmap=matplotlib.cm.jet)
p.set_facecolor(cols[N_is_3_OKtimed_after_12s])
p.set_edgecolor(None)
if N_is_3_OKtimed_after_12s < 20:
p.set_hatch('/'*int((N_is_longbursts)/4))
myadd = '/'*int((N_is_longbursts)/4)
else:
myadd = ''
stuffPlotted.append([[[iA_plus,iA_plus,iA_plus+1,iA_plus+1],[ifA,ifA+1,ifA+1,ifA]],cols[N_is_3_OKtimed_after_12s],myadd])
axarr.add_collection(p)
if iA_plus == 0:
axarr.text(-0.05,ifA+0.5,str(A_minus_factors[ifA]),fontsize=5,rotation=0,va='center',ha='right',clip_on=False)
axarr.text(iA_plus+0.5,-0.05,str(A_plus_nS)+' nS',fontsize=5,va='top',ha='center',clip_on=False)
axarr.set_xlim([0,len(A_pluss)])
axarr.set_ylim([0,len(A_minus_factors)])
axarr.set_xticks([])
axarr.set_yticks([])
f.savefig("fig_stdpsynfire_gridsearch3_"+str(DT)+"_"+str(longBurstLen)+".pdf")
scipy.io.savemat("fig_stdpsynfire_gridsearch3_"+str(DT)+"_"+str(longBurstLen)+".mat",{'stuffPlotted': stuffPlotted, 'A_pluss': A_pluss, 'A_minus_factors': A_minus_factors})