from pylab import *
import scipy.io
import mytools
from matplotlib.collections import PatchCollection
from os.path import exists
from scipy.ndimage import gaussian_filter1d



#filenames = ['MMNothers_2pm_sep_Nperpop40_paramSD0.3_stimA150_130_gAMPA17.5_30.0_80.0_gNMDA5.827500000000001_9.99_26.64_gGABA35.0_35.0_dep1000_0.0_0.95_tau10.0_10.0_10.0_250.0.mat',
#             'MMNothers_2pm_sep_Nperpop40_paramSD0.3_stimA150_130_gAMPA17.5_30.0_80.0_gNMDA5.827500000000001_9.99_26.64_gGABA35.0_35.0_dep1000_0.0_0.95_tau10.0_10.0_10.0_250.0_seed2.mat']
filenames = ['MMNothers_2pm_sep_noISDIDD_Nperpop40_paramSD0.3_stimA150_130_gAMPA17.5_30.0_80.0_gNMDA5.827500000000001_9.99_26.64_gGABA35.0_35.0_dep1000_0.0_0.9_tau10.0_10.0_10.0_250.0.mat',
             'MMNothers_2pm_sep_noISDIDD_Nperpop40_paramSD0.3_stimA150_130_gAMPA17.5_30.0_80.0_gNMDA5.827500000000001_9.99_26.64_gGABA35.0_35.0_dep1000_0.0_0.9_tau10.0_10.0_10.0_250.0_seed2.mat',
             'MMNothers_2pm_sep_noISDIDD_Nperpop40_paramSD0.3_stimA150_130_gAMPA17.5_30.0_80.0_gNMDA5.827500000000001_9.99_26.64_gGABA35.0_35.0_dep1000_0.0_0.9_tau10.0_10.0_10.0_250.0_seed3.mat']

Nperpop = 40
Nsamp = 1
def boxoff(ax,whichxoff='top'):
    ax.spines[whichxoff].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.get_xaxis().tick_bottom()
    ax.get_yaxis().tick_left()
def mybar(ax,x,y,facecolor=[],linewidth=0.1,w=0.4):
  qs = quantile(y, [0,0.25,0.5,0.75,1])
  polygon = Polygon(array([[x-w,x+w,x+w,x-w],[qs[1],qs[1],qs[3],qs[3]]]).T)
  p = PatchCollection([polygon], cmap=matplotlib.cm.jet)
  if type(facecolor) is not list or len(facecolor) > 0:
    p.set_facecolor(facecolor)
  p.set_edgecolor('#000000')
  p.set_linewidth(0.3)
  ax.add_collection(p)
  a2 = ax.plot([x-w,x+w,x,x,x-w,x+w,x,x,x-w,x+w],[qs[0],qs[0],qs[0],qs[2],qs[2],qs[2],qs[2],qs[4],qs[4],qs[4]],'k-',lw=linewidth)
  return [p,a2]

fig1, axs = subplots(1,4)
axarr = axs.reshape(prod(axs.shape),).tolist()
for iax in range(0,4):
  axs[iax].set_position([0.08+0.24*iax, 0.78-0*0.2-0.12,0.18,0.08])
for iax in range(0,len(axarr)):
  axarr[iax].tick_params(axis='both', which='major', labelsize=4)
  boxoff(axarr[iax])
  for axis in ['top','bottom','left','right']:
    axarr[iax].spines[axis].set_linewidth(0.2)
  axarr[iax].set_xlim([0,3800])
  axarr[iax].set_ylim([0,46])

axarr[3].set_visible(False)

#axs[0,0].text(0,Nperpop+7,'Excitatory deviant detecting output (EO)',fontsize=4,ha='left',va='top',fontweight='bold')
#axs[1,0].text(0,Nperpop+7,'Excitatory population for standards (ES)',fontsize=4,ha='left',va='top')
#axs[2,0].text(0,Nperpop+7,'Inhibitory population for standards (IS)',fontsize=4,ha='left',va='top')
#axs[3,0].text(0,Nperpop+7,'Excitatory population for standards, delayed (ESD)',fontsize=4,ha='left',va='top')
#axs[4,0].text(0,Nperpop+7,'Excitatory population for deviants (ED)',fontsize=4,ha='left',va='top')
#axs[5,0].text(0,Nperpop+7,'Inhibitory population for deviants (ID)',fontsize=4,ha='left',va='top')
#axs[6,0].text(0,Nperpop+7,'Excitatory pop. for deviants, delayed (ESD)',fontsize=4,ha='left',va='top')
#axs[7,0].text(0,Nperpop+7,'Exc. timer pop. receiving phase-locked input (EP)',fontsize=4,ha='left',va='top')
#axs[8,0].text(0,Nperpop+7,'Exc. timer pop. receiving phase-locked input, alt. phase (EP2)',fontsize=4,ha='left',va='top')


