from pylab import *
import scipy.io
from matplotlib.collections import PatchCollection
import mytools
from os.path import exists

seedAdd = ''
if len(sys.argv) > 1:
  seedAdd = sys.argv[1]

def format4(x):
    return format(x, ".4f").rstrip("0").rstrip(".")

def mybar(ax,x,y,facecolor=[],linewidth=0.3,w=0.4):
  qs = quantile(y, [0,0.25,0.5,0.75,1])
  polygon = Polygon(array([[x-w,x+w,x+w,x-w],[qs[1],qs[1],qs[3],qs[3]]]).T)
  p = PatchCollection([polygon], cmap=matplotlib.cm.jet)
  if type(facecolor) is not list or len(facecolor) > 0:
    p.set_facecolor(facecolor)
  p.set_edgecolor('#000000')
  p.set_linewidth(0.3)
  ax.add_collection(p)
  a2 = ax.plot([x-w,x+w,x,x,x-w,x+w,x,x,x-w,x+w],[qs[0],qs[0],qs[0],qs[2],qs[2],qs[2],qs[2],qs[4],qs[4],qs[4]],'k-',lw=linewidth)
  return [p,a2]

f,axarr = subplots(14,1)
for iax in range(0,len(axarr)):
  axarr[iax].tick_params(axis='both', which='major', labelsize=4, direction='out', width=0.4, length=2, pad=2)

axarr[0].set_position([0.09,0.54,0.46,0.45])
axarr[1].set_position([0.62,0.88,0.33,0.1])
axarr[2].set_position([0.62,0.71,0.33,0.1])
axarr[3].set_position([0.62,0.54,0.33,0.1])

axarr[4].set_position([0.09,0.34,0.25,0.12])
axarr[5].set_position([0.395,0.34,0.25,0.12])
axarr[6].set_position([0.70,0.34,0.25,0.12])
for i in range(0,3):
  axarr[7+i].set_position([0.1,0.21-0.08*i,0.17,0.07])
axarr[10].set_position([0.35,0.21-0.08*2,0.17,0.23])

axarr[11].set_position([0.59,0.21-0.08*2,0.16,0.23])
axarr[12].set_position([0.82,0.21-0.08*2+0.11,0.13,0.12])
axarr[13].set_position([0.82,0.21-0.08*2,0.13,0.05])

#Plot panel B (grid search):
A=scipy.io.loadmat('fig_stdpsynfire_gridsearch3_0.05_0.075.mat') #Run drawfig_stdpsynfire_gridsearch3.py first and load the results here

stuffPlotted = A['stuffPlotted']
A_pluss = A['A_pluss'][0]
A_minus_factors = A['A_minus_factors'][0]

for ipatch in range(0,len(stuffPlotted)):
  rect = stuffPlotted[ipatch][0]
  polygon = Polygon(rect.T)
  p = PatchCollection([polygon], cmap=matplotlib.cm.jet)
  p.set_facecolor(stuffPlotted[ipatch][1])
  p.set_edgecolor(None)
  axarr[0].add_collection(p)
  if len(stuffPlotted[ipatch][2]) > 0:
    p.set_hatch(stuffPlotted[ipatch][2][0])

for iA_plus in range(0,len(A_pluss)):
  axarr[0].text(iA_plus+0.5,-0.05,str(int(A_pluss[iA_plus]*1000)),fontsize=5,rotation=0,va='top',ha='center',clip_on=False)

for ifA in range(0,len(A_minus_factors)):
  axarr[0].text(-0.05,ifA+0.5,str(A_minus_factors[ifA]),fontsize=5,rotation=0,va='center',ha='right',clip_on=False)

axarr[0].set_xlim([-0.01,len(A_pluss)+3.5])
axarr[0].set_ylim([0,len(A_minus_factors)])
axarr[0].set_xticks([])
axarr[0].set_yticks([])

