from pylab import *
import scipy.io
import mytools
from matplotlib.collections import PatchCollection
from os.path import exists
from scipy.ndimage import gaussian_filter1d


filenames = ['MMNs_2pm_sep_noISDIDD_Nperpop'+str(N)+'_paramSD0.3_stimA150_130_gAMPA17.5_30.0_80.0_gNMDA5.827500000000001_9.99_26.64_gGABA35.0_35.0_dep1000_0.0_0.95_tau10.0_10.0_10.0_250.0.mat' for N in [20,30,40,50,60]]

Nperpop = 40
Nsamp = 1
def boxoff(ax,whichxoff='top'):
    ax.spines[whichxoff].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.get_xaxis().tick_bottom()
    ax.get_yaxis().tick_left()
def mybar(ax,x,y,facecolor=[],linewidth=0.1,w=0.4):
  qs = quantile(y, [0,0.25,0.5,0.75,1])
  polygon = Polygon(array([[x-w,x+w,x+w,x-w],[qs[1],qs[1],qs[3],qs[3]]]).T)
  p = PatchCollection([polygon], cmap=matplotlib.cm.jet)
  if type(facecolor) is not list or len(facecolor) > 0:
    p.set_facecolor(facecolor)
  p.set_edgecolor('#000000')
  p.set_linewidth(0.3)
  ax.add_collection(p)
  a2 = ax.plot([x-w,x+w,x,x,x-w,x+w,x,x,x-w,x+w],[qs[0],qs[0],qs[0],qs[2],qs[2],qs[2],qs[2],qs[4],qs[4],qs[4]],'k-',lw=linewidth)
  return [p,a2]

fig1, axs = subplots(1,4)
axarr = axs.reshape(prod(axs.shape),).tolist()
for iax in range(0,4):
  axs[iax].set_position([0.08+0.24*iax, 0.78-0*0.2-0.12,0.18,0.2])
for iax in range(0,len(axarr)):
  axarr[iax].tick_params(axis='both', which='major', labelsize=4)
  boxoff(axarr[iax])
  for axis in ['top','bottom','left','right']:
    axarr[iax].spines[axis].set_linewidth(0.2)
  axarr[iax].set_xlim([0,3800])
  axarr[iax].set_ylim([0,240])

#axs[0,0].text(0,Nperpop+7,'Excitatory deviant detecting output (EO)',fontsize=4,ha='left',va='top',fontweight='bold')
#axs[1,0].text(0,Nperpop+7,'Excitatory population for standards (ES)',fontsize=4,ha='left',va='top')
#axs[2,0].text(0,Nperpop+7,'Inhibitory population for standards (IS)',fontsize=4,ha='left',va='top')
#axs[3,0].text(0,Nperpop+7,'Excitatory population for standards, delayed (ESD)',fontsize=4,ha='left',va='top')
#axs[4,0].text(0,Nperpop+7,'Excitatory population for deviants (ED)',fontsize=4,ha='left',va='top')
#axs[5,0].text(0,Nperpop+7,'Inhibitory population for deviants (ID)',fontsize=4,ha='left',va='top')
#axs[6,0].text(0,Nperpop+7,'Excitatory pop. for deviants, delayed (ESD)',fontsize=4,ha='left',va='top')
#axs[7,0].text(0,Nperpop+7,'Exc. timer pop. receiving phase-locked input (EP)',fontsize=4,ha='left',va='top')
#axs[8,0].text(0,Nperpop+7,'Exc. timer pop. receiving phase-locked input, alt. phase (EP2)',fontsize=4,ha='left',va='top')


