from pylab import *
import scipy.io
import mytools
from matplotlib.collections import PatchCollection
from os.path import exists
from scipy.ndimage import gaussian_filter1d

def mybar(ax,x,y,facecolor=[],linewidth=0.1,w=0.4):
  qs = quantile(y, [0,0.25,0.5,0.75,1])
  polygon = Polygon(array([[x-w,x+w,x+w,x-w],[qs[1],qs[1],qs[3],qs[3]]]).T)
  p = PatchCollection([polygon], cmap=matplotlib.cm.jet)
  if type(facecolor) is not list or len(facecolor) > 0:
    p.set_facecolor(facecolor)
  p.set_edgecolor('#000000')
  p.set_linewidth(0.3)
  ax.add_collection(p)
  a2 = ax.plot([x-w,x+w,x,x,x-w,x+w,x,x,x-w,x+w],[qs[0],qs[0],qs[0],qs[2],qs[2],qs[2],qs[2],qs[4],qs[4],qs[4]],'k-',lw=linewidth)
  return [p,a2]

filenames = sys.argv[1:]
try:
  NperpopIndex = 1
  while 'Nperpop' not in filenames[0].split('_')[NperpopIndex]:
    NperpopIndex = NperpopIndex + 1
  NperpopStr = filenames[0].split('_')[NperpopIndex]
  Nperpop = int(NperpopStr[7:])
except:
  Nperpop = 40

seqAdds = [[[[0,500],[0,0]],[[0,500],[0,0]]], #0: both stimuli the same 0- omission
           [[[0,450,450,500,500],[0,0,1,1,0]],[[0,500],[0,0]]], #1: standard short, deviant missing
           [[[0,500],[0,0]],[[0,450,450,500,500],[0,0,1,1,0]]], #2: standard missing, deviant short
           [[[0,400,400,500,500],[0,0,1,1,0]],[[0,500],[0,0]]], #3: standard long, deviant missing
           [[[0,500],[0,0]],[[0,400,400,500,500],[0,0,1,1,0]]]] #4: standard missing, deviant long

fig1, axs = subplots(4,1)
axarr = axs.reshape(prod(axs.shape),).tolist()
axs[0].set_position([0.065, 0.08+0.08*7,0.9,0.08])
axs[1].set_position([0.065, 0.08+0.08*6+0.06,0.9,0.02])
axs[0].set_xticks([])
axs[1].set_yticks([])
axs[2].set_position([0.065, 0.5,0.9,0.08])
axs[3].set_position([0.065, 0.4,0.9,0.08])
for iax in range(0,len(axarr)):
  axarr[iax].tick_params(axis='both', which='major', labelsize=4)
  for axis in ['top','bottom','left','right']:
    axarr[iax].spines[axis].set_linewidth(0.2)
  axarr[iax].set_ylim([0,Nperpop])
axarr[0].set_yticks([0,int(Nperpop/2)])

xlimmax = 0

istims = [6,11,16,21,26,31,36,42,47,52,57,62,68,73,78,84]
for iseq in istims:
  for iax in [0,1]:
    polygon = Polygon(array([[500*iseq-150,500*iseq-150,500*iseq+350,500*iseq+350],[0,Nperpop,Nperpop,0]]).T)
    p = PatchCollection([polygon], cmap=matplotlib.cm.jet)
    p.set_facecolor('#DDDDFF')
    p.set_edgecolor(None)
    axarr[iax].add_collection(p)

  for iax in [2]:
    for iblock in [0,1]:
      polygon = Polygon(array([[500*iseq-150+iblock*500,500*iseq-150+iblock*500,500*iseq+iblock*500+350,500*iseq+iblock*500+350],[0,5,5,0]]).T)
      p = PatchCollection([polygon], cmap=matplotlib.cm.jet)
      p.set_facecolor('#DDDDFF' if iblock==0 else '#EEEEEE')
      p.set_edgecolor(None)
      axarr[iax].add_collection(p)

#axs[0].text(0,Nperpop-1,'Output (EO)',fontsize=4,ha='left',va='top',fontweight='bold')
#axs[1].text(0,Nperpop-1,'Standard (blue) and deviant (red) stimulus',fontsize=4,ha='left',va='top')