facecolors_all = ['#0000CC', '#0C0CCE', '#1919D1', '#2626D3', '#3333D6', '#3F3FD8', '#4646D5', '#4C4CDB', '#5959DD', '#6666E0', '#7272E2', '#7F7FE5', '#8C8CE8', '#9999EA', '#ABABED','#BFBFF2', '#CCCCF4', '#D8D8F7', '#E5E5F9', '#F2F2FC', '#FFFFFF']
axarr[0].text(12.3,7.5+0.6,'Precise\nactivations after\nstim. offset',fontsize=5.5,clip_on=False,va='bottom',ha='left')
axarr[0].text(12.3,2.5+0.6,'Long or\nceaseless\nresponses',fontsize=5.5,clip_on=False,va='bottom',ha='left')
for ip in range(0,5):
  polygon = Polygon(array([[12.6,13.6,13.6,12.6],[7.5-ip*0.6+x for x in [0,0,0.5,0.5]]]).T)
  p = PatchCollection([polygon], cmap=matplotlib.cm.jet)
  p.set_facecolor(facecolors_all[ip*5])
  p.set_edgecolor('#000000')
  axarr[0].add_collection(p)
  axarr[0].text(13.9,7.5-ip*0.6+0.25,str(25*ip)+'%',fontsize=6,clip_on=False,va='center',ha='left')

for ip in range(0,6):
  polygon = Polygon(array([[12.6,13.6,13.6,12.6],[2.5-ip*0.6+x for x in [0,0,0.5,0.5]]]).T)
  p = PatchCollection([polygon], cmap=matplotlib.cm.jet)
  p.set_facecolor('#FFFFFF')
  p.set_edgecolor('#000000')
  p.set_clip_on(False)
  axarr[0].add_collection(p)
  axarr[0].text(13.9,2.5-ip*0.6+0.25,str(20*ip)+'%',fontsize=6,clip_on=False,va='center',ha='left')
  if ip > 0:
    p.set_hatch('/'*(ip))
axarr[0].plot([0,0,10],[10,0,0],'k-')
axarr[0].axis('off')

#Plot panels C, D, and E (example spike trains of the whole population for the whole duration):
ibest_A_plus = 6
ibest_A_minus = 3
i1_A_plus = [2, 10, 6]
i1_A_minus = [5, 1, 3]
cols_i1 = ['#BBBB00','#444400','#FF00FF']
for iax in range(0,3):
  A_plus = A_pluss[i1_A_plus[iax]]
  A_minus = A_plus*A_minus_factors[i1_A_minus[iax]]
  axarr[0].plot([i1_A_plus[iax],i1_A_plus[iax],i1_A_plus[iax]+1,i1_A_plus[iax]+1,i1_A_plus[iax]],[i1_A_minus[iax],i1_A_minus[iax]+1,i1_A_minus[iax]+1,i1_A_minus[iax],i1_A_minus[iax]],'-',color=cols_i1[iax])

  polygon = Polygon(array([[0,0,12,12],[0,3550,3550,0]]).T)
  p = PatchCollection([polygon], cmap=matplotlib.cm.jet)
  p.set_facecolor('#DDDDDD')
  p.set_edgecolor(None)
  axarr[1+iax].add_collection(p)
  filename = 'synfirefiles/stdpsynfire_synstim_N50_L70_T12.0_16.0_2.0Hz_gEE12.0_gEI20.0_gIE20.0_A'+str(A_plus)+'_'+format4(A_minus)+'_'+str(A_plus)+'.mat'
  A = scipy.io.loadmat(filename)
  iN = filename.find('_N')
  N = int(filename[iN+2:iN+2+filename[iN+2:].find('_')])
  iL = filename.find('_L')
  Lmax = int(filename[iL+2:iL+2+filename[iL+2:].find('_')])
  iT = filename.find('_T')
  stim_duration = float(filename[iT+2:iT+2+filename[iT+2:].find('_')])

  Nskip = 10
  for i in range(0, Lmax):  # plot every pool
    axarr[1+iax].plot(A['spikes'][0][1+i][0][::Nskip], i*N+0.95*A['spikes'][1][1+i][0][::Nskip], 'ks', markersize=0.3,mew=0.3,lw=0.3)
  axarr[1+iax].set_xlabel('Time (s)',fontsize=6,labelpad=3)
  axarr[1+iax].set_ylabel('Neuron ID',fontsize=6)
  axarr[4+iax].set_ylabel('Neuron ID',fontsize=6)
  axarr[1+iax].set_xlim([0,16])
  axarr[1+iax].set_ylim([0,3550])
  axarr[1+iax].spines['bottom'].set_color(cols_i1[iax])
  axarr[1+iax].spines['top'].set_color(cols_i1[iax])
  axarr[1+iax].spines['left'].set_color(cols_i1[iax])
  axarr[1+iax].spines['right'].set_color(cols_i1[iax])

