#cp drawcorticalpops_SCZ.py drawfig3.py
from pylab import *
import scipy.io
import scipy.stats
from os.path import exists
from matplotlib.collections import PatchCollection
from scipy.ndimage import gaussian_filter1d
def boxoff(ax,whichxoff='top'):
ax.spines[whichxoff].set_visible(False)
ax.spines['right'].set_visible(False)
ax.get_xaxis().tick_bottom()
ax.get_yaxis().tick_left()
def mybar(ax,x,y,facecolor=[],linewidth=0.3,w=0.4):
qs = quantile(y, [0,0.25,0.5,0.75,1])
polygon = Polygon(array([[x-w,x+w,x+w,x-w],[qs[1],qs[1],qs[3],qs[3]]]).T)
p = PatchCollection([polygon], cmap=matplotlib.cm.jet)
if type(facecolor) is not list or len(facecolor) > 0:
p.set_facecolor(facecolor)
p.set_edgecolor('#000000')
p.set_linewidth(0.3)
ax.add_collection(p)
a2 = ax.plot([x-w,x+w,x,x,x-w,x+w,x,x,x-w,x+w],[qs[0],qs[0],qs[0],qs[2],qs[2],qs[2],qs[2],qs[4],qs[4],qs[4]],'k-',lw=linewidth)
return [p,a2]
Nsamp = 20
if len(sys.argv) > 1:
Nsamp = int(float(sys.argv[1]))
cols = ['#000000','#999900']
dimcols = ['#AAAAAA','#CCCC00']
it_target = 2300
its_nontarget = [800,1300,1800,2800,3300]
f,axarr = subplots(5,1)
axarr[0].set_position([0.08, 0.75+0.22*1/3, 0.88, 0.22*2/3])
axarr[1].set_position([0.08, 0.43+0.22*1/3+0.27*1/3, 0.88, 0.27*2/3])
axarr[2].set_position([0.08, 0.21+0.22*1/3+0.27*1/3+0.16*1/3, 0.26, 0.16*2/3])
axarr[3].set_position([0.39, 0.21+0.22*1/3+0.27*1/3+0.16*1/3, 0.26, 0.16*2/3])
axarr[4].set_position([0.70, 0.21+0.22*1/3+0.27*1/3+0.16*1/3, 0.26, 0.16*2/3])
#MMNtypes = ['Freq. dev.','Omission','Dur. dev.','Inv. dur. dev']
MMNtypes = ['Frequency deviant','Omission','Duration deviant','Inv. duration deviant']
MMNtypes_short = ['Freq.','Om.','Dur.','Inv. dur.']
for iax in range(0,len(axarr)):
axarr[iax].tick_params(axis='both', which='major', labelsize=4, direction='out', width=0.4, length=2)
#axarr[iax].set_xticks([])
boxoff(axarr[iax])
for iax in [0,1]:
polygon = Polygon(array([[-20,-20,5500,5500],[40-40*iax,80-40*iax,80-40*iax,40-40*iax]]).T)
p = PatchCollection([polygon], cmap=matplotlib.cm.jet)
p.set_facecolor('#F9F9F9' if iax == 0 else '#F7F7FF')
p.set_edgecolor(None)
axarr[0].add_collection(p)
for iax in [0,1]:
for iMMN in [0,1,2,3]:
for itarget in [0,1]:
if iax == 0:
x = 1400*iMMN+600*itarget
polygon = Polygon(array([[x,x+500,x+500,x],[-1e5,-1e5,1e5,1e5]]).T)
print('iMMN='+str(iMMN)+', x = '+str(x))
else:
x = 1400*iMMN+600*itarget
polygon = Polygon(array([[x,x+500,x+500,x],[-1e5,-1e5,1e5,1e5]]).T)
print('iMMN='+str(iMMN)+', x = '+str(x))
p = PatchCollection([polygon], cmap=matplotlib.cm.jet)
p.set_facecolor('#EEEEEE' if itarget == 1 else '#DDDDFF')
p.set_edgecolor(None)
axarr[iax].add_collection(p)
axarr[0].text(-55,20,'Fewer\nsyn.',fontsize=5,rotation=90,ha='right',va='center',color=cols[1])
axarr[0].text(-55,60,'CTRL',fontsize=5,rotation=90,ha='right',va='center',color=cols[0])
axarr[0].set_ylim([0,85])
axarr[1].set_ylim([-0.01,4.5])
if True:
Nspikes_target_all = []
Nspikes_nontarget_all = []
curves_all = []
for imodel in range(0,16):
Nspikes_target = []
Nspikes_nontarget = []
curves = []
for igroup in range(0,2):
groupExt = '' if igroup == 0 else '_19percentMorePruned'
print('Working on model '+str(imodel))
Nspikes_target_thisgroup = []
Nspikes_nontarget_thisgroup = []
curves_thisgroup = []
for myseed in range(1,1+Nsamp):
#MMNs_2pm_sep_noISDIDD_limtau_withnoisyfittedcortical_Nperpop40_SD0.3_model7_CTRLpop_seed13.mat
filename = 'MMNs_2pm_sep_noISDIDD_limtau_withnoisyfittedcortical'+groupExt+'_Nperpop40_SD0.3_model'+str(imodel)+'_CTRLpop_seed'+str(myseed)+'.mat'
if not exists(filename):
print(filename+' does not exist')
continue
A = scipy.io.loadmat(filename)
if myseed == 1:
print('Loaded '+filename)
curves_thissamp = []
Nspikes_target_thissamp = []
Nspikes_nontarget_thissamp = []
for iMMNtype in range(0,4):
A_iMMNtype_order = [1,0,2,3]
spikes = A['cortOutput'][A_iMMNtype_order[iMMNtype]][0][0]
spikers = A['cortOutput'][A_iMMNtype_order[iMMNtype]][1][0]
#Target (deviant):
Nspikes_target_thissamp.append(len([1 for x in spikes if x >= it_target and x < it_target+500]))
for it in [it_target]:
if myseed == 1 and imodel == 0:
axarr[0].plot([spikes[i]-it+iMMNtype*1400+0 for i in range(0,len(spikes)) if spikes[i] >= it and spikes[i] < it+500],[spikers[i]+40*(igroup==0) for i in range(0,len(spikes)) if spikes[i] >= it and spikes[i] < it+500],'.',
ms=0.5,mew=0.5,lw=0.5,color=cols[igroup])
print('igroup = '+str(igroup)+', iMMNtype = '+str(iMMNtype)+', standards, '+str(len([spikes[i]-it+iMMNtype*1400+0 for i in range(0,len(spikes)) if spikes[i] >= it and spikes[i] < it+500]))+' plotted')
mysigma = 25 #25 ms std
# Parameters
n_samples = 4000
spike_train = np.zeros(n_samples)
rounded_spikes = np.round(spikes).astype(int)
valid_spikes = rounded_spikes[(rounded_spikes >= 0) & (rounded_spikes < n_samples)]
spike_train = np.bincount(valid_spikes, minlength=n_samples).astype(float)
# Apply Gaussian smoothing (convolution)
thiscurve = gaussian_filter1d(spike_train, sigma=mysigma)
# If you want to normalize as in your original formula, you can scale:
#thiscurve /= (mysigma * np.sqrt(2 * np.pi))
#Non-target (standards):
Nspikes_nontarget_this = 0
nontarget_curve = zeros([500])
for it in its_nontarget:
nontarget_curve = nontarget_curve + thiscurve[it:it+500]
Nspikes_nontarget_this = Nspikes_nontarget_this + len([1 for x in spikes if x >= it and x < it+500])
if it == its_nontarget[2]:
if myseed == 1 and imodel == 0:
axarr[0].plot([spikes[i]-it+iMMNtype*1400+600 for i in range(0,len(spikes)) if spikes[i] >= it and spikes[i] < it+500],[spikers[i]+40*(igroup==0) for i in range(0,len(spikes)) if spikes[i] >= it and spikes[i] < it+500],'.',
ms=0.5,mew=0.5,lw=0.5,color=cols[igroup])
print('igroup = '+str(igroup)+', iMMNtype = '+str(iMMNtype)+', deviant, '+str(len([spikes[i]-it+iMMNtype*1400+0 for i in range(0,len(spikes)) if spikes[i] >= it and spikes[i] < it+500]))+' plotted')
nontarget_curve = nontarget_curve/len(its_nontarget) #normalize by the number of stimuli (in [900,1400,1900,2900,3400])
Nspikes_nontarget_thissamp.append(Nspikes_nontarget_this/len(its_nontarget))
curves_thissamp.append(thiscurve[:])
Nspikes_target_thisgroup.append(Nspikes_target_thissamp[:])
Nspikes_nontarget_thisgroup.append(Nspikes_nontarget_thissamp[:])
curves_thisgroup.append(curves_thissamp[:])
Nspikes_target.append(Nspikes_target_thisgroup[:])
Nspikes_nontarget.append(Nspikes_nontarget_thisgroup[:])
mean_curves = mean(array(curves_thisgroup),axis=0)
for iMMNtype in range(0,4):
target_curve = mean_curves[iMMNtype][it_target:it_target+500]
nontarget_curve = zeros([500])
for it in its_nontarget:
nontarget_curve = nontarget_curve + mean_curves[iMMNtype][it:it+500]
nontarget_curve = nontarget_curve/len(its_nontarget) #normalize by the number of stimuli (in [900,1400,1900,2900,3400])
#axarr[1].plot(range(700*iMMNtype,700*iMMNtype+250),target_curve,'-',lw=0.1,color=dimcols[igroup])
#axarr[1].plot(range(700*iMMNtype+300,700*iMMNtype+550),nontarget_curve,'-',lw=0.1,color=dimcols[igroup])
#axarr[2,iarea].plot(range(0,4000),mean_curves,'-',lw=0.5,color=cols[igroup*(1+iarea)])
curves.append(mean_curves[:])
Nspikes_target_all.append(Nspikes_target[:])
Nspikes_nontarget_all.append(Nspikes_nontarget[:])
curves_all.append(curves[:])
for iMMNtype in range(0,4):
for igroup in range(0,2):
mean_curve = mean(array([curves_all[imodel][igroup][iMMNtype] for imodel in range(0,len(curves_all))]),axis=0)
target_curve = mean_curve[it_target:it_target+500]
nontarget_curve = zeros([500])
for it in its_nontarget:
nontarget_curve = nontarget_curve + mean_curve[it:it+500]
nontarget_curve = nontarget_curve/len(its_nontarget) #normalize by the number of stimuli (in [900,1400,1900,2900,3400])
axarr[1].plot(range(1400*iMMNtype,1400*iMMNtype+500),target_curve,'-',lw=0.45,color=cols[igroup])
axarr[1].plot(range(1400*iMMNtype+600,1400*iMMNtype+1100),nontarget_curve,'-',lw=0.45,color=cols[igroup])
#Divide the Nspikes per 40 and 0.5 to get the spiking rate per neuron
mybar(axarr[2],iMMNtype*4+igroup,[mean([Nspikes_target_all[imodel][igroup][i][iMMNtype]/40/0.5 for imodel in range(0,len(Nspikes_target_all))]) for i in range(0,len(Nspikes_target_all[0][igroup]))],facecolor=dimcols[igroup])
mybar(axarr[3],iMMNtype*4+igroup,[mean([Nspikes_nontarget_all[imodel][igroup][i][iMMNtype]/40/0.5 for imodel in range(0,len(Nspikes_nontarget_all))]) for i in range(0,len(Nspikes_nontarget_all[0][igroup]))],facecolor=dimcols[igroup])
mybar(axarr[4],iMMNtype*4+igroup,[mean([Nspikes_target_all[imodel][igroup][i][iMMNtype]/40/0.5-Nspikes_nontarget_all[imodel][igroup][i][iMMNtype]/40/0.5 for imodel in range(0,len(Nspikes_target_all))]) for i in range(0,len(Nspikes_target_all[0][igroup]))],facecolor=dimcols[igroup])
for iax in [2,3,4]:
axarr[iax].set_xlim([-0.5,14])
axarr[iax].set_xticks([])
for iMMN in range(0,4):
axarr[iax].text(4*iMMN+0.5,8-4*(iax>2),MMNtypes_short[iMMN],va='top',ha='center',fontsize=5)
axarr[2].set_ylim([0,31])
axarr[3].set_ylim([0,17])
axarr[4].set_ylim([0,22])
for iMMNtype in range(0,4):
for igroup in [1]:
print('MMNtype = '+str(iMMNtype)+':')
pval1 = scipy.stats.ranksums([mean([Nspikes_target_all[imodel][0][i][iMMNtype] for imodel in range(0,len(Nspikes_target_all))]) for i in range(0,len(Nspikes_target_all[0][0]))], [mean([Nspikes_target_all[imodel][igroup][i][iMMNtype] for imodel in range(0,len(Nspikes_target_all))]) for i in range(0,len(Nspikes_target_all[0][igroup]))])[1]
pval2 = scipy.stats.ranksums([mean([Nspikes_nontarget_all[imodel][0][i][iMMNtype] for imodel in range(0,len(Nspikes_nontarget_all))]) for i in range(0,len(Nspikes_nontarget_all[0][0]))], [mean([Nspikes_nontarget_all[imodel][igroup][i][iMMNtype] for imodel in range(0,len(Nspikes_nontarget_all))]) for i in range(0,len(Nspikes_nontarget_all[0][igroup]))])[1]
pval3 = scipy.stats.ranksums([mean([Nspikes_target_all[imodel][0][i][iMMNtype]/Nspikes_nontarget_all[imodel][0][i][iMMNtype] if Nspikes_nontarget_all[imodel][0][i][iMMNtype] > 0 else (1 if Nspikes_target_all[imodel][0][i][iMMNtype] == 0 else 1e8) for imodel in range(0,len(Nspikes_target_all))]) for i in range(0,len(Nspikes_target_all[0][0]))], [mean([Nspikes_target_all[imodel][igroup][i][iMMNtype]/Nspikes_nontarget_all[imodel][igroup][i][iMMNtype] if Nspikes_nontarget_all[imodel][igroup][i][iMMNtype] > 0 else (1 if Nspikes_target_all[imodel][igroup][i][iMMNtype] == 0 else 1e8) for imodel in range(0,len(Nspikes_target_all))]) for i in range(0,len(Nspikes_target_all[0][0]))])[1]
pval4 = scipy.stats.ranksums([mean([Nspikes_target_all[imodel][0][i][iMMNtype]-Nspikes_nontarget_all[imodel][0][i][iMMNtype] for imodel in range(0,len(Nspikes_target_all))]) for i in range(0,len(Nspikes_target_all[0][0]))], [mean([Nspikes_target_all[imodel][igroup][i][iMMNtype]-Nspikes_nontarget_all[imodel][igroup][i][iMMNtype] for imodel in range(0,len(Nspikes_target_all))]) for i in range(0,len(Nspikes_target_all[0][0]))])[1]
print(' p-val of number of spikes (target) = '+str(pval1)+', Nspikes = '+str(mean([median([Nspikes_target_all[imodel][0][i][iMMNtype] for imodel in range(0,len(Nspikes_target_all))]) for i in range(0,len(Nspikes_target_all[0][0]))]))+' vs '+str(median([mean([Nspikes_target_all[imodel][igroup][i][iMMNtype] for imodel in range(0,len(Nspikes_target_all))]) for i in range(0,len(Nspikes_target_all[0][igroup]))])))
print(' p-val of number of spikes (nontarget) = '+str(pval2))
print(' p-val of number of spikes (target/nontarget) = '+str(pval3)+' ('+str(median([mean([Nspikes_target_all[imodel][0][i][iMMNtype]/Nspikes_nontarget_all[imodel][0][i][iMMNtype] if Nspikes_nontarget_all[imodel][0][i][iMMNtype] > 0 else (1 if Nspikes_target_all[imodel][0][i][iMMNtype] == 0 else 1e8) for imodel in range(0,len(Nspikes_target_all))]) for i in range(0,len(Nspikes_target_all[0][0]))]))+' vs '+str(median([mean([Nspikes_target_all[imodel][igroup][i][iMMNtype]/Nspikes_nontarget_all[imodel][igroup][i][iMMNtype] if Nspikes_nontarget_all[imodel][igroup][i][iMMNtype] > 0 else (1 if Nspikes_target_all[imodel][igroup][i][iMMNtype] == 0 else 1e8) for imodel in range(0,len(Nspikes_target_all))]) for i in range(0,len(Nspikes_target_all[0][0]))]))+')')
print(' p-val of number of spikes (target-nontarget) = '+str(pval4)+' ('+str(median([mean([Nspikes_target_all[imodel][0][i][iMMNtype]-Nspikes_nontarget_all[imodel][0][i][iMMNtype] for imodel in range(0,len(Nspikes_target_all))]) for i in range(0,len(Nspikes_target_all[0][0]))]))+' vs '+str(median([mean([Nspikes_target_all[imodel][igroup][i][iMMNtype]-Nspikes_nontarget_all[imodel][igroup][i][iMMNtype] for imodel in range(0,len(Nspikes_target_all))]) for i in range(0,len(Nspikes_target_all[0][0]))]))+')')
#axarr[4].text(700*iMMNtype+300,1.5-0.3*(igroup==2),'p='+'{:.2g}'.format(pval4)+' '+('*' if pval4<0.05 else '')+('*' if pval4<0.005 else '')+('*' if pval4<0.0005 else '')+('*' if pval4<0.00005 else ''),fontsize=5,color=cols[igroup])
if pval1 < 0.05/4:
axarr[2].plot([iMMNtype*4,iMMNtype*4,iMMNtype*4+igroup,iMMNtype*4+igroup],[28,28+igroup,28+igroup,28],'k-',lw=0.3)
axarr[2].text(mean([iMMNtype*4,iMMNtype*4+igroup]),28+igroup+0.1,'*',fontsize=5,ha='center',va='bottom')
if pval2 < 0.05/4:
axarr[3].plot([iMMNtype*4,iMMNtype*4,iMMNtype*4+igroup,iMMNtype*4+igroup],[12,12+igroup,12+igroup,12],'k-',lw=0.3)
axarr[3].text(mean([iMMNtype*4,iMMNtype*4+igroup]),12+igroup+0.1,'*',fontsize=5,ha='center',va='bottom')
if pval4 < 0.05/4:
axarr[4].plot([iMMNtype*4,iMMNtype*4,iMMNtype*4+igroup,iMMNtype*4+igroup],[19,19+igroup,19+igroup,19],'k-',lw=0.3)
axarr[4].text(mean([iMMNtype*4,iMMNtype*4+igroup]),19+igroup+0.1,'*',fontsize=5,ha='center',va='bottom')
for iax in [0,1]:
axarr[iax].plot([5100,5350],[80-77*iax,80-77*iax],'k-',lw=0.45)
axarr[iax].text(5225,81-77.9*iax,'250 ms',fontsize=5,ha='center',va='bottom')
axarr[iax].set_xlim([-50,5400])
axarr[iax].set_xticks([])
axarr[0].set_yticks([])
axarr[2].set_title('$f_{\\mathrm{deviant}}$',fontsize=6)
axarr[3].set_title('$f_{\\mathrm{standard}}$',fontsize=6)
axarr[4].set_title('$f_{\mathrm{dd}}$',fontsize=6)
for iax in [0,1]:
for iMMN in [0,1,2,3]:
#axarr[iax].text(700*iMMN+275.5,1.505,MMNtypes[iMMN],ha='center',va='bottom',fontsize=5,clip_on=False,color='white')
axarr[iax].text(1400*iMMN+550,4.5+80*(iax==0),MMNtypes[iMMN],ha='center',va='bottom',fontsize=5,clip_on=False)
#axarr[16].plot([2400,2600],[3.0,3.0],'k-',lw=1.5)
#axarr[16].text(2500,3.05,'200 ms',fontsize=5,ha='center',va='bottom')
for iax in range(0,5):
pos = axarr[iax].get_position()
f.text(pos.x0 - 0.05 - 0.02*(iax==1), pos.y1 - 0.01, chr(ord('J')+iax), fontsize=11)
axarr[1].set_ylabel('Firing rate (spikes/sec)',fontsize=6)
axarr[2].set_ylabel('(A.U.)',fontsize=6)
axarr[3].set_ylabel('(A.U.)',fontsize=6)
axarr[4].set_ylabel('(A.U.)',fontsize=6)
f.savefig("fig_SCZ_corticalsubcortical_c_Nsamp"+str(Nsamp)+".pdf")