cols = mytools.colorsredtolila(len(filenames)+1,0.8)
if True:
  A = scipy.io.loadmat('MMNs_2pm_sep_noISDIDD_seq0_model0_pop_seed1.mat')
  for q in ['standard', 'deviant', 'pacemaker', 'pacemaker2', 'output', 'standardBoost', 'deviantBoost']:
    try:
      shp = A[q].shape
      for iy in range(0,shp[0]):
        for ix in range(0,shp[1]):
          if A[q][iy,ix].shape[0] == 1 and A[q][iy,ix].shape[1] > 1:
            A[q][iy,ix] = A[q][iy,ix][0]
    except:
      pass

  plotteds = []
  xlimmax = max(xlimmax,500*len(A['sequence'][0]))
  for iMMN in range(0,1):
    plotteds_this = []
    axs[0].plot(A['output'][iMMN,0], A['output'][iMMN,1], 'r.', lw=0.35, ms=0.35, mew=0.35, color='#000000')
    plotteds_this.append(len(A['output'][iMMN,0]))

  for iseq in range(0,len(A['sequence'][0])):
    istim = int(A['sequence'][0][iseq])
    axs[1].plot([500*iseq+x for x in seqAdds[istim][0][0]],[24+10*x for x in seqAdds[istim][0][1]],'b-',lw=0.25)
    axs[1].plot([500*iseq+x for x in seqAdds[istim][1][0]],[6+10*x for x in seqAdds[istim][1][1]],'r-',lw=0.25)
  axs[1].plot([500*len(A['sequence'][0]),500*len(A['sequence'][0])+500],[24,24],'b-',lw=0.25)
  axs[1].plot([500*len(A['sequence'][0]),500*len(A['sequence'][0])+500],[6,6],'r-',lw=0.25)
for iax in range(0,len(axarr)):
  axarr[iax].set_xlim([0,1000+500*len(A['sequence'][0])])


#axarr[1].text(500*6,-20,'short standard to short deviant',rotation=90,fontsize=5,ha='center',va='top')
#axarr[1].text(500*11,-20,'short deviant to short standard',rotation=90,fontsize=5,ha='center',va='top')
#axarr[1].text(500*16,-20,'short standard to long standard',rotation=90,fontsize=5,ha='center',va='top')
#axarr[1].text(500*21,-20,'long standard to short standard',rotation=90,fontsize=5,ha='center',va='top')
#axarr[1].text(500*26,-20,'short standard to long deviant',rotation=90,fontsize=5,ha='center',va='top')
#axarr[1].text(500*31,-20,'long deviant to short standard',rotation=90,fontsize=5,ha='center',va='top')
#axarr[1].text(500*36,-20,'omission from short standards',rotation=90,fontsize=5,ha='center',va='top')
#axarr[1].text(500*42,-20,'short deviant to long standard',rotation=90,fontsize=5,ha='center',va='top')
#axarr[1].text(500*47,-20,'long standard to short deviant',rotation=90,fontsize=5,ha='center',va='top')
#axarr[1].text(500*52,-20,'short deviant to long deviant',rotation=90,fontsize=5,ha='center',va='top')
#axarr[1].text(500*57,-20,'long deviant to short deviant',rotation=90,fontsize=5,ha='center',va='top')
#axarr[1].text(500*62,-20,'omission from short deviant',rotation=90,fontsize=5,ha='center',va='top')
#axarr[1].text(500*68,-20,'long standard to long deviant',rotation=90,fontsize=5,ha='center',va='top')
#axarr[1].text(500*73,-20,'long deviant to long standard',rotation=90,fontsize=5,ha='center',va='top')
#axarr[1].text(500*78,-20,'omission from long standard',rotation=90,fontsize=5,ha='center',va='top')
#axarr[1].text(500*84,-20,'omission from long deviant',rotation=90,fontsize=5,ha='center',va='top')

for iiseq in range(0,16):
  iseq = istims[iiseq]
  xs = [500*iseq-150,500*iseq-550,500*iseq+1150,500*iseq+1750,500*iseq+350]
  xfracs = [x/(1000+500*len(A['sequence'][0])) for x in xs]
  verts = [(0.065+0.9*xfracs[0],0.72),(0.065+0.9*xfracs[1],0.73),(0.065+0.9*xfracs[1],0.82),(0.065+0.9*xfracs[0],0.835),(0.065+0.9*xfracs[2],0.835),(0.065+0.9*xfracs[3],0.82),(0.065+0.9*xfracs[3],0.76),(0.065+0.9*xfracs[4],0.72)]
  polygon = Polygon(verts, closed=True, transform=fig1.transFigure,
                    facecolor='#DDDDFF', edgecolor=None, zorder=-1) #, alpha=0.5)
  fig1.patches.append(polygon)  # Attach directly to figure


