from pylab import *
import scipy.io
import mytools
from matplotlib.collections import PatchCollection

Nperpop = 20

#filename = 'MMNs_2pm_sep_noISDIDD_Nperpop40_paramSD0.3_stimA150_130_gAMPA17.5_30.0_80.0_gNMDA5.827500000000001_9.99_26.64_gGABA35.0_35.0_dep1000_0.0_0.95_tau10.0_10.0_10.0_250.0.mat'
filename = 'MMNs_2pm_sep_noISDIDD_model0_CTRLpop_AUCbased_seed1.mat'
Nperpop = 40
  
def boxoff(ax,whichxoff='top'):
    ax.spines[whichxoff].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.get_xaxis().tick_bottom()
    ax.get_yaxis().tick_left()

fig1, axs = subplots(9,4)
axarr = axs.reshape(prod(axs.shape),).tolist()
for iax in range(0,4):
  for iay in range(0,9):
    axs[iay,iax].set_position([0.08+0.24*iax, 0.04+0.1*(8-iay),0.18,0.1])
    if iay < 10:
      axs[iay,iax].set_xticks([])
for iax in range(0,len(axarr)):
  axarr[iax].tick_params(axis='both', which='major', labelsize=4)
  boxoff(axarr[iax])
  axarr[iax].set_yticks([])
  for axis in ['top','bottom','left','right']:
    axarr[iax].spines[axis].set_linewidth(0.2)
  axarr[iax].set_xlim([0,3800])
  axarr[iax].set_ylim([0,Nperpop+7])
  
#axs[0,0].text(0,Nperpop+7,'Excitatory deviant detecting output (EO)',fontsize=4,ha='left',va='top',fontweight='bold')
#axs[1,0].text(0,Nperpop+7,'Excitatory population for standards (ES)',fontsize=4,ha='left',va='top')
#axs[2,0].text(0,Nperpop+7,'Inhibitory population for standards (IS)',fontsize=4,ha='left',va='top')
#axs[3,0].text(0,Nperpop+7,'Excitatory population for standards, delayed (ESD)',fontsize=4,ha='left',va='top')
#axs[4,0].text(0,Nperpop+7,'Excitatory population for deviants (ED)',fontsize=4,ha='left',va='top')
#axs[5,0].text(0,Nperpop+7,'Inhibitory population for deviants (ID)',fontsize=4,ha='left',va='top')
#axs[6,0].text(0,Nperpop+7,'Excitatory pop. for deviants, delayed (ESD)',fontsize=4,ha='left',va='top')
#axs[7,0].text(0,Nperpop+7,'Exc. timer pop. receiving phase-locked input (EP)',fontsize=4,ha='left',va='top')
#axs[8,0].text(0,Nperpop+7,'Exc. timer pop. receiving phase-locked input, alt. phase (EP2)',fontsize=4,ha='left',va='top')

titles = ['Excitatory deviant detecting output (EO)',
          'Excitatory population for standards (ES)',
          'Inhibitory population for standards (IS)',
          'Excitatory population for deviants (ED)',
          'Inhibitory population for deviants (ID)',
          'Exc. timer pop. receiving phase-locked input (EP)',
          'Excitatory population for standards, delayed (ESD)',
          'Excitatory pop. for deviants, delayed (EDD)',
          'Exc. timer pop. receiving phase-locked input, alt. phase (EP2)']

for iay in range(0,len(titles)):
  verts = [(0.08,0.04+0.1*(8-iay)+0.085), (0.08,0.04+0.1*(8-iay)+0.0995), (0.98,0.04+0.1*(8-iay)+0.0995), (0.98,0.04+0.1*(8-iay)+0.085)] 
  polygon = Polygon(verts, closed=True, transform=fig1.transFigure,
                    facecolor='#EEEEEE', edgecolor=None) #, alpha=0.5)
  fig1.patches.append(polygon)  # Attach directly to figure
  fig1.text(0.085,0.04+0.1*(8-iay)+0.084, titles[iay], fontsize=5.5, ha='left', va='bottom')