A_plus = A_pluss[ibest_A_plus]
A_minus = A_plus*A_minus_factors[ibest_A_minus]
filename = 'synfirefiles/stdpsynfire_synstim_N50_L70_T12.0_16.0_2.0Hz_gEE12.0_gEI20.0_gIE20.0_A'+str(A_plus)+'_'+format4(A_minus)+'_'+str(A_plus)+seedAdd+'.mat'
A = scipy.io.loadmat(filename)

#Plot panels F, G, and H (example spike trains, zoomed in)
Nskip = 1
axnew = []
for iax in [4,5,6]:
  pos = axarr[iax].get_position()
  axnew.append(f.add_axes([pos.x0+0.162,pos.y0+0.047,0.07,0.061]))

xstarts = [11.45,11.95,14.95]
for i in range(0,3):
  polygon = Polygon(array([[xstarts[i],xstarts[i],xstarts[i]+0.1,xstarts[i]+0.1],[0,200,200,0]]).T)
  p = PatchCollection([polygon], cmap=matplotlib.cm.jet)
  p.set_facecolor('#DDDDDD')
  p.set_edgecolor(None)
  axarr[4+i].add_collection(p)

for i in range(0, Lmax):  # plot every pool
  ts = A['spikes'][0][1+i][0][::Nskip]
  inds = i*N+0.95*A['spikes'][1][1+i][0][::Nskip]
  
  axarr[4].plot([ts[i] for i in range(0,len(ts)) if xstarts[0]<=ts[i]<=xstarts[0]+0.5], [inds[i] for i in range(0,len(ts)) if xstarts[0]<=ts[i]<=xstarts[0]+0.5], 'ks', markersize=0.3,mew=0.3,lw=0.3)
  axarr[5].plot([ts[i] for i in range(0,len(ts)) if xstarts[1]<=ts[i]<=xstarts[1]+0.5], [inds[i] for i in range(0,len(ts)) if xstarts[1]<=ts[i]<=xstarts[1]+0.5], 'ks', markersize=0.3,mew=0.3,lw=0.3)
  axarr[6].plot([ts[i] for i in range(0,len(ts)) if xstarts[2]<=ts[i]<=xstarts[2]+0.5], [inds[i] for i in range(0,len(ts)) if xstarts[2]<=ts[i]<=xstarts[2]+0.5], 'ks', markersize=0.3,mew=0.3,lw=0.3)

  axnew[0].plot([ts[i]*1000 for i in range(0,len(ts)) if xstarts[0]<=ts[i]<=xstarts[0]+0.1], [inds[i] for i in range(0,len(ts)) if xstarts[0]<=ts[i]<=xstarts[0]+0.1], 'ks', markersize=0.3,mew=0.3,lw=0.3)
  axnew[1].plot([ts[i]*1000 for i in range(0,len(ts)) if xstarts[1]<=ts[i]<=xstarts[1]+0.1], [inds[i] for i in range(0,len(ts)) if xstarts[1]<=ts[i]<=xstarts[1]+0.1], 'ks', markersize=0.3,mew=0.3,lw=0.3)
  axnew[2].plot([ts[i]*1000 for i in range(0,len(ts)) if xstarts[2]<=ts[i]<=xstarts[2]+0.1], [inds[i] for i in range(0,len(ts)) if xstarts[2]<=ts[i]<=xstarts[2]+0.1], 'ks', markersize=0.3,mew=0.3,lw=0.3)

for i in [4,5,6]:
  axarr[i].spines['bottom'].set_color(cols_i1[2])
  axarr[i].spines['top'].set_color(cols_i1[2])
  axarr[i].spines['left'].set_color(cols_i1[2])
  axarr[i].spines['right'].set_color(cols_i1[2])

