#cp drawfig3b.py drawfig2b.py
#cp drawcorticalpops_SCZ.py drawfig3.py
from pylab import *
import scipy.io
import scipy.stats
from os.path import exists
from matplotlib.collections import PatchCollection
from scipy.ndimage import gaussian_filter1d

def boxoff(ax,whichxoff='top'):
    ax.spines[whichxoff].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.get_xaxis().tick_bottom()
    ax.get_yaxis().tick_left()

def mybar(ax,x,y,facecolor=[],linewidth=0.1,w=0.4):
  qs = quantile(y, [0,0.25,0.5,0.75,1])
  polygon = Polygon(array([[x-w,x+w,x+w,x-w],[qs[1],qs[1],qs[3],qs[3]]]).T)
  p = PatchCollection([polygon], cmap=matplotlib.cm.jet)
  if type(facecolor) is not list or len(facecolor) > 0:
    p.set_facecolor(facecolor)
  p.set_edgecolor('#000000')
  p.set_linewidth(0.3)
  ax.add_collection(p)
  a2 = ax.plot([x-w,x+w,x,x,x-w,x+w,x,x,x-w,x+w],[qs[0],qs[0],qs[0],qs[2],qs[2],qs[2],qs[2],qs[4],qs[4],qs[4]],'k-',lw=linewidth)
  return [p,a2]
  
Nsamp = 20
if len(sys.argv) > 1:
  Nsamp = int(float(sys.argv[1]))

areas = ['ACC','PFC']

cols = ['#000000','#999900']
dimcols = ['#AAAAAA','#DDDD88']

it_target = 2300
its_nontarget = [800,1300,1800,2800,3300]
f,axarr = subplots(5,1)
axarr[0].set_position([0.08, 0.75+0.22*1/3, 0.88, 0.22*2/3])
axarr[1].set_position([0.08, 0.43+0.22*1/3+0.27*1/3, 0.88, 0.27*2/3])
axarr[2].set_position([0.08, 0.21+0.22*1/3+0.27*1/3+0.16*1/3, 0.26, 0.16*2/3])
axarr[3].set_position([0.39, 0.21+0.22*1/3+0.27*1/3+0.16*1/3, 0.26, 0.16*2/3])
axarr[4].set_position([0.70, 0.21+0.22*1/3+0.27*1/3+0.16*1/3, 0.26, 0.16*2/3])

#MMNtypes = ['Freq. dev.','Omission','Dur. dev.','Inv. dur. dev']
MMNtypes = ['Frequency deviant','Omission','Duration deviant','Inv. duration deviant']
MMNtypes_short = ['Freq.','Om.','Dur.','Inv. dur.']
for iax in range(0,len(axarr)):
  axarr[iax].tick_params(axis='both', which='major', labelsize=4, direction='out', width=0.4, length=2)
  #axarr[iax].set_xticks([])
  boxoff(axarr[iax])

for iax in [0,1]:
  polygon = Polygon(array([[-20,-20,5500,5500],[40-40*iax,80-40*iax,80-40*iax,40-40*iax,]]).T)
  p = PatchCollection([polygon], cmap=matplotlib.cm.jet)
  p.set_facecolor('#F9F9F9' if iax == 0 else '#FFFFF0')
  p.set_edgecolor(None)
  axarr[0].add_collection(p)
for iax in [0,1]:
  for iMMN in [0,1,2,3]:
    for itarget in [0,1]:
      if iax == 0:
        x = 1400*iMMN+600*itarget
        polygon = Polygon(array([[x,x+500,x+500,x],[-1e5,-1e5,1e5,1e5]]).T)
        print('iMMN='+str(iMMN)+', x = '+str(x))
      else:
        x = 1400*iMMN+600*itarget
        polygon = Polygon(array([[x,x+500,x+500,x],[-1e5,-1e5,1e5,1e5]]).T)
        print('iMMN='+str(iMMN)+', x = '+str(x))
      p = PatchCollection([polygon], cmap=matplotlib.cm.jet)
      p.set_facecolor('#EEEEEE' if itarget == 1 else '#DDDDFF')
      p.set_edgecolor(None)
      axarr[iax].add_collection(p)

axarr[0].text(-55,20,'SCZ \n(-19\%)\ncond.',fontsize=5,rotation=90,ha='right',va='center',color=cols[1])
axarr[0].text(-55,60,'CTRL',fontsize=5,rotation=90,ha='right',va='center',color=cols[0])
axarr[0].set_ylim([0,85])
axarr[1].set_ylim([-0.01,1.5])
    
