from pylab import *
import scipy.io
import os
from os.path import exists
import numpy


Nperpop = 40
paramSD = 0.3
tauD = 1000

stimAmp1 = [120,125,130,140,150]
gAMPA1 = [12.5, 15.0, 17.5, 20.0, 22.5]
gAMPA2 = [25.0, 30.0, 35.0, 40.0, 45.0]
gAMPA3 = [80.0, 90.0, 100.0, 110.0, 120.0, 130.0, 140.0]
nmdaAmpaRatio = [0.5, 0.333]
gGABA = [10.0, 15.0, 20.0, 25.0, 30.0, 35.0]
pvs = [0.9, 0.95]
min_gluts = [0.0]
tauNeur1s = [10.0]
tauNeur2s = [200.0, 250.0]

numpy.random.seed(1)
simIDs = list(range(0,len(stimAmp1)*len(gAMPA1)*len(gAMPA2)*len(gAMPA3)*len(nmdaAmpaRatio)*len(gGABA)*len(pvs)*len(min_gluts)*len(tauNeur1s)*len(tauNeur2s)))
shuffle(simIDs)

istart = int(sys.argv[1])
iend = int(sys.argv[2])

NoutputOutsides_all = []
NoutputInsides_all = []
Nbetweens_all = []
NoutputOutsidesUnique_all = []
NoutputInsidesUnique_all = []
NbetweensUnique_all = []
strs_all = []
filenames_all = []
for isamp in range(istart,iend):
    
    simID = simIDs[isamp]
    stA1 = stimAmp1[simID%len(stimAmp1)]
    stA2 = stA1 - 20
    gA1 = gAMPA1[int(simID/len(stimAmp1))%len(gAMPA1)]
    gA2 = gAMPA2[int(simID/len(stimAmp1)/len(gAMPA1))%len(gAMPA2)]
    gA3 = gAMPA3[int(simID/len(stimAmp1)/len(gAMPA1)/len(gAMPA2))%len(gAMPA3)]
    NAratio = nmdaAmpaRatio[int(simID/len(stimAmp1)/len(gAMPA1)/len(gAMPA2)/len(gAMPA3))%len(nmdaAmpaRatio)]
    gG = gGABA[int(simID/len(stimAmp1)/len(gAMPA1)/len(gAMPA2)/len(gAMPA3)/len(nmdaAmpaRatio))%len(gGABA)]
    pv = pvs[int(simID/len(stimAmp1)/len(gAMPA1)/len(gAMPA2)/len(gAMPA3)/len(nmdaAmpaRatio)/len(gGABA))%len(pvs)]
    min_glut = min_gluts[int(simID/len(stimAmp1)/len(gAMPA1)/len(gAMPA2)/len(gAMPA3)/len(nmdaAmpaRatio)/len(gGABA)/len(pvs))%len(min_gluts)]
    tauN1 = tauNeur1s[int(simID/len(stimAmp1)/len(gAMPA1)/len(gAMPA2)/len(gAMPA3)/len(nmdaAmpaRatio)/len(gGABA)/len(pvs)/len(min_gluts))%len(tauNeur1s)]
    tauN2 = tauNeur2s[int(simID/len(stimAmp1)/len(gAMPA1)/len(gAMPA2)/len(gAMPA3)/len(nmdaAmpaRatio)/len(gGABA)/len(pvs)/len(min_gluts)/len(tauNeur1s))%len(tauNeur2s)]

    gN1 = gA1 * NAratio
    gN2 = gA2 * NAratio
    gN3 = gA3 * NAratio
    strs_all.append('python3 sim_mmns_2pm_sep_noISDIDD_savespikesonly.py '+str(Nperpop)+' '+str(paramSD)+' '+str(stA1)+' '+str(stA2)+' '+str(gA1)+' '+str(gA2)+' '+str(gA3)+' '+str(gN1)+' '+str(gN2)+' '+str(gN3)+' '+str(gG)+' '+str(gG)+' '+str(tauD)+' '+str(min_glut)+' '+str(pv)+' '+str(tauN1)+' '+str(tauN1)+' '+str(tauN1)+' '+str(tauN2))
    filenames_all.append('MMNs_2pm_sep_noISDIDD_Nperpop'+str(Nperpop)+'_paramSD'+str(paramSD)+'_stimA'+str(stA1)+'_'+str(stA2)+'_gAMPA'+str(gA1)+'_'+str(gA2)+'_'+str(gA3)+'_gNMDA'+str(gN1)+'_'+str(gN2)+'_'+str(gN3)+'_gGABA'+str(gG)+'_'+str(gG)+'_dep'+str(tauD)+'_'+str(min_glut)+'_'+str(pv)+'_tau'+str(tauN1)+'_'+str(tauN1)+'_'+str(tauN1)+'_'+str(tauN2)+'.mat')
    if not exists('MMNs_2pm_sep_noISDIDD_Nperpop'+str(Nperpop)+'_paramSD'+str(paramSD)+'_stimA'+str(stA1)+'_'+str(stA2)+'_gAMPA'+str(gA1)+'_'+str(gA2)+'_'+str(gA3)+'_gNMDA'+str(gN1)+'_'+str(gN2)+'_'+str(gN3)+'_gGABA'+str(gG)+'_'+str(gG)+'_dep'+str(tauD)+'_'+str(min_glut)+'_'+str(pv)+'_tau'+str(tauN1)+'_'+str(tauN1)+'_'+str(tauN1)+'_'+str(tauN2)+'.mat'):
      print('python3 sim_mmns_2pm_sep_noISDIDD_savespikesonly.py '+str(Nperpop)+' '+str(paramSD)+' '+str(stA1)+' '+str(stA2)+' '+str(gA1)+' '+str(gA2)+' '+str(gA3)+' '+str(gN1)+' '+str(gN2)+' '+str(gN3)+' '+str(gG)+' '+str(gG)+' '+str(tauD)+' '+str(min_glut)+' '+str(pv)+' '+str(tauN1)+' '+str(tauN1)+' '+str(tauN1)+' '+str(tauN2))
      os.system('python3 sim_mmns_2pm_sep_noISDIDD_savespikesonly.py '+str(Nperpop)+' '+str(paramSD)+' '+str(stA1)+' '+str(stA2)+' '+str(gA1)+' '+str(gA2)+' '+str(gA3)+' '+str(gN1)+' '+str(gN2)+' '+str(gN3)+' '+str(gG)+' '+str(gG)+' '+str(tauD)+' '+str(min_glut)+' '+str(pv)+' '+str(tauN1)+' '+str(tauN1)+' '+str(tauN1)+' '+str(tauN2))
      A = scipy.io.loadmat('MMNs_2pm_sep_noISDIDD_Nperpop'+str(Nperpop)+'_paramSD'+str(paramSD)+'_stimA'+str(stA1)+'_'+str(stA2)+'_gAMPA'+str(gA1)+'_'+str(gA2)+'_'+str(gA3)+'_gNMDA'+str(gN1)+'_'+str(gN2)+'_'+str(gN3)+'_gGABA'+str(gG)+'_'+str(gG)+'_dep'+str(tauD)+'_'+str(min_glut)+'_'+str(pv)+'_tau'+str(tauN1)+'_'+str(tauN1)+'_'+str(tauN1)+'_'+str(tauN2)+'.mat')
      if isamp % 1000 != 0:
        print('rm MMNs_2pm_sep_noISDIDD_Nperpop'+str(Nperpop)+'_paramSD'+str(paramSD)+'_stimA'+str(stA1)+'_'+str(stA2)+'_gAMPA'+str(gA1)+'_'+str(gA2)+'_'+str(gA3)+'_gNMDA'+str(gN1)+'_'+str(gN2)+'_'+str(gN3)+'_gGABA'+str(gG)+'_'+str(gG)+'_dep'+str(tauD)+'_'+str(min_glut)+'_'+str(pv)+'_tau'+str(tauN1\
)+'_'+str(tauN1)+'_'+str(tauN1)+'_'+str(tauN2)+'.mat')
        os.system('rm MMNs_2pm_sep_noISDIDD_Nperpop'+str(Nperpop)+'_paramSD'+str(paramSD)+'_stimA'+str(stA1)+'_'+str(stA2)+'_gAMPA'+str(gA1)+'_'+str(gA2)+'_'+str(gA3)+'_gNMDA'+str(gN1)+'_'+str(gN2)+'_'+str(gN3)+'_gGABA'+str(gG)+'_'+str(gG)+'_dep'+str(tauD)+'_'+str(min_glut)+'_'+str(pv)+'_tau'+str(tauN1\
)+'_'+str(tauN1)+'_'+str(tauN1)+'_'+str(tauN2)+'.mat')
    else:
      A = scipy.io.loadmat('MMNs_2pm_sep_noISDIDD_Nperpop'+str(Nperpop)+'_paramSD'+str(paramSD)+'_stimA'+str(stA1)+'_'+str(stA2)+'_gAMPA'+str(gA1)+'_'+str(gA2)+'_'+str(gA3)+'_gNMDA'+str(gN1)+'_'+str(gN2)+'_'+str(gN3)+'_gGABA'+str(gG)+'_'+str(gG)+'_dep'+str(tauD)+'_'+str(min_glut)+'_'+str(pv)+'_tau'+str(tauN1)+'_'+str(tauN1)+'_'+str(tauN1)+'_'+str(tauN2)+'.mat')
    NoutputOutsides = []
    NoutputInsides = []
    Nbetweens = []
    NoutputOutsidesUnique = []
    NoutputInsidesUnique = []
    NbetweensUnique = []
    for MMNtype in [0,1,2,3]:
      s1 = A['standard'][MMNtype]
      s2 = A['deviant'][MMNtype]
      s3 = A['pacemaker'][MMNtype]
      s4 = A['output'][MMNtype]
      while len(s4) > 0:
        if type(s4[0]) == ndarray:
          s4 = s4[0]
        else:
          break
      ss4 = A['output'][MMNtype]
      while len(ss4) > 0:
        if type(ss4[-1]) == ndarray:
          ss4 = ss4[-1]
        else:
          break
      #if len(s4) == 1 and prod(s4.shape) > 1:
      #  s4 = s4[0]
      #if len(s4) == 2 and prod(s4.shape) > 1:
      #  s4 = s4[0]
      
      #NoutputOutside = len(s4) - len([1 for t in s4[0] if t >= 2350 and t < 2500])
      NoutputInsideUnique = len(unique([ss4[j] for j in range(0,len(s4)) if s4[j] > 2350 and s4[j] <= 2500]))
      NoutputOutsideUnique = [len(unique([ss4[j] for j in range(0,len(s4)) if s4[j] > 500*i-150 and s4[j] <= 500*i])) for i in [1,2,3,4,6,7]]
      NbetweenUnique = [len(unique([ss4[j] for j in range(0,len(s4)) if s4[j] > 500*i-450 and s4[j] <= 500*i-150])) for i in [1,2,3,4,6,7]]

      NoutputInside = len([1 for t in s4 if t > 2350 and t <= 2500])
      NoutputOutside = [len([1 for t in s4 if t > 500*i-150 and t <= 500*i]) for i in [1,2,3,4,6,7]]
      Nbetween = [len([1 for t in s4 if t > 500*i-450 and t <= 500*i-150]) for i in [1,2,3,4,6,7]]
      
      NoutputInsides.append(NoutputInside)
      NoutputOutsides.append(NoutputOutside[:])
      Nbetweens.append(Nbetween[:])

      NoutputInsidesUnique.append(NoutputInsideUnique)
      NoutputOutsidesUnique.append(NoutputOutsideUnique[:])
      NbetweensUnique.append(NbetweenUnique[:])

    NoutputInsides_all.append(NoutputInsides[:])
    NoutputOutsides_all.append(NoutputOutsides[:])
    Nbetweens_all.append(Nbetweens[:])

    NoutputInsidesUnique_all.append(NoutputInsidesUnique[:])
    NoutputOutsidesUnique_all.append(NoutputOutsidesUnique[:])
    NbetweensUnique_all.append(NbetweensUnique[:])

scipy.io.savemat('gridsearch40_2pm_sep7_'+str(istart)+'-'+str(iend)+'.mat', {'NoutputInsides_all': NoutputInsides_all, 'NoutputOutsides_all': NoutputOutsides_all, 'Nbetweens_all': Nbetweens_all, 'strs_all': strs_all, 'filenames_all': filenames_all, 'NoutputInsidesUnique_all': NoutputInsidesUnique_all, 'NoutputOutsidesUnique_all': NoutputOutsidesUnique_all, 'NbetweensUnique_all': NbetweensUnique_all, 'params': [gAMPA1,gAMPA2,nmdaAmpaRatio,gGABA,pvs,min_gluts,tauNeur1s,tauNeur2s]})