from pylab import *
import scipy.io
import mytools
from matplotlib.collections import PatchCollection
from os.path import exists
from scipy.ndimage import gaussian_filter1d
def mybar(ax,x,y,facecolor=[],linewidth=0.1,w=0.4):
qs = quantile(y, [0,0.25,0.5,0.75,1])
polygon = Polygon(array([[x-w,x+w,x+w,x-w],[qs[1],qs[1],qs[3],qs[3]]]).T)
p = PatchCollection([polygon], cmap=matplotlib.cm.jet)
if type(facecolor) is not list or len(facecolor) > 0:
p.set_facecolor(facecolor)
p.set_edgecolor('#000000')
p.set_linewidth(0.3)
ax.add_collection(p)
a2 = ax.plot([x-w,x+w,x,x,x-w,x+w,x,x,x-w,x+w],[qs[0],qs[0],qs[0],qs[2],qs[2],qs[2],qs[2],qs[4],qs[4],qs[4]],'k-',lw=linewidth)
return [p,a2]
filenames = sys.argv[1:]
try:
NperpopIndex = 1
while 'Nperpop' not in filenames[0].split('_')[NperpopIndex]:
NperpopIndex = NperpopIndex + 1
NperpopStr = filenames[0].split('_')[NperpopIndex]
Nperpop = int(NperpopStr[7:])
except:
Nperpop = 40
seqAdds = [[[[0,500],[0,0]],[[0,500],[0,0]]], #0: both stimuli the same 0- omission
[[[0,450,450,500,500],[0,0,1,1,0]],[[0,500],[0,0]]], #1: standard short, deviant missing
[[[0,500],[0,0]],[[0,450,450,500,500],[0,0,1,1,0]]], #2: standard missing, deviant short
[[[0,400,400,500,500],[0,0,1,1,0]],[[0,500],[0,0]]], #3: standard long, deviant missing
[[[0,500],[0,0]],[[0,400,400,500,500],[0,0,1,1,0]]]] #4: standard missing, deviant long
fig1, axs = subplots(4,1)
axarr = axs.reshape(prod(axs.shape),).tolist()
axs[0].set_position([0.065, 0.08+0.08*7,0.9,0.08])
axs[1].set_position([0.065, 0.08+0.08*6+0.06,0.9,0.02])
axs[0].set_xticks([])
axs[1].set_yticks([])
axs[2].set_position([0.065, 0.5,0.9,0.08])
axs[3].set_position([0.065, 0.4,0.9,0.08])
for iax in range(0,len(axarr)):
axarr[iax].tick_params(axis='both', which='major', labelsize=4)
for axis in ['top','bottom','left','right']:
axarr[iax].spines[axis].set_linewidth(0.2)
axarr[iax].set_ylim([0,Nperpop])
axarr[0].set_yticks([0,int(Nperpop/2)])
xlimmax = 0
istims = [6,11,16,21,26,31,36,42,47,52,57,62,68,73,78,84]
for iseq in istims:
for iax in [0,1]:
polygon = Polygon(array([[500*iseq-150,500*iseq-150,500*iseq+350,500*iseq+350],[0,Nperpop,Nperpop,0]]).T)
p = PatchCollection([polygon], cmap=matplotlib.cm.jet)
p.set_facecolor('#DDDDFF')
p.set_edgecolor(None)
axarr[iax].add_collection(p)
for iax in [2]:
for iblock in [0,1]:
polygon = Polygon(array([[500*iseq-150+iblock*500,500*iseq-150+iblock*500,500*iseq+iblock*500+350,500*iseq+iblock*500+350],[0,5,5,0]]).T)
p = PatchCollection([polygon], cmap=matplotlib.cm.jet)
p.set_facecolor('#DDDDFF' if iblock==0 else '#EEEEEE')
p.set_edgecolor(None)
axarr[iax].add_collection(p)
#axs[0].text(0,Nperpop-1,'Output (EO)',fontsize=4,ha='left',va='top',fontweight='bold')
#axs[1].text(0,Nperpop-1,'Standard (blue) and deviant (red) stimulus',fontsize=4,ha='left',va='top')
cols = mytools.colorsredtolila(len(filenames)+1,0.8)
if True:
A = scipy.io.loadmat('MMNs_2pm_sep_noISDIDD_seq0_model0_pop_seed1.mat')
for q in ['standard', 'deviant', 'pacemaker', 'pacemaker2', 'output', 'standardBoost', 'deviantBoost']:
try:
shp = A[q].shape
for iy in range(0,shp[0]):
for ix in range(0,shp[1]):
if A[q][iy,ix].shape[0] == 1 and A[q][iy,ix].shape[1] > 1:
A[q][iy,ix] = A[q][iy,ix][0]
except:
pass
plotteds = []
xlimmax = max(xlimmax,500*len(A['sequence'][0]))
for iMMN in range(0,1):
plotteds_this = []
axs[0].plot(A['output'][iMMN,0], A['output'][iMMN,1], 'r.', lw=0.35, ms=0.35, mew=0.35, color='#000000')
plotteds_this.append(len(A['output'][iMMN,0]))
for iseq in range(0,len(A['sequence'][0])):
istim = int(A['sequence'][0][iseq])
axs[1].plot([500*iseq+x for x in seqAdds[istim][0][0]],[24+10*x for x in seqAdds[istim][0][1]],'b-',lw=0.25)
axs[1].plot([500*iseq+x for x in seqAdds[istim][1][0]],[6+10*x for x in seqAdds[istim][1][1]],'r-',lw=0.25)
axs[1].plot([500*len(A['sequence'][0]),500*len(A['sequence'][0])+500],[24,24],'b-',lw=0.25)
axs[1].plot([500*len(A['sequence'][0]),500*len(A['sequence'][0])+500],[6,6],'r-',lw=0.25)
for iax in range(0,len(axarr)):
axarr[iax].set_xlim([0,1000+500*len(A['sequence'][0])])
#axarr[1].text(500*6,-20,'short standard to short deviant',rotation=90,fontsize=5,ha='center',va='top')
#axarr[1].text(500*11,-20,'short deviant to short standard',rotation=90,fontsize=5,ha='center',va='top')
#axarr[1].text(500*16,-20,'short standard to long standard',rotation=90,fontsize=5,ha='center',va='top')
#axarr[1].text(500*21,-20,'long standard to short standard',rotation=90,fontsize=5,ha='center',va='top')
#axarr[1].text(500*26,-20,'short standard to long deviant',rotation=90,fontsize=5,ha='center',va='top')
#axarr[1].text(500*31,-20,'long deviant to short standard',rotation=90,fontsize=5,ha='center',va='top')
#axarr[1].text(500*36,-20,'omission from short standards',rotation=90,fontsize=5,ha='center',va='top')
#axarr[1].text(500*42,-20,'short deviant to long standard',rotation=90,fontsize=5,ha='center',va='top')
#axarr[1].text(500*47,-20,'long standard to short deviant',rotation=90,fontsize=5,ha='center',va='top')
#axarr[1].text(500*52,-20,'short deviant to long deviant',rotation=90,fontsize=5,ha='center',va='top')
#axarr[1].text(500*57,-20,'long deviant to short deviant',rotation=90,fontsize=5,ha='center',va='top')
#axarr[1].text(500*62,-20,'omission from short deviant',rotation=90,fontsize=5,ha='center',va='top')
#axarr[1].text(500*68,-20,'long standard to long deviant',rotation=90,fontsize=5,ha='center',va='top')
#axarr[1].text(500*73,-20,'long deviant to long standard',rotation=90,fontsize=5,ha='center',va='top')
#axarr[1].text(500*78,-20,'omission from long standard',rotation=90,fontsize=5,ha='center',va='top')
#axarr[1].text(500*84,-20,'omission from long deviant',rotation=90,fontsize=5,ha='center',va='top')
for iiseq in range(0,16):
iseq = istims[iiseq]
xs = [500*iseq-150,500*iseq-550,500*iseq+1150,500*iseq+1750,500*iseq+350]
xfracs = [x/(1000+500*len(A['sequence'][0])) for x in xs]
verts = [(0.065+0.9*xfracs[0],0.72),(0.065+0.9*xfracs[1],0.73),(0.065+0.9*xfracs[1],0.82),(0.065+0.9*xfracs[0],0.835),(0.065+0.9*xfracs[2],0.835),(0.065+0.9*xfracs[3],0.82),(0.065+0.9*xfracs[3],0.76),(0.065+0.9*xfracs[4],0.72)]
polygon = Polygon(verts, closed=True, transform=fig1.transFigure,
facecolor='#DDDDFF', edgecolor=None, zorder=-1) #, alpha=0.5)
fig1.patches.append(polygon) # Attach directly to figure
axarr[1].text(500*6,168,' short\n standard\n to short\ndeviant',rotation=45,fontsize=5,ha='center',va='bottom',zorder=10)
axarr[1].text(500*11,168,' short\n deviant\n to short\nstandard',rotation=45,fontsize=5,ha='center',va='bottom')
axarr[1].text(500*16,168,' short\n standard\n to long\nstandard',rotation=45,fontsize=5,ha='center',va='bottom')
axarr[1].text(500*21,168,' long\n standard\n to short\nstandard',rotation=45,fontsize=5,ha='center',va='bottom')
axarr[1].text(500*26,168,' short\n standard\n to long\ndeviant',rotation=45,fontsize=5,ha='center',va='bottom')
axarr[1].text(500*31,168,' long\n deviant\n to short\nstandard',rotation=45,fontsize=5,ha='center',va='bottom')
axarr[1].text(500*36,168,' omission\n from short\n standards',rotation=45,fontsize=5,ha='center',va='bottom')
axarr[1].text(500*42,168,' short\n deviant\n to long\nstandard',rotation=45,fontsize=5,ha='center',va='bottom')
axarr[1].text(500*47,168,' long\n standard\n to short\ndeviant',rotation=45,fontsize=5,ha='center',va='bottom')
axarr[1].text(500*52,168,' short\n deviant\n to long\ndeviant',rotation=45,fontsize=5,ha='center',va='bottom')
axarr[1].text(500*57,168,' long\n deviant\n to short\ndeviant',rotation=45,fontsize=5,ha='center',va='bottom')
axarr[1].text(500*62,168,' omission\n from short\n deviant',rotation=45,fontsize=5,ha='center',va='bottom')
axarr[1].text(500*68,168,' long\n standard\n to long\ndeviant',rotation=45,fontsize=5,ha='center',va='bottom')
axarr[1].text(500*73,168,' long\n deviant\n to long\nstandard',rotation=45,fontsize=5,ha='center',va='bottom')
axarr[1].text(500*78,168,' omission\n from long\n standard',rotation=45,fontsize=5,ha='center',va='bottom')
axarr[1].text(500*84,168,' omission\n from long\n deviant',rotation=45,fontsize=5,ha='center',va='bottom')
fig1.savefig('fig_longseq.pdf')
its_target = [500*i-200 for i in istims]
its_nontarget = [x-500 for x in its_target]
curves_all = []
Nspikes_target_all = []
Nspikes_nontarget_all = []
for imodel in range(0,16):
Nspikes_target_thisgroup = []
Nspikes_nontarget_thisgroup = []
curves_thisgroup = []
for myseed in [1]:
#MMNs_2pm_sep_noISDIDD_limtau_withnoisyfittedcortical_Nperpop40_SD0.3_model7_CTRLpop_seed13.mat
filename = 'MMNs_2pm_sep_noISDIDD_seq0_model'+str(imodel)+'_pop_seed1.mat'
if not exists(filename):
print(filename+' does not exist')
continue
A = scipy.io.loadmat(filename)
if myseed == 1:
print('Loaded '+filename)
curves_thissamp = []
Nspikes_target_thissamp = []
Nspikes_nontarget_thissamp = []
spikes = A['output'][0][0][0]
spikers = A['output'][0][1][0]
if type(spikes) == np.float64:
spikes = A['output'][0][0]
spikers = A['output'][0][1]
#Target (deviant):
#for it in [it_target]:
# if myseed == 1 and imodel == 0:
# axarr[0].plot([spikes[i]-it+iMMNtype*1400+0 for i in range(0,len(spikes)) if spikes[i] >= it and spikes[i] < it+500],[spikers[i]+80*(igroup==0)+40*(igroup==1) for i in range(0,len(spikes)) if spikes[i] >= it and spikes[i] < it+500],'.',
# ms=0.5,mew=0.5,lw=0.5,color=cols[igroup])
# print('igroup = '+str(igroup)+', iMMNtype = '+str(iMMNtype)+', standards, '+str(len([spikes[i]-it+iMMNtype*1400+0 for i in range(0,len(spikes)) if spikes[i] >= it and spikes[i] < it+500]))+' plotted')
mysigma = 25 #25 ms std
# Parameters
n_samples = 43000
spike_train = np.zeros(n_samples)
rounded_spikes = np.round(spikes).astype(int)
valid_spikes = rounded_spikes[(rounded_spikes >= 0) & (rounded_spikes < n_samples)]
spike_train = np.bincount(valid_spikes, minlength=n_samples).astype(float)
# Apply Gaussian smoothing (convolution)
thiscurve = gaussian_filter1d(spike_train, sigma=mysigma)
# If you want to normalize as in your original formula, you can scale:
#thiscurve /= (mysigma * np.sqrt(2 * np.pi))
#Non-target (standards):
Nspikes_nontarget_this = 0
nontarget_curve = zeros([500])
for it in its_target:
target_curve = thiscurve[it:it+500]
Nspikes_target_this = len([1 for x in spikes if x >= it and x < it+500])
Nspikes_nontarget_this = len([1 for x in spikes if x >= it-500 and x < it])
nontarget_curve = thiscurve[it-500:it]
#axarr[2].plot([spikes[i]-it+100 for i in range(0,len(spikes)) if spikes[i] >= it and spikes[i] < it+500],[spikers[i]+80*(igroup==0)+40*(igroup==1) for i in range(0,len(spikes)) if spikes[i] >= it and
# spikes[i] < it+500],'.', ms=0.5,mew=0.5,lw=0.5,color=cols[igroup])
#axarr[2].plot([spikes[i]-it+1100 for i in range(0,len(spikes)) if spikes[i] >= it-500 and spikes[i] < it],[spikers[i]+80*(igroup==0)+40*(igroup==1) for i in range(0,len(spikes)) if spikes[i] >= it-500 and
# spikes[i] < it],'.', ms=0.5,mew=0.5,lw=0.5,color=cols[igroup])
Nspikes_target_thissamp.append(Nspikes_target_this)
Nspikes_nontarget_thissamp.append(Nspikes_nontarget_this)
Nspikes_target_thisgroup.append(Nspikes_target_thissamp[:])
Nspikes_nontarget_thisgroup.append(Nspikes_nontarget_thissamp[:])
curves_thisgroup.append(thiscurve[:])
Nspikes_target_all.append(Nspikes_target_thisgroup[:])
Nspikes_nontarget_all.append(Nspikes_nontarget_thisgroup[:])
mean_curves = mean(array(curves_thisgroup),axis=0)
curves_all.append(mean_curves[:])
mean_curve = mean(curves_all,axis=0)
for iit in range(0,len(its_target)):
it = its_target[iit]
axarr[2].plot(range(it,it+500),mean_curve[it:it+500],'k-',lw=0.5)
axarr[2].plot(range(it+500,it+1000),mean_curve[it-500:it],'k-',lw=0.5)
ddi = [Nspikes_target_all[imodel][0][iit]-Nspikes_nontarget_all[imodel][0][iit] for imodel in range(0,16)]
#axarr[3].bar(it+250,mean(ddi),width=1000,facecolor='#FF00FF' if iit in [0,2,3,6] else '#00FFFF')
#axarr[3].plot([it+250,it+250],[mean(ddi)-std(ddi), mean(ddi)+std(ddi)],'k-',lw=0.5)
Q=mybar(axarr[3],it+250,ddi,facecolor='#999999' if iit in [0,2,3,6] else '#AA9900',linewidth=0.3,w=400)
Q[1][0].set_color('#000000' if iit in [0,2,3,6] else '#555500')
Q[0].set_edgecolor('#000000' if iit in [0,2,3,6] else '#555500')
Q[0].set_linewidth(0.3)
axarr[2].set_ylim([0,1.5])
axarr[3].set_ylim([-10,180])
for ax in axarr:
ax.set_xticks([])
axarr[0].plot([500,1500],[70,70],'k-',lw=0.5,clip_on=False)
axarr[0].text(1000,73,'1000 ms',fontsize=5,ha='center',va='bottom',clip_on=False)
axarr[0].set_yticks([])
axarr[2].set_ylabel('Firing rate \n(spikes/sec) ',fontsize=5)
axarr[3].set_ylabel('$f_{\mathrm{dd}}$ (A.U.)',fontsize=5)
fig1.text(0.03, 0.8, 'A', fontsize=11)
fig1.text(0.03, 0.59, 'B', fontsize=11)
fig1.text(0.03, 0.46, 'C', fontsize=11)
for ax in axarr:
for line in ax.yaxis.get_ticklines():
line.set_markeredgewidth(0.3)
line.set_markersize(2)
fig1.savefig('fig_longseq.pdf')