if True:
  Nspikes_target_all = []
  Nspikes_nontarget_all = []
  curves_all = []
  for imodel in range(0,16):
    Nspikes_target = []
    Nspikes_nontarget = []
    curves = []
    for igroup in range(0,2):
      groupExt = '' if igroup == 0 else '_19percentMorePruned'
      print('Working on model '+str(imodel))
      Nspikes_target_thisgroup = []
      Nspikes_nontarget_thisgroup = []
      curves_thisgroup = []
      for myseed in range(1,1+Nsamp):
        #MMNs_2pm_sep_noISDIDD_limtau_withnoisyfittedcortical_Nperpop40_SD0.3_model7_CTRLpop_seed13.mat
        #          'MMNs_2pm_sep_noISDIDD_19percentMorePruned_model7_CTRLpop_seed5.mat'
        filename = 'MMNs_2pm_sep_noISDIDD_19percentMorePruned_model'+str(imodel)+'_CTRLpop_seed'+str(myseed)+'.mat' if igroup == 1 else 'MMNs_2pm_sep_noISDIDD_model'+str(imodel)+'_CTRLpop_AUCbased_seed'+str(myseed)+'.mat'
        if not exists(filename):
          print(filename+' does not exist')
          continue
        A = scipy.io.loadmat(filename)
        if myseed == 1:
          print('Loaded '+filename)
        curves_thissamp = []
        Nspikes_target_thissamp = []
        Nspikes_nontarget_thissamp = []
        for iMMNtype in range(0,4):
          A_iMMNtype_order = [1,0,2,3]
          spikes = A['output'][A_iMMNtype_order[iMMNtype]][0][0] 
          spikers = A['output'][A_iMMNtype_order[iMMNtype]][1][0]

          #Target (deviant):
          Nspikes_target_thissamp.append(len([1 for x in spikes if x >= it_target and x < it_target+500]))
          for it in [it_target]:
            if myseed == 1 and imodel == 0:
              axarr[0].plot([spikes[i]-it+iMMNtype*1400+0 for i in range(0,len(spikes)) if spikes[i] >= it and spikes[i] < it+500],[spikers[i]+40*(igroup==0) for i in range(0,len(spikes)) if spikes[i] >= it and spikes[i] < it+500],'.',
                            ms=0.5,mew=0.5,lw=0.5,color=cols[igroup])
              print('igroup = '+str(igroup)+', iMMNtype = '+str(iMMNtype)+', standards, '+str(len([spikes[i]-it+iMMNtype*1400+0 for i in range(0,len(spikes)) if spikes[i] >= it and spikes[i] < it+500]))+' plotted')
        
          mysigma = 25 #25 ms std
          # Parameters
          n_samples = 4000
          spike_train = np.zeros(n_samples)
          rounded_spikes = np.round(spikes).astype(int)
          valid_spikes = rounded_spikes[(rounded_spikes >= 0) & (rounded_spikes < n_samples)]
          spike_train = np.bincount(valid_spikes, minlength=n_samples).astype(float)

          # Apply Gaussian smoothing (convolution)
          thiscurve = gaussian_filter1d(spike_train, sigma=mysigma)
          # If you want to normalize as in your original formula, you can scale:
          #thiscurve /= (mysigma * np.sqrt(2 * np.pi))

          #thiscurve2 = zeros([4000])
          #for ispike in range(0,len(spikes)):
          #  thiscurve2 = thiscurve2 + 1/mysigma/sqrt(2*pi)*exp(-0.5*((array(range(0,4000))-spikes[ispike])/mysigma)**2)

          #Non-target (standards):
          Nspikes_nontarget_this = 0
          nontarget_curve = zeros([500])
          for it in its_nontarget:
            nontarget_curve = nontarget_curve +  thiscurve[it:it+500]
            Nspikes_nontarget_this = Nspikes_nontarget_this + len([1 for x in spikes if x >= it and x < it+500])
            if it == its_nontarget[2]:
              if myseed == 1 and imodel == 0:
                axarr[0].plot([spikes[i]-it+iMMNtype*1400+600 for i in range(0,len(spikes)) if spikes[i] >= it and spikes[i] < it+500],[spikers[i]+40*(igroup==0) for i in range(0,len(spikes)) if spikes[i] >= it and spikes[i] < it+500],'.',
                              ms=0.5,mew=0.5,lw=0.5,color=cols[igroup])
                print('igroup = '+str(igroup)+', iMMNtype = '+str(iMMNtype)+', deviant, '+str(len([spikes[i]-it+iMMNtype*1400+0 for i in range(0,len(spikes)) if spikes[i] >= it and spikes[i] < it+500]))+' plotted')
          nontarget_curve = nontarget_curve/len(its_nontarget) #normalize by the number of stimuli (in [900,1400,1900,2900,3400])
          Nspikes_nontarget_thissamp.append(Nspikes_nontarget_this/len(its_nontarget))
          curves_thissamp.append(thiscurve[:])
        Nspikes_target_thisgroup.append(Nspikes_target_thissamp[:])
        Nspikes_nontarget_thisgroup.append(Nspikes_nontarget_thissamp[:])
        curves_thisgroup.append(curves_thissamp[:])
        
      Nspikes_target.append(Nspikes_target_thisgroup[:])
      Nspikes_nontarget.append(Nspikes_nontarget_thisgroup[:])
      mean_curves = mean(array(curves_thisgroup),axis=0)

      for iMMNtype in range(0,4):
        target_curve = mean_curves[iMMNtype][it_target:it_target+500]
        nontarget_curve = zeros([500])
        for it in its_nontarget:
          nontarget_curve = nontarget_curve +  mean_curves[iMMNtype][it:it+500]
        nontarget_curve = nontarget_curve/len(its_nontarget) #normalize by the number of stimuli (in [900,1400,1900,2900,3400])

        #axarr[1].plot(range(700*iMMNtype,700*iMMNtype+250),target_curve,'-',lw=0.1,color=dimcols[igroup])
        #axarr[1].plot(range(700*iMMNtype+300,700*iMMNtype+550),nontarget_curve,'-',lw=0.1,color=dimcols[igroup])
      
      #axarr[2,iarea].plot(range(0,4000),mean_curves,'-',lw=0.5,color=cols[igroup*(1+iarea)])
      curves.append(mean_curves[:])
    Nspikes_target_all.append(Nspikes_target[:])
    Nspikes_nontarget_all.append(Nspikes_nontarget[:])
    curves_all.append(curves[:])
    
  for iMMNtype in range(0,4):
    for igroup in range(0,2):
      mean_curve = mean(array([curves_all[imodel][igroup][iMMNtype] for imodel in range(0,len(curves_all))]),axis=0)
      target_curve = mean_curve[it_target:it_target+500]
  
      nontarget_curve = zeros([500])
      for it in its_nontarget:
        nontarget_curve = nontarget_curve +  mean_curve[it:it+500]
      nontarget_curve = nontarget_curve/len(its_nontarget) #normalize by the number of stimuli (in [900,1400,1900,2900,3400])                                                                                                                                           
        
      axarr[1].plot(range(1400*iMMNtype,1400*iMMNtype+500),target_curve,'-',lw=0.45,color=cols[igroup])
      axarr[1].plot(range(1400*iMMNtype+600,1400*iMMNtype+1100),nontarget_curve,'-',lw=0.45,color=cols[igroup])

      #Divide the Nspikes per 40 and 0.5 to get the spiking rate per neuron
      mybar(axarr[2],iMMNtype*4+igroup,[mean([Nspikes_target_all[imodel][igroup][i][iMMNtype]/40/0.5 for imodel in range(0,len(Nspikes_target_all))]) for i in range(0,len(Nspikes_target_all[0][igroup]))],facecolor=dimcols[igroup])
      mybar(axarr[3],iMMNtype*4+igroup,[mean([Nspikes_nontarget_all[imodel][igroup][i][iMMNtype]/40/0.5 for imodel in range(0,len(Nspikes_nontarget_all))]) for i in range(0,len(Nspikes_nontarget_all[0][igroup]))],facecolor=dimcols[igroup])
      mybar(axarr[4],iMMNtype*4+igroup,[mean([Nspikes_target_all[imodel][igroup][i][iMMNtype]/40/0.5-Nspikes_nontarget_all[imodel][igroup][i][iMMNtype]/40/0.5 for imodel in range(0,len(Nspikes_target_all))]) for i in range(0,len(Nspikes_target_all[0][igroup]))],facecolor=dimcols[igroup])

  for iax in [2,3,4]:
    axarr[iax].set_xlim([-0.5,14])
    axarr[iax].set_xticks([])
    for iMMN in range(0,4):
      axarr[iax].text(4*iMMN+0.5,0.2*(iax in [2,4]),MMNtypes_short[iMMN],va='top',ha='center',fontsize=5)
  axarr[2].set_ylim([-0.9,7])
  axarr[3].set_ylim([-0.4,2.0])
  axarr[4].set_ylim([-0.9,6.8])

  for iMMNtype in range(0,4):
    for igroup in [1]:
      print('MMNtype = '+str(iMMNtype)+':')
      pval1 = scipy.stats.ranksums([mean([Nspikes_target_all[imodel][0][i][iMMNtype] for imodel in range(0,len(Nspikes_target_all))]) for i in range(0,len(Nspikes_target_all[0][0]))], [mean([Nspikes_target_all[imodel][igroup][i][iMMNtype] for imodel in range(0,len(Nspikes_target_all))]) for i in range(0,len(Nspikes_target_all[0][igroup]))])[1]
      pval2 = scipy.stats.ranksums([mean([Nspikes_nontarget_all[imodel][0][i][iMMNtype] for imodel in range(0,len(Nspikes_nontarget_all))]) for i in range(0,len(Nspikes_nontarget_all[0][0]))], [mean([Nspikes_nontarget_all[imodel][igroup][i][iMMNtype] for imodel in range(0,len(Nspikes_nontarget_all))]) for i in range(0,len(Nspikes_nontarget_all[0][igroup]))])[1]
      pval3 = scipy.stats.ranksums([mean([Nspikes_target_all[imodel][0][i][iMMNtype]/Nspikes_nontarget_all[imodel][0][i][iMMNtype] if Nspikes_nontarget_all[imodel][0][i][iMMNtype] > 0 else (1 if Nspikes_target_all[imodel][0][i][iMMNtype] == 0 else 1e8) for imodel in range(0,len(Nspikes_target_all))]) for i in range(0,len(Nspikes_target_all[0][0]))], [mean([Nspikes_target_all[imodel][igroup][i][iMMNtype]/Nspikes_nontarget_all[imodel][igroup][i][iMMNtype] if Nspikes_nontarget_all[imodel][igroup][i][iMMNtype] > 0 else (1 if Nspikes_target_all[imodel][igroup][i][iMMNtype] == 0 else 1e8) for imodel in range(0,len(Nspikes_target_all))]) for i in range(0,len(Nspikes_target_all[0][0]))])[1]
      pval4 = scipy.stats.ranksums([mean([Nspikes_target_all[imodel][0][i][iMMNtype]-Nspikes_nontarget_all[imodel][0][i][iMMNtype]  for imodel in range(0,len(Nspikes_target_all))]) for i in range(0,len(Nspikes_target_all[0][0]))], [mean([Nspikes_target_all[imodel][igroup][i][iMMNtype]-Nspikes_nontarget_all[imodel][igroup][i][iMMNtype] for imodel in range(0,len(Nspikes_target_all))]) for i in range(0,len(Nspikes_target_all[0][0]))])[1]
      
      print('  p-val of number of spikes (target) = '+str(pval1)+', Nspikes = '+str(mean([median([Nspikes_target_all[imodel][0][i][iMMNtype] for imodel in range(0,len(Nspikes_target_all))]) for i in range(0,len(Nspikes_target_all[0][0]))]))+' vs '+str(median([mean([Nspikes_target_all[imodel][igroup][i][iMMNtype] for imodel in range(0,len(Nspikes_target_all))]) for i in range(0,len(Nspikes_target_all[0][igroup]))])))
      print('  p-val of number of spikes (nontarget) = '+str(pval2))
      print('  p-val of number of spikes (target/nontarget) = '+str(pval3)+' ('+str(median([mean([Nspikes_target_all[imodel][0][i][iMMNtype]/Nspikes_nontarget_all[imodel][0][i][iMMNtype] if Nspikes_nontarget_all[imodel][0][i][iMMNtype] > 0 else (1 if Nspikes_target_all[imodel][0][i][iMMNtype] == 0 else 1e8) for imodel in range(0,len(Nspikes_target_all))]) for i in range(0,len(Nspikes_target_all[0][0]))]))+' vs '+str(median([mean([Nspikes_target_all[imodel][igroup][i][iMMNtype]/Nspikes_nontarget_all[imodel][igroup][i][iMMNtype] if Nspikes_nontarget_all[imodel][igroup][i][iMMNtype] > 0 else (1 if Nspikes_target_all[imodel][igroup][i][iMMNtype] == 0 else 1e8) for imodel in range(0,len(Nspikes_target_all))]) for i in range(0,len(Nspikes_target_all[0][0]))]))+')')
      print('  p-val of number of spikes (target-nontarget) = '+str(pval4)+' ('+str(median([mean([Nspikes_target_all[imodel][0][i][iMMNtype]-Nspikes_nontarget_all[imodel][0][i][iMMNtype]  for imodel in range(0,len(Nspikes_target_all))]) for i in range(0,len(Nspikes_target_all[0][0]))]))+' vs '+str(median([mean([Nspikes_target_all[imodel][igroup][i][iMMNtype]-Nspikes_nontarget_all[imodel][igroup][i][iMMNtype]  for imodel in range(0,len(Nspikes_target_all))]) for i in range(0,len(Nspikes_target_all[0][0]))]))+')')

      #axarr[4].text(700*iMMNtype+300,1.5-0.3*(igroup==2),'p='+'{:.2g}'.format(pval4)+' '+('*' if pval4<0.05 else '')+('*' if pval4<0.005 else '')+('*' if pval4<0.0005 else '')+('*' if pval4<0.00005 else ''),fontsize=5,color=cols[igroup])
      if pval1 < 0.05/4:
        axarr[2].plot([iMMNtype*4,iMMNtype*4,iMMNtype*4+igroup,iMMNtype*4+igroup],[6,6+0.4*igroup,6+0.4*igroup,6],'k-',lw=0.3)
        axarr[2].text(mean([iMMNtype*4,iMMNtype*4+igroup]),5.5+0.4*igroup+0.1,'*',fontsize=5,ha='center',va='bottom')
      if pval2 < 0.05/4:
        axarr[3].plot([iMMNtype*4,iMMNtype*4,iMMNtype*4+igroup,iMMNtype*4+igroup],[1.5,1.5+0.15*igroup,1.5+0.15*igroup,1.5],'k-',lw=0.3)
        axarr[3].text(mean([iMMNtype*4,iMMNtype*4+igroup]),1.25+0.15*igroup+0.1,'*',fontsize=5,ha='center',va='bottom')
      if pval4 < 0.05/4:
        axarr[4].plot([iMMNtype*4,iMMNtype*4,iMMNtype*4+igroup,iMMNtype*4+igroup],[5.3,5.3+0.4*igroup,5.3+0.4*igroup,5.3],'k-',lw=0.3)
        axarr[4].text(mean([iMMNtype*4,iMMNtype*4+igroup]),4.8+0.4*igroup+0.1,'*',fontsize=5,ha='center',va='bottom')