#Plot panels I, J, K and L: Time courses of synaptic weight s
poolgroups = [list(range(0,10)),list(range(10,50)),list(range(60,70)),list(range(50,60))] #Panel I has neurons 1-10, J has 11-50, K has 61-70, and L has 51-60
for ipoolgroup in range(0,4):
  poolgroup = poolgroups[ipoolgroup]
  poolcols = mytools.colorsredtolila(len(poolgroup)+int(len(poolgroup)/5),0.7)
  for iipool in range(0,len(poolgroup)):
    ipool = poolgroup[iipool]
    axarr[7+ipoolgroup].plot(A['weights'][0][0][0],A['weights'][0][1][ipool],'-',lw=0.4,color=poolcols[iipool])
  axarr[7+ipoolgroup].set_xlim([0,16])
  axarr[7+ipoolgroup].set_ylim([0,174])
  pos = axarr[7+ipoolgroup].get_position()

  myleg = mytools.mylegend(f,[pos.x0+0.005+0.04*(ipoolgroup%2==0),pos.y0+0.014+0.16*(ipoolgroup==3),0.07,0.048],['-',':','-','-'],['Pool '+str(max(poolgroup)+1),'','Pool '+str(min(poolgroup)+2),'Pool '+str(min(poolgroup)+1)],nx=1,dx=1.5,yplus=0.5,yplustext=0.35,colors=[poolcols[len(poolgroup)-1],'#000000',poolcols[1],poolcols[0]],linewidths=[0.4,0.6,0.4,0.4],myfontsize=4.5)
  for q in ['top','bottom','left','right']:
    myleg.spines[q].set_linewidth(0.0)
  myleg.patch.set_alpha(0)

axarr[10].set_ylim([0,174*0.23/0.07])
axarr[4].set_title('Last entrainment stimulus',fontsize=5,pad=2)
axarr[5].set_title('First test period',fontsize=5,pad=2)
axarr[6].set_title('Seventh test period',fontsize=5,pad=2)
  
for iax in range(0,len(axnew)):
  axnew[iax].tick_params(axis='both', which='major', labelsize=4, direction='out', width=0.4, length=2.0, pad=1.5)
  axnew[iax].set_xlim([xstarts[iax]*1000,xstarts[iax]*1000+100])
  axnew[iax].set_ylim([0,200])
  axnew[iax].set_xlabel('Time (ms)',fontsize=4.5,labelpad=1.5)
  axarr[4+iax].set_xlabel('Time (sec)',fontsize=5,labelpad=3)

#Panel M: Weight distribution after entrainment. For these, use the _allweights_ files because there all weights are saved, not only their mean across the population. These will be needed to plot the distribution exactly
spikes_all = []
weights_all = []
allweights_all = []
istim_to_plot = [23,24,25,26,27,28,29,30,31]
base_spikes_t_pooled_all_pooled = []
for istim in istim_to_plot:
  base_spikes_t_pooled_all_pooled.append([])
for myseed in range(1,21):
  filename = 'synfirefiles/stdpsynfire_synstim_N50_L70_T12.0_16.0_2.0Hz_gEE12.0_gEI20.0_gIE20.0_A'+str(A_plus)+'_'+format4(A_minus)+'_'+str(A_plus)+('' if myseed == 1 else '_seed'+str(myseed))+'.mat'
  A = scipy.io.loadmat(filename)
  spikes_t = A['spikes'][0] # [[spike_monitors[i].t/second for i in range(0,len(spike_monitors))],[spike_monitors[i].i for i in range(0,len(spike_monitors))]],
  spikes_i = A['spikes'][1] # [[spike_monitors[i].t/second for i in range(0,len(spike_monitors))],[spike_monitors[i].i for i in range(0,len(spike_monitors))]],
  weights = A['weights'] # [weight_monitors[0].t/second,[np.mean(weight_monitors[i].w, axis=0) for i in range(0,len(weight_monitors))]]
  if max(weights[0][1][0]) < 1e-4:
    print("synfirefiles/"+filename_body+seed_addition+".mat is not in pS!")
    weights[0][1][0] = weights[0][1][0]*1e12

  spikes_all.append([spikes_t[:],spikes_i[:]])
  weights_all.append(weights[:])

  base_spikes_t_pooled_all = []

  for istim in istim_to_plot:
    base_spikes_t_pooled = []
    for iseed in range(0,len(spikes_all)):
      spikes_t = spikes_all[iseed][0]
      spikes_i = spikes_all[iseed][1]
      base_spikes_t = [t-0.5*istim for t in spikes_t[0][0] if 0.5*istim-0.2 <= t <= 0.5*istim+0.3] 
      base_spikes_i = [spikes_i[0][0][i] for i in range(0,len(spikes_t[0][0])) if 0.5*istim-0.2 <= spikes_t[0][0][i] <= 0.5*istim+0.3]
      base_spikes_t_pooled = base_spikes_t_pooled + base_spikes_t
    base_spikes_t_pooled_all.append(base_spikes_t_pooled[:])
  print("  min1,max1,min2,max2 = "+str(min(base_spikes_t_pooled_all[0]))+","+str(max(base_spikes_t_pooled_all[0]))+","+str(min(base_spikes_t_pooled_all[1]))+","+str(max(base_spikes_t_pooled_all[1])))
  for istim in range(0,len(base_spikes_t_pooled_all_pooled)):
    base_spikes_t_pooled_all_pooled[istim] = base_spikes_t_pooled_all_pooled[istim] + base_spikes_t_pooled_all[istim][:]
  Nspikes = [len(x) for x in base_spikes_t_pooled_all]
  filename2 = 'synfirefiles/stdpsynfire_synstim_allweights_N50_L70_T12.0_16.0_2.0Hz_gEE12.0_gEI20.0_gIE20.0_A'+str(A_plus)+'_'+format4(A_minus)+'_'+str(A_plus)+('' if myseed == 1 else '_seed'+str(myseed))+'.mat'
  if exists(filename2):
    B = scipy.io.loadmat(filename2)
    allweights = B['weights']
    allweights_all.append(allweights[:])
  else:
    print(filename2+" does not exist")
    