MMNorder = [1,0,2,3]
axs[0,MMNorder[0]].set_title('           Omission',fontsize=8, pad=12)
axs[0,MMNorder[1]].set_title('           Frequency deviant',fontsize=8, pad=12)
axs[0,MMNorder[2]].set_title('           Duration deviant',fontsize=8, pad=12)
axs[0,MMNorder[3]].set_title('Inv. dur. deviant',fontsize=8, pad=12)
col = '#000000'
if True:
  print('Loading '+filename)
  A = scipy.io.loadmat(filename)
  for q in ['standard', 'deviant', 'pacemaker', 'pacemaker2', 'output', 'standardBoost', 'deviantBoost']:
    try:
      shp = A[q].shape
      for iy in range(0,shp[0]):
        for ix in range(0,shp[1]):
          if A[q][iy,ix].shape[0] == 1 and A[q][iy,ix].shape[1] > 1:
            A[q][iy,ix] = A[q][iy,ix][0]
    except:
      pass
  # Plotting the spikes for standardPopulationSpikeMonitor
  #stimvec = [stimulusStandard,stimulusPaceMaker,stimulusDeviant]
  #for stimind in [0,1,2]:
  #  thisstim = stimvec[stimind]
  #  lastval = 0
  #  dt = thisstim.dt*1000
  #  vals = thisstim.values
  #  for itime in range(0,len(vals)):
  #      axarr[0].plot([itime*dt,itime*dt,(itime+1)*dt],[lastval-150*stimind,vals[itime]-150*stimind,vals[itime]-150*stimind],lw=0.5)
  #      lastval = vals[itime]
  plotteds = []
  for iiMMN in range(0,4):
    iMMN = MMNorder[iiMMN]
    plotteds_this = []
    axs[0,iiMMN].plot(A['output'][iMMN,0], A['output'][iMMN,1]+Nperpop*0, 'r.', lw=0.35, ms=0.35, mew=0.35, color=col)
    plotteds_this.append(len(A['output'][iMMN,0]))
    axs[1,iiMMN].plot([A['standard'][iMMN,0][i] for i in range(0,len(A['standard'][iMMN,0])) if A['standard'][iMMN,1][i] < Nperpop], [A['standard'][iMMN,1][i]+Nperpop*0 for i in range(0,len(A['standard'][iMMN,0])) if A['standard'][iMMN,1][i] < Nperpop], 'b.', lw=0.35, ms=0.35, mew=0.35, color=col)
    plotteds_this.append(len([1 for i in range(0,len(A['standard'][iMMN,0])) if A['standard'][iMMN,1][i] < Nperpop]))
    axs[2,iiMMN].plot([A['standard'][iMMN,0][i] for i in range(0,len(A['standard'][iMMN,0])) if A['standard'][iMMN,1][i] >= Nperpop], [A['standard'][iMMN,1][i]-Nperpop+Nperpop*0 for i in range(0,len(A['standard'][iMMN,0])) if A['standard'][iMMN,1][i] >= Nperpop], 'b.', lw=0.35, ms=0.35, mew=0.35, color=col)
    plotteds_this.append(len([1 for i in range(0,len(A['standard'][iMMN,0])) if A['standard'][iMMN,1][i] >= Nperpop]))
    try:
      axs[6,iiMMN].plot([A['standardBoost'][iMMN,0][i] for i in range(0,len(A['standardBoost'][iMMN,0])) if A['standardBoost'][iMMN,1][i] < Nperpop], [A['standardBoost'][iMMN,1][i]+Nperpop*0 for i in range(0,len(A['standardBoost'][iMMN,0])) if A['standardBoost'][iMMN,1][i] < Nperpop], 'b.', lw=0.35, ms=0.35, mew=0.35, color=col)
      plotteds_this.append(len([1 for i in range(0,len(A['standardBoost'][iMMN,0])) if A['standardBoost'][iMMN,1][i] < Nperpop]))
    except:
      print('standardBoost population not plotted')
    try:
      axs[3,iiMMN].plot([A['deviant'][iMMN,0][i] for i in range(0,len(A['deviant'][iMMN,0])) if A['deviant'][iMMN,1][i] < Nperpop], [A['deviant'][iMMN,1][i]+Nperpop*0 for i in range(0,len(A['deviant'][iMMN,0])) if A['deviant'][iMMN,1][i] < Nperpop], 'b.', lw=0.35, ms=0.35, mew=0.35, color=col)
      plotteds_this.append(len([1 for i in range(0,len(A['deviant'][iMMN,0])) if A['deviant'][iMMN,1][i] < Nperpop]))
      axs[4,iiMMN].plot([A['deviant'][iMMN,0][i] for i in range(0,len(A['deviant'][iMMN,0])) if A['deviant'][iMMN,1][i] >= Nperpop], [A['deviant'][iMMN,1][i]-Nperpop+Nperpop*0 for i in range(0,len(A['deviant'][iMMN,0])) if A['deviant'][iMMN,1][i] >= Nperpop], 'b.', lw=0.35, ms=0.35, mew=0.35, color=col)
      plotteds_this.append(len([1 for i in range(0,len(A['deviant'][iMMN,0])) if A['deviant'][iMMN,1][i] >= Nperpop]))
    except:
      print('deviant population not plotted')
    try:
      axs[7,iiMMN].plot([A['deviantBoost'][iMMN,0][i] for i in range(0,len(A['deviantBoost'][iMMN,0])) if A['deviantBoost'][iMMN,1][i] < Nperpop], [A['deviantBoost'][iMMN,1][i]+Nperpop*0 for i in range(0,len(A['deviantBoost'][iMMN,0])) if A['deviantBoost'][iMMN,1][i] < Nperpop], 'b.', lw=0.35, ms=0.35, mew=0.35, color=col)
      plotteds_this.append(len([1 for i in range(0,len(A['deviantBoost'][iMMN,0])) if A['deviantBoost'][iMMN,1][i] < Nperpop]))
    except:
      print('deviantBoost population not plotted')
    try:
      axs[5,iiMMN].plot(A['pacemaker'][iMMN,0], A['pacemaker'][iMMN,1]+Nperpop*0, 'r.', lw=0.35, ms=0.35, mew=0.35, color=col)
      plotteds_this.append(len(A['pacemaker'][iMMN,0]))
    except:
      print('pacemaker population not plotted')
    try:
      axs[8,iiMMN].plot(A['pacemaker2'][iMMN,0], A['pacemaker2'][iMMN,1]+Nperpop*0, 'r.', lw=0.35, ms=0.35, mew=0.35, color=col)
      plotteds_this.append(len(A['pacemaker2'][iMMN,0]))
    except:
      print('pacemaker2 population not plotted')
    print("Nspikes = "+str(plotteds_this))