axarr[0].plot([5100,5350],[58,58],'k-',lw=0.45)
axarr[1].plot([5100,5350],[1.05,1.05],'k-',lw=0.45)
for iax in [0,1]:
  #axarr[iax].plot([5100,5350],[80-79*iax,80-79*iax],'k-',lw=0.45)
  axarr[iax].text(5225,61-59.9*iax,'250 ms',fontsize=5,ha='center',va='bottom')
  axarr[iax].set_xlim([-50,5400])
  axarr[iax].set_xticks([])
axarr[0].set_yticks([])



        
for iax in [0,1]:
  for iMMN in [0,1,2,3]:
    #axarr[iax].text(700*iMMN+275.5,1.505,MMNtypes[iMMN],ha='center',va='bottom',fontsize=5,clip_on=False,color='white')
    axarr[iax].text(1400*iMMN+550,1.5+85*(iax==0),MMNtypes[iMMN],ha='center',va='bottom',fontsize=5,clip_on=False)

#axarr[16].plot([2400,2600],[3.0,3.0],'k-',lw=1.5)
#axarr[16].text(2500,3.05,'200 ms',fontsize=5,ha='center',va='bottom')

axarr[2].set_title('$f_{\\mathrm{deviant}}$',fontsize=6)
axarr[3].set_title('$f_{\\mathrm{standard}}$',fontsize=6)
axarr[4].set_title('$f_{\mathrm{dd}}$',fontsize=6)

for iax in range(0,5):
  pos = axarr[iax].get_position()
  f.text(pos.x0 - 0.05 - 0.02*(iax==1), pos.y1 - 0.01, chr(ord('H')+iax), fontsize=11)

axarr[1].set_ylabel('Firing rate (spikes/sec)',fontsize=6)
axarr[2].set_ylabel('(A.U.)',fontsize=6)
axarr[3].set_ylabel('(A.U.)',fontsize=6)
axarr[4].set_ylabel('(A.U.)',fontsize=6)
f.savefig("fig_SCZ_allcortical_c_Nsamp"+str(Nsamp)+".pdf")