axarr[1].text(500*6,168,'                  short\n            standard\n      to short\ndeviant',rotation=45,fontsize=5,ha='center',va='bottom',zorder=10)
axarr[1].text(500*11,168,'                  short\n            deviant\n      to short\nstandard',rotation=45,fontsize=5,ha='center',va='bottom')
axarr[1].text(500*16,168,'                  short\n            standard\n      to long\nstandard',rotation=45,fontsize=5,ha='center',va='bottom')
axarr[1].text(500*21,168,'                  long\n            standard\n      to short\nstandard',rotation=45,fontsize=5,ha='center',va='bottom')
axarr[1].text(500*26,168,'                  short\n            standard\n      to long\ndeviant',rotation=45,fontsize=5,ha='center',va='bottom')
axarr[1].text(500*31,168,'                  long\n            deviant\n      to short\nstandard',rotation=45,fontsize=5,ha='center',va='bottom')
axarr[1].text(500*36,168,'              omission\n            from short\n      standards',rotation=45,fontsize=5,ha='center',va='bottom')
axarr[1].text(500*42,168,'                  short\n            deviant\n      to long\nstandard',rotation=45,fontsize=5,ha='center',va='bottom')
axarr[1].text(500*47,168,'                  long\n            standard\n      to short\ndeviant',rotation=45,fontsize=5,ha='center',va='bottom')
axarr[1].text(500*52,168,'                  short\n            deviant\n      to long\ndeviant',rotation=45,fontsize=5,ha='center',va='bottom')
axarr[1].text(500*57,168,'                  long\n            deviant\n      to short\ndeviant',rotation=45,fontsize=5,ha='center',va='bottom')
axarr[1].text(500*62,168,'              omission\n            from short\n  deviant',rotation=45,fontsize=5,ha='center',va='bottom')
axarr[1].text(500*68,168,'                  long\n            standard\n      to long\ndeviant',rotation=45,fontsize=5,ha='center',va='bottom')
axarr[1].text(500*73,168,'                  long\n            deviant\n      to long\nstandard',rotation=45,fontsize=5,ha='center',va='bottom')
axarr[1].text(500*78,168,'              omission\n            from long\n      standard',rotation=45,fontsize=5,ha='center',va='bottom')
axarr[1].text(500*84,168,'              omission\n            from long\n  deviant',rotation=45,fontsize=5,ha='center',va='bottom')

fig1.savefig('fig_longseq.pdf')