standard_xs = [[0+x,400+x,400+x,450+x,450+x,500+x,500+x] for x in [0,500,1000,1500,2000,2500,3000]]
for iMMN in range(0,4):
    st_on = [1,1,1,1,0,1,1] if iMMN < 2 else ([1,1,1,1,2,1,1] if iMMN == 2 else [2,2,2,2,1,2,2])
    dev_on = [0,0,0,0,1,0,0] if iMMN == 0 else [0,0,0,0,0,0,0]
    
    standard_ys = [[0,0,1*(x>1),1*(x>1),1*(x>0),1*(x>0),0] for x in st_on]
    deviant_ys = [[0,0,1*(x>1),1*(x>1),1*(x>0),1*(x>0),0] for x in dev_on]
    standard_xs_this = [x for y in standard_xs for x in y]+[3800]
    standard_ys_this = [x for y in standard_ys for x in y]+[0]
    deviant_ys_this = [x for y in deviant_ys for x in y]+[0]
                   
    axarr[iMMN].plot(standard_xs_this,[2*y+54 for y in standard_ys_this],'k-',lw=0.3,clip_on=False)
    axarr[iMMN].plot(standard_xs_this,[2*y+50 for y in deviant_ys_this],'k-',lw=0.3,clip_on=False)
    
pos = axarr[0].get_position()
fig1.text(pos.x0 - 0.0, pos.y1 + 0.03, 'B', fontsize=11)

fig1.savefig('fig_onesim.pdf')