print("min1,max1,min2,max2 = "+str(min(base_spikes_t_pooled_all_pooled[0]))+","+str(max(base_spikes_t_pooled_all_pooled[0]))+","+str(min(base_spikes_t_pooled_all_pooled[1]))+","+str(max(base_spikes_t_pooled_all_pooled[1])))

weights = array([weights_all[i][0][1][:,240] for i in range(0,len(weights_all))])
allweights_pools = []
for ipool in [50,53,56,59,60]:
  allweights = concatenate([allweights_all[i][0][1][0][ipool][:,240] for i in range(0,len(allweights_all))])
  allweights_pools.append(allweights[:])
  counts, bin_edges = histogram(allweights, bins=30)
  axarr[11].step(bin_edges, list(counts)+[0], where='post',lw=0.4,label="Pool "+str(ipool+1))

#Plot panel N (distribution of spike times in the base population at the last stimulus and at the first expected post-cessation cycle): 
for istim in range(0,2):
  counts, bin_edges = histogram([1000*x for x in base_spikes_t_pooled_all_pooled[istim]], bins=30)
  axarr[12].step(bin_edges, list(counts)+[0], where='post',lw=0.4, label='11.5 sec' if istim == 0 else '12.0 sec')

#Plot panel O (drift): 
for istim in range(1,len(base_spikes_t_pooled_all_pooled)):
  mybar(axarr[13],0.5*istim_to_plot[istim],[1000*x for x in base_spikes_t_pooled_all_pooled[istim]],'#ff7f0e',linewidth=0.3,w=0.2) #Default orange color
  

axarr[12].set_xlim([-80,40])
axarr[12].set_ylim([0,3000])
axarr[11].set_ylim([0,2300])
axarr[11].legend(fontsize=4.5)
axarr[12].legend(fontsize=4.5)

axarr[0].text(-1.4,4.5,'$\\alpha$',fontsize=6,clip_on=False)
axarr[0].text(4.5,-0.95,'$A^+$ (pS)',fontsize=6,clip_on=False)

axarr[8].set_ylabel('Synapse weight (pS)',fontsize=6)
axarr[10].set_ylabel('Synapse weight (pS)   ',fontsize=6)
axarr[9].set_xlabel('Time (s)',fontsize=6,labelpad=3)
axarr[10].set_xlabel('Time (s)',fontsize=6,labelpad=3)
axarr[11].set_xlabel('Weight (pS)',fontsize=6,labelpad=3)
axarr[12].set_xlabel('Rel. spike time (ms)',fontsize=6,labelpad=3)
axarr[13].set_xlabel('Time (s)',fontsize=6,labelpad=3)
axarr[13].set_ylabel('Drift     \n (ms)    ',fontsize=6,labelpad=3)


for iax in range(0,len(axarr)):
  pos = axarr[iax].get_position()
  f.text(pos.x0 - 0.055 - 0.01*(iax in [1,2,3,7,8,9]) - 0.01*(iax in [1,2,3,6]) +0.005*(iax==5)+ 0.015*(iax==6), pos.y1 - 0.03 + 0.02*(iax==13) + 0.02*(iax in [4,5,6]), chr(ord('B')+iax), fontsize=11)

f.set_size_inches([6.8,4.8])
f.savefig("figstdp"+seedAdd+".pdf")