MMNorder = [0,1,2,3]
axs[MMNorder[0]].set_title('Random auditory input,\nseed 1',fontsize=6,pad=12)
axs[MMNorder[1]].set_title('Random auditory input,\nseed 2',fontsize=6,pad=12)
axs[MMNorder[2]].set_title('Random auditory input,\nseed 3',fontsize=6,pad=12)
axs[MMNorder[3]].set_title('Inv. dur. deviant',fontsize=6,pad=12)
#cols = ['#000000','#AA7700']
cols = ['#000000','#000000','#000000']
#dimcols = ['#FFFF22','#CCCC55','#999999','#77CCCC','#55FFFF']
dimcols = ['#EEEEEE','#EEEEEE','#EEEEEE']
for ifile in range(0,len(filenames)):
  filename = filenames[ifile]
  yfiles = [0,0,0]

  print('Loading '+filename)
  A = scipy.io.loadmat(filename)
  for q in ['standard', 'deviant', 'pacemaker', 'pacemaker2', 'output', 'standardBoost', 'deviantBoost']:
    try:
      shp = A[q].shape
      for iy in range(0,shp[0]):
        for ix in range(0,shp[1]):
          if A[q][iy,ix].shape[0] == 1 and A[q][iy,ix].shape[1] > 1:
            A[q][iy,ix] = A[q][iy,ix][0]
    except:
      pass
  # Plotting the spikes for standardPopulationSpikeMonitor
  #stimvec = [stimulusStandard,stimulusPaceMaker,stimulusDeviant]
  #for stimind in [0,1,2]:
  #  thisstim = stimvec[stimind]
  #  lastval = 0
  #  dt = thisstim.dt*1000
  #  vals = thisstim.values
  #  for itime in range(0,len(vals)):
  #      axarr[0].plot([itime*dt,itime*dt,(itime+1)*dt],[lastval-150*stimind,vals[itime]-150*stimind,vals[itime]-150*stimind],lw=0.5)
  #      lastval = vals[itime]
  plotteds = []
  for iMMN in [3]:
    #polygon = Polygon(array([[0,3900,3900,0],[yfiles[ifile],yfiles[ifile],yfiles[ifile]+Nperpop,yfiles[ifile]+Nperpop]]).T)
    #p = PatchCollection([polygon], cmap=matplotlib.cm.jet)
    #p.set_facecolor(dimcols[ifile])
    #p.set_edgecolor(None)
    #axarr[0].add_collection(p)
    plotteds_this = []
    axarr[ifile].plot(A['output'][iMMN,0], A['output'][iMMN,1]+yfiles[ifile], 'r.', lw=0.35, ms=0.35, mew=0.35, color=cols[ifile])
    plotteds_this.append(len(A['output'][iMMN,0]))

  #axarr[0].text(-500,0.5*20+yfiles[ifile],'Default' if ifile==0 else 'Random',rotation=0,ha='right',va='center',fontsize=5,fontweight='bold' if ifile==2 else 'normal')

  standard_xs = [[0+x,400+x,400+x,450+x,450+x,500+x,500+x] for x in [0,500,1000,1500,2000,2500,3000]]
  standard_xs_random = [50*i for i in range(1,75)]
  
  for iMMN in [0]:
    pm_on = [1,1,1,1,1,1,1]

    pm_ys = [[0,0,1*(x>1),1*(x>1),1*(x>0),1*(x>0),0] for x in pm_on]

    standard_xs_this_syst = [x for y in standard_xs for x in y]+[3700]
    standard_xs_this = [0]+list(kron(standard_xs_random,[1,1]))+[3700]
    standard_ys_this = [0] if A['stimListStandard'][0][0] == 0 else [1]
    deviant_ys_this = [0] if A['stimListDeviant'][0][0] == 0 else [1]
    for iblock in range(0,len(standard_xs_random)):
      if (not A['stimListStandard'][0][iblock]) == (not standard_ys_this[-1]):
        standard_ys_this = standard_ys_this + [standard_ys_this[-1]]*2
        print("iblock = "+str(iblock)+", standard_ys_this[-1] = "+str(standard_ys_this[-1])+", A['stimListStandard'][0][iblock] = "+str(A['stimListStandard'][0][iblock])+", not A['stimListStandard'][0][iblock] = "+str(not A['stimListStandard'][0][iblock])+" ?= not standard_ys_this[-1] = "+str(not standard_ys_this[-1]))
      else:
        standard_ys_this = standard_ys_this + [standard_ys_this[-1],1-standard_ys_this[-1]]
      if (not A['stimListDeviant'][0][iblock]) == (not deviant_ys_this[-1]):
        deviant_ys_this = deviant_ys_this + [deviant_ys_this[-1]]*2
      else:
        deviant_ys_this = deviant_ys_this + [deviant_ys_this[-1],1-deviant_ys_this[-1]]
    standard_ys_this = standard_ys_this + [standard_ys_this[-1]]
    deviant_ys_this = deviant_ys_this + [deviant_ys_this[-1]]
    #standard_ys_this = [x for y in standard_ys for x in y]+[0]
    pm_ys_this = [x for y in pm_ys for x in y]+[0]

    axarr[ifile].plot(standard_xs_this,[2*y+52 for y in standard_ys_this],'b-',lw=0.3,clip_on=False)
    axarr[ifile].plot(standard_xs_this,[2*y+47 for y in deviant_ys_this],'r-',lw=0.3,clip_on=False)
    axarr[ifile].plot(standard_xs_this_syst,[2*y+42 for y in pm_ys_this],'g-',lw=0.3,clip_on=False)

pos = axarr[0].get_position()
fig1.text(pos.x0 - 0.05, pos.y1 + 0.02, 'C', fontsize=11)
#pos = axarr[4].get_position()
#fig1.text(pos.x0 - 0.07, pos.y1 - 0.01, 'E', fontsize=11)


axarr[0].plot([3600,4100],[20,20],'k-',lw=0.5)
axarr[0].text(3850,22,'500 ms',ha='center',va='bottom',fontsize=5)

for iax in range(0,4): 
  axarr[iax].set_yticks([])
  axarr[iax].set_xticks([])

for ax in axarr:
  for line in ax.yaxis.get_ticklines():
    line.set_markeredgewidth(0.3)
    line.set_markersize(2)
fig1.savefig('fig_rob_randominputs.pdf')