its_target = [500*i-200 for i in istims]
its_nontarget = [x-500 for x in its_target]
curves_all = []
Nspikes_target_all = []
Nspikes_nontarget_all = []
for imodel in range(0,16):
  Nspikes_target_thisgroup = []
  Nspikes_nontarget_thisgroup = []
  curves_thisgroup = []
  for myseed in [1]:
        #MMNs_2pm_sep_noISDIDD_limtau_withnoisyfittedcortical_Nperpop40_SD0.3_model7_CTRLpop_seed13.mat
        filename = 'MMNs_2pm_sep_noISDIDD_seq0_model'+str(imodel)+'_pop_seed1.mat'
        if not exists(filename):
          print(filename+' does not exist')
          continue
        A = scipy.io.loadmat(filename)
        if myseed == 1:
          print('Loaded '+filename)
        curves_thissamp = []
        Nspikes_target_thissamp = []
        Nspikes_nontarget_thissamp = []

        spikes = A['output'][0][0][0]
        spikers = A['output'][0][1][0]
        if type(spikes) == np.float64:
          spikes = A['output'][0][0]
          spikers = A['output'][0][1]
            

        #Target (deviant):
        #for it in [it_target]:
        #  if myseed == 1 and imodel == 0:
        #    axarr[0].plot([spikes[i]-it+iMMNtype*1400+0 for i in range(0,len(spikes)) if spikes[i] >= it and spikes[i] < it+500],[spikers[i]+80*(igroup==0)+40*(igroup==1) for i in range(0,len(spikes)) if spikes[i] >= it and spikes[i] < it+500],'.',
        #                  ms=0.5,mew=0.5,lw=0.5,color=cols[igroup])
        #    print('igroup = '+str(igroup)+', iMMNtype = '+str(iMMNtype)+', standards, '+str(len([spikes[i]-it+iMMNtype*1400+0 for i in range(0,len(spikes)) if spikes[i] >= it and spikes[i] < it+500]))+' plotted')
        
        mysigma = 25 #25 ms std
        # Parameters
        n_samples = 43000
        spike_train = np.zeros(n_samples)
        rounded_spikes = np.round(spikes).astype(int)
        valid_spikes = rounded_spikes[(rounded_spikes >= 0) & (rounded_spikes < n_samples)]
        spike_train = np.bincount(valid_spikes, minlength=n_samples).astype(float)

        # Apply Gaussian smoothing (convolution)
        thiscurve = gaussian_filter1d(spike_train, sigma=mysigma)
        # If you want to normalize as in your original formula, you can scale:
        #thiscurve /= (mysigma * np.sqrt(2 * np.pi))

        #Non-target (standards):
        Nspikes_nontarget_this = 0
        nontarget_curve = zeros([500])
        for it in its_target:
          target_curve = thiscurve[it:it+500]
          Nspikes_target_this = len([1 for x in spikes if x >= it and x < it+500])
          Nspikes_nontarget_this = len([1 for x in spikes if x >= it-500 and x < it])
          nontarget_curve = thiscurve[it-500:it]
          #axarr[2].plot([spikes[i]-it+100 for i in range(0,len(spikes)) if spikes[i] >= it and spikes[i] < it+500],[spikers[i]+80*(igroup==0)+40*(igroup==1) for i in range(0,len(spikes)) if spikes[i] >= it and
          #                                                                                                                        spikes[i] < it+500],'.', ms=0.5,mew=0.5,lw=0.5,color=cols[igroup])
          #axarr[2].plot([spikes[i]-it+1100 for i in range(0,len(spikes)) if spikes[i] >= it-500 and spikes[i] < it],[spikers[i]+80*(igroup==0)+40*(igroup==1) for i in range(0,len(spikes)) if spikes[i] >= it-500 and
          #                                                                                                                        spikes[i] < it],'.', ms=0.5,mew=0.5,lw=0.5,color=cols[igroup])
          Nspikes_target_thissamp.append(Nspikes_target_this)
          Nspikes_nontarget_thissamp.append(Nspikes_nontarget_this)

        Nspikes_target_thisgroup.append(Nspikes_target_thissamp[:])
        Nspikes_nontarget_thisgroup.append(Nspikes_nontarget_thissamp[:])
        curves_thisgroup.append(thiscurve[:])
  Nspikes_target_all.append(Nspikes_target_thisgroup[:])
  Nspikes_nontarget_all.append(Nspikes_nontarget_thisgroup[:])
  mean_curves = mean(array(curves_thisgroup),axis=0)
  curves_all.append(mean_curves[:])

mean_curve = mean(curves_all,axis=0)
for iit in range(0,len(its_target)):
  it = its_target[iit]
  axarr[2].plot(range(it,it+500),mean_curve[it:it+500],'k-',lw=0.5)
  axarr[2].plot(range(it+500,it+1000),mean_curve[it-500:it],'k-',lw=0.5)
  ddi = [Nspikes_target_all[imodel][0][iit]-Nspikes_nontarget_all[imodel][0][iit] for imodel in range(0,16)]
  #axarr[3].bar(it+250,mean(ddi),width=1000,facecolor='#FF00FF' if iit in [0,2,3,6] else '#00FFFF')
  #axarr[3].plot([it+250,it+250],[mean(ddi)-std(ddi), mean(ddi)+std(ddi)],'k-',lw=0.5)
  Q=mybar(axarr[3],it+250,ddi,facecolor='#999999' if iit in [0,2,3,6] else '#AA9900',linewidth=0.3,w=400)
  Q[1][0].set_color('#000000' if iit in [0,2,3,6] else '#555500')
  Q[0].set_edgecolor('#000000' if iit in [0,2,3,6] else '#555500')
  Q[0].set_linewidth(0.3)

axarr[2].set_ylim([0,1.5])
axarr[3].set_ylim([-10,180])

for ax in axarr:
  ax.set_xticks([])
axarr[0].plot([500,1500],[70,70],'k-',lw=0.5,clip_on=False)
axarr[0].text(1000,73,'1000 ms',fontsize=5,ha='center',va='bottom',clip_on=False)
axarr[0].set_yticks([])
axarr[2].set_ylabel('Firing rate    \n(spikes/sec)    ',fontsize=5)
axarr[3].set_ylabel('$f_{\mathrm{dd}}$ (A.U.)',fontsize=5)

fig1.text(0.03, 0.8, 'A', fontsize=11)
fig1.text(0.03, 0.59, 'B', fontsize=11)
fig1.text(0.03, 0.46, 'C', fontsize=11)

for ax in axarr:
  for line in ax.yaxis.get_ticklines():
    line.set_markeredgewidth(0.3)
    line.set_markersize(2)

fig1.savefig('fig_longseq.pdf')