MMNorder = [1,0,2,3]
axs[MMNorder[0]].set_title('   Omission',fontsize=6,pad=12)
axs[MMNorder[1]].set_title('Frequency deviant',fontsize=6,pad=12)
axs[MMNorder[2]].set_title('Duration deviant',fontsize=6,pad=12)
axs[MMNorder[3]].set_title('Inv. dur. deviant',fontsize=6,pad=12)
cols = ['#666600','#333300','#000000','#003333','#006666']
#dimcols = ['#FFFF22','#CCCC55','#999999','#77CCCC','#55FFFF']
dimcols = ['#FFFFEE','#F8F8EE','#EEEEEE','#EEF8F8','#EEFFFF']
for ifile in range(0,len(filenames)):
  filename = filenames[ifile]
  Nperpop = [20,30,40,50,60][ifile]
  yfiles = [220,180,130,70,0]

  print('Loading '+filename)
  A = scipy.io.loadmat(filename)
  for q in ['standard', 'deviant', 'pacemaker', 'pacemaker2', 'output', 'standardBoost', 'deviantBoost']:
    try:
      shp = A[q].shape
      for iy in range(0,shp[0]):
        for ix in range(0,shp[1]):
          if A[q][iy,ix].shape[0] == 1 and A[q][iy,ix].shape[1] > 1:
            A[q][iy,ix] = A[q][iy,ix][0]
    except:
      pass
  # Plotting the spikes for standardPopulationSpikeMonitor
  #stimvec = [stimulusStandard,stimulusPaceMaker,stimulusDeviant]
  #for stimind in [0,1,2]:
  #  thisstim = stimvec[stimind]
  #  lastval = 0
  #  dt = thisstim.dt*1000
  #  vals = thisstim.values
  #  for itime in range(0,len(vals)):
  #      axarr[0].plot([itime*dt,itime*dt,(itime+1)*dt],[lastval-150*stimind,vals[itime]-150*stimind,vals[itime]-150*stimind],lw=0.5)
  #      lastval = vals[itime]
  plotteds = []
  for iiMMN in range(0,4):
    polygon = Polygon(array([[0,3900,3900,0],[yfiles[ifile],yfiles[ifile],yfiles[ifile]+Nperpop,yfiles[ifile]+Nperpop]]).T)
    p = PatchCollection([polygon], cmap=matplotlib.cm.jet)
    p.set_facecolor(dimcols[ifile])
    p.set_edgecolor(None)
    axarr[iiMMN].add_collection(p)
    iMMN = MMNorder[iiMMN]
    plotteds_this = []
    axarr[iiMMN].plot(A['output'][iMMN,0], A['output'][iMMN,1]+yfiles[ifile], 'r.', lw=0.35, ms=0.35, mew=0.35, color=cols[ifile])
    plotteds_this.append(len(A['output'][iMMN,0]))

  axarr[0].text(-500,0.5*[20,30,40,50,60][ifile]+yfiles[ifile],'N='+str([20,30,40,50,60][ifile]),rotation=0,ha='right',va='center',fontsize=5,fontweight='bold' if ifile==2 else 'normal')

standard_xs = [[0+x,400+x,400+x,450+x,450+x,500+x,500+x] for x in [0,500,1000,1500,2000,2500,3000]]
for iMMN in range(0,4):
    st_on = [1,1,1,1,0,1,1] if iMMN < 2 else ([1,1,1,1,2,1,1] if iMMN == 2 else [2,2,2,2,1,2,2])
    dev_on = [0,0,0,0,1,0,0] if iMMN == 0 else [0,0,0,0,0,0,0]

    standard_ys = [[0,0,1*(x>1),1*(x>1),1*(x>0),1*(x>0),0] for x in st_on]
    deviant_ys = [[0,0,1*(x>1),1*(x>1),1*(x>0),1*(x>0),0] for x in dev_on]
    standard_xs_this = [x for y in standard_xs for x in y]+[3800]
    standard_ys_this = [x for y in standard_ys for x in y]+[0]
    deviant_ys_this = [x for y in deviant_ys for x in y]+[0]

    axarr[iMMN].plot(standard_xs_this,[4*y+258 for y in standard_ys_this],'b-',lw=0.3,clip_on=False)
    axarr[iMMN].plot(standard_xs_this,[4*y+250 for y in deviant_ys_this],'r-',lw=0.3,clip_on=False)

pos = axarr[0].get_position()
fig1.text(pos.x0 - 0.05, pos.y1 + 0.02, 'A', fontsize=11)
#pos = axarr[4].get_position()
#fig1.text(pos.x0 - 0.07, pos.y1 - 0.01, 'E', fontsize=11)


axarr[0].plot([3000,3500],[20,20],'k-',lw=0.5)
axarr[0].text(3250,22,'500 ms',ha='center',va='bottom',fontsize=5)

for iax in range(0,4): 
  axarr[iax].set_yticks([])
  axarr[iax].set_xticks([])

for ax in axarr:
  for line in ax.yaxis.get_ticklines():
    line.set_markeredgewidth(0.3)
    line.set_markersize(2)
fig1.savefig('fig_rob_sizes.pdf')