#cp sim_mmns_2pm_sep_nopm2_savespikesonly.py sim_mmns_2pm_sep_nopm2_nopm_savespikesonly.py #removed the primary pacemaker population too, everything else (especially the argv inputs) as previously
#cp sim_mmns_2pm_sep_savespikesonly.py sim_mmns_2pm_sep_nopm2_savespikesonly.py #removed the pacemaker2 population, everything else (especially the argv inputs) as previously
#after _old: add fourth protocol (long standards, short deviants)
#after _old3: add pacemaker2
#cp sim_mmns_savespikesonly.py sim_mmns_sep_savespikesonly.py #Add slowly excited population
from brian2 import *
from pylab import *
import scipy.io
import time


Nperpop = 10
Nstandards = 4
Nstandards_after = 2
stimUnitDt = 50 #The minimum length (ms) of DC stimuli (as well as non-stimulus periods)
stimT = 500 # stimulus interval
stimA1 = 100 # stimulus amplitude
stimA2 = 90 # stimulus amplitude, pacemakers
nSP = int(stimT/stimUnitDt)-1 #this many silent periods (of duration stimUnitDt) before a stimulus period

paramSD = 0.0
gAMPA1 = 25.0 
gAMPA2 = 25.0
gAMPA2 = 25.0
gNMDA1 = 1.0
gNMDA2 = 1.0
gNMDA3 = 1.0
gGABA1 = 30
gGABA2 = 30
tauD = 800
min_glut = 0.1
tauNeur1 = 10.0
tauNeur2 = 10.0
tauNeur3 = 10.0
tauNeur4 = 30.0
pv = 0.6
myseed = 1

# Parameters

if len(sys.argv) > 1:
    Nperpop = int(sys.argv[1])
if len(sys.argv) > 2:
    paramSD = float(sys.argv[2])
if len(sys.argv) > 3:
    stimA1 = int(sys.argv[3])
if len(sys.argv) > 4:
    stimA2 = int(sys.argv[4])
if len(sys.argv) > 5:
    gAMPA1 = float(sys.argv[5]) #Standard/deviant to Output AMPA conductance
if len(sys.argv) > 6:
    gAMPA2 = float(sys.argv[6]) #PaceMaker to Output AMPA conductance
if len(sys.argv) > 7:
    gAMPA3 = float(sys.argv[7]) #Standard boost/deviant boost to Output AMPA conductance
if len(sys.argv) > 8:
    gNMDA1 = float(sys.argv[8]) #Standard/deviant to Output NMDA conductance
if len(sys.argv) > 9:
    gNMDA2 = float(sys.argv[9]) #PaceMaker to Output NMDA conductance
if len(sys.argv) > 10:
    gNMDA3 = float(sys.argv[10]) #Standard boost/deviant boost to Output NMDA conductance
if len(sys.argv) > 11:
    gGABA1 = float(sys.argv[11])  #Standard/deviant to Pacemaker GABA conductance
if len(sys.argv) > 12:
    gGABA2 = float(sys.argv[12])  #Standard boost/deviant boost to Pacemaker GABA conductance
if len(sys.argv) > 13:
    tauD = int(sys.argv[13])
if len(sys.argv) > 14 :
    min_glut = float(sys.argv[14])
if len(sys.argv) > 15:
    pv = float(sys.argv[15])
if len(sys.argv) > 16:
    tauNeur1 = float(sys.argv[16])
if len(sys.argv) > 17:
    tauNeur2 = float(sys.argv[17])
if len(sys.argv) > 18:
    tauNeur3 = float(sys.argv[18])
if len(sys.argv) > 19:
    tauNeur4 = float(sys.argv[19])    
if len(sys.argv) > 20:
    myseed = int(float(sys.argv[20]))
                 
seed(myseed)
np.random.seed(myseed)
myseedAdd = '' if myseed == 1 else '_seed'+str(myseed)

spikes_standard = []
spikes_deviant = []
spikes_output = []
spikes_standardBoost = []
spikes_deviantBoost = []

for MMNtype in [0,1,2,3]:

  start_scope()

  #mggate = 1 / (1 + exp(0.062 (/mV) * -(v)) * (MgCon / 3.57 (mM)))

  eqs = '''
dv/dt = (stimulusDeviant(t)*coeffDeviant+stimulusStandard(t)*coeffStandard+ g_leak*(e_leak-v) + s_ampatot1*(0-v) + s_nmdatot1*(0-v)/(1+exp(-0.062*v)/3.57) + s_gabatot1*(-90-v) +s_ampatot2*(0-v) + s_nmdatot2*(0-v)/(1+exp(-0.062*v)/3.57) + s_gabatot2*(-90-v) + s_ampatot3*(0-v) + s_nmdatot3*(0-v)/(1+exp(-0.062*v)/3.57) + s_gabatot3*(-90-v) + s_ampatot4*(0-v) + s_nmdatot4*(0-v)/(1+exp(-0.062*v)/3.57) + s_gabatot4*(-90-v) + s_ampatot5*(0-v) + s_nmdatot5*(0-v)/(1+exp(-0.062*v)/3.57) + s_gabatot5*(-90-v) + s_ampatot6*(0-v) + s_nmdatot6*(0-v)/(1+exp(-0.062*v)/3.57) + s_gabatot6*(-90-v))/tau : 1 (unless refractory)
I : 1
tau : second
s_ampatot1 : 1
s_nmdatot1 : 1
s_gabatot1 : 1
s_ampatot2 : 1
s_nmdatot2 : 1
s_gabatot2 : 1
s_ampatot3 : 1
s_nmdatot3 : 1
s_gabatot3 : 1
s_ampatot4 : 1
s_nmdatot4 : 1
s_gabatot4 : 1
s_ampatot5 : 1
s_nmdatot5 : 1
s_gabatot5 : 1
s_ampatot6 : 1
s_nmdatot6 : 1
s_gabatot6 : 1
coeffStandard: 1
coeffDeviant: 1
e_leak : 1
g_leak : 1
'''


  def create_synapse_equations(suffix='1'):
    return f'''
    ds_ampa/dt = alpha_s*x_ampa*(1-s_ampa)-s_ampa/tau_ampa : 1
    dx_ampa/dt = -x_ampa/tau_x_ampa : 1
    ds_nmda/dt = alpha_s*x_nmda*(1-s_nmda)-s_nmda/tau_nmda : 1
    dx_nmda/dt = -x_nmda/tau_x_nmda : 1
    ds_gaba/dt = -s_gaba/tau_gaba : 1
    dD_glut/dt = (1-D_glut)/tau_D_glut : 1
    s_ampatot{suffix}_post = g_ampa*s_ampa : 1 (summed)
    s_nmdatot{suffix}_post = g_nmda*s_nmda : 1 (summed)
    s_gabatot{suffix}_post = g_gaba*s_gaba : 1 (summed)
    tau_ampa : second
    tau_x_ampa : second
    tau_nmda : second
    tau_x_nmda : second
    tau_gaba : second
    tau_D_glut : second
    min_D_glut : 1
    alpha_s : 1/second
    alpha_ampa : 1
    alpha_nmda : 1
    alpha_gaba : 1
    pv : 1
    g_ampa : 1
    g_nmda : 1
    g_gaba : 1
    '''
  '''
this function exist because Brian does not allow for 
the connection of several neurons for the same summed variables
'''

  eq_syn1 = create_synapse_equations('1')
  eq_syn2 = create_synapse_equations('2')
  eq_syn3 = create_synapse_equations('3')
  eq_syn4 = create_synapse_equations('4')
  eq_syn5 = create_synapse_equations('5')
  eq_syn6 = create_synapse_equations('6')

  eq_syn_on_pre = '''
x_ampa += alpha_ampa*D_glut
x_nmda += alpha_nmda*D_glut
s_gaba += alpha_gaba
D_glut += -pv*(D_glut-min_D_glut)
'''

  '''
stimulusStandard:responsible for the standard signal
stimulusDeviant:responsible for the deviant signal
'''

  stimListStandard = []
  stimListDeviant = []
  for istim in range(0,Nstandards + 1 + Nstandards_after):
    if istim < Nstandards or istim > Nstandards:
      if MMNtype < 3:
        stimListStandard = stimListStandard+[0]*nSP+[stimA1]
      else:
        stimListStandard = stimListStandard+[0]*(nSP-1)+[stimA1]*2
      stimListDeviant = stimListDeviant+[0]*nSP+[0]
    else:
      #Set the standard tone input
      if MMNtype == 2: #In omission and different-pitch MMN the standard tone is silent at the time of deviant
        stimListStandard = stimListStandard+[0]*(nSP-1)+[stimA1]*2
      elif MMNtype == 3:
        stimListStandard = stimListStandard+[0]*nSP+[stimA1]
      else:
        stimListStandard = stimListStandard+[0]*nSP+[0]
 
      #Set the deviant tone input
      if MMNtype == 1:
        stimListDeviant = stimListDeviant+[0]*nSP+[stimA1]
      else:
        stimListDeviant = stimListDeviant+[0]*nSP+[0]

  stimListStandard = stimListStandard + [0]*nSP
  stimListDeviant = stimListDeviant + [0]*nSP
      
  #f,axarr = subplots(3,1)
  #axarr[0].plot(stimListStandard)
  #axarr[1].plot(stimListDeviant)
  #axarr[2].plot(stimListPaceMaker)
  #f.savefig("test"+str(MMNtype)+".pdf")
  #qwe
      

  stimulusStandard = TimedArray(array(stimListStandard+[0]*100), dt=50*ms)
  stimulusDeviant = TimedArray(array(stimListDeviant+[0]*100), dt=50*ms)

  '''
standardPopulation:      population that fires in response to the standard signal
deviantPopulation:       population that fires in response to the deviant signal
standardBoostPopulation: population that fires in response to prolonged standard signal
deviantBoostPopulation:  population that fires in response to prolonged deviant signal
outputPopulation:        population that displays the output
'''
  standardPopulation = NeuronGroup(2*Nperpop, eqs, threshold='v>-40', reset='v = -80', refractory=2*ms, method='rk4')
  deviantPopulation = NeuronGroup(2*Nperpop, eqs, threshold='v>-40', reset='v = -80', refractory=2*ms, method='rk4') 
  standardBoostPopulation = NeuronGroup(2*Nperpop, eqs, threshold='v>-40', reset='v = -80', refractory=2*ms, method='rk4')
  deviantBoostPopulation = NeuronGroup(2*Nperpop, eqs, threshold='v>-40', reset='v = -80', refractory=2*ms, method='rk4') 
  outputPopulation = NeuronGroup(Nperpop, eqs, threshold='v>-40', reset='v = -80', refractory=2*ms, method='rk4')

  # Configuration for standardPopulation
  standardPopulation.v = -80
  standardPopulation.tau = tauNeur1*(1+randn(2*Nperpop)*paramSD)*ms #[10]*(2*Nperpop)*ms
  standardPopulation.coeffStandard = [1]*(2*Nperpop)
  standardPopulation.g_leak = 2.0*(1+randn(2*Nperpop)*paramSD)
  standardPopulation.e_leak = -80

  # Configuration for deviantPopulation
  deviantPopulation.v = -80
  deviantPopulation.tau = tauNeur1*(1+randn(2*Nperpop)*paramSD)*ms
  deviantPopulation.coeffDeviant = [1]*(2*Nperpop)
  deviantPopulation.g_leak = 2.0*(1+randn(2*Nperpop)*paramSD)
  deviantPopulation.e_leak = -80

  # Configuration for outputPopulation
  outputPopulation.v = -80
  outputPopulation.tau = tauNeur3*(1+randn(Nperpop)*paramSD)*ms
  outputPopulation.g_leak = 2.0*(1+randn(Nperpop)*paramSD)
  outputPopulation.e_leak = -80

  # Configuration for standardBoostPopulation
  standardBoostPopulation.v = -80
  standardBoostPopulation.tau = tauNeur4*(1+randn(2*Nperpop)*paramSD)*ms #[10]*(2*Nperpop)*ms
  standardBoostPopulation.coeffStandard = [1]*(2*Nperpop)
  standardBoostPopulation.g_leak = 2.0*(1+randn(2*Nperpop)*paramSD)
  standardBoostPopulation.e_leak = -80

  # Configuration for deviantBoostPopulation
  deviantBoostPopulation.v = -80
  deviantBoostPopulation.tau = tauNeur4*(1+randn(2*Nperpop)*paramSD)*ms
  deviantBoostPopulation.coeffDeviant = [1]*(2*Nperpop)
  deviantBoostPopulation.g_leak = 2.0*(1+randn(2*Nperpop)*paramSD)
  deviantBoostPopulation.e_leak = -80

  # Synapses configurations
  standardToOutputSynapses = Synapses(standardPopulation, outputPopulation, model = eq_syn1, on_pre = eq_syn_on_pre)
  deviantToOutputSynapses = Synapses(deviantPopulation, outputPopulation, model = eq_syn3, on_pre = eq_syn_on_pre)
  standardBoostToOutputSynapses = Synapses(standardBoostPopulation, outputPopulation, model = eq_syn4, on_pre = eq_syn_on_pre)
  deviantBoostToOutputSynapses = Synapses(deviantBoostPopulation, outputPopulation, model = eq_syn5, on_pre = eq_syn_on_pre)
  
  # Matrix A: Random connectivity from standard to output
  MAT_A = np.random.rand(Nperpop, Nperpop) < 0.5
  
  # Matrix B: Random connectivity from deviant to output
  MAT_B = np.random.rand(Nperpop, Nperpop) < 0.5

  # Matrix C: Random connectivity from pacemaker to output
  MAT_C = np.random.rand(Nperpop, Nperpop) < 0.5

  # Matrix D: Random connectivity from standardBoost to output
  MAT_D = np.random.rand(Nperpop, Nperpop) < 0.5

  # Matrix E: Random connectivity from deviantBoost to output
  MAT_E = np.random.rand(Nperpop, Nperpop) < 0.5


  taus = []
  g_ampas = []
  g_nmdas = []
  g_gabas = []

  g_ampas2 = []
  g_nmdas2 = []
  g_gabas2 = []

  g_ampas2b = []
  g_nmdas2b = []
  g_gabas2b = []

  g_ampas3 = []
  g_nmdas3 = []
  g_gabas3 = []

  g_ampas4 = []
  g_nmdas4 = []
  g_gabas4 = []

  g_ampas4b = []
  g_nmdas4b = []
  g_gabas4b = []

  g_ampas5 = []
  g_nmdas5 = []
  g_gabas5 = []

  g_ampas5b = []
  g_nmdas5b = []
  g_gabas5b = []

  g_ampas6 = []
  g_nmdas6 = []
  g_gabas6 = []

  g_ampas6b = []
  g_nmdas6b = []
  g_gabas6b = []

  g_ampas6c = []
  g_nmdas6c = []
  g_gabas6c = []

  g_ampas7 = []
  g_nmdas7 = []
  g_gabas7 = []

  g_ampas7b = []
  g_nmdas7b = []
  g_gabas7b = []

  g_ampas7c = []
  g_nmdas7c = []
  g_gabas7c = []


  '''
The following two functions show the circuit connections (all connections have a depressing effect when exciting):

standardToOutputSynapses:           standardPopulation Excites outputPopulation
deviantToOutputSynapses:            deviantPopulation Excites outputPopulation
standardBoostToOutputSynapses:      standardBoostPopulation Excites outputPopulation
deviantBoostToOutputSynapses:       deviantBoostPopulation Excites outputPopulation
'''

  def connect_synapses(MAT, synapse, g_ampas_to_update, g_nmdas_to_update, g_gabas_to_update, gAMPA_this, gNMDA_this, gGABA_this=None):
    for iy in range(0, Nperpop): 
      for ix in range(0, Nperpop): 
        if MAT[iy, ix]:
          synapse.connect(i=iy, j=ix)
          g_ampas_to_update.append(gAMPA_this / Nperpop if gGABA_this is None else 0.0)
          g_nmdas_to_update.append(gNMDA_this / Nperpop if gGABA_this is None else 0.0)
          g_gabas_to_update.append(0.0 if gGABA_this is None else gGABA_this / Nperpop)
    

  # Use the updated synapse names in the connection and initialization functions
  connect_synapses(MAT_A[:,0:Nperpop],standardToOutputSynapses, g_ampas, g_nmdas, g_gabas, gAMPA1, gNMDA1)
  connect_synapses(MAT_B[:,0:Nperpop],deviantToOutputSynapses, g_ampas3, g_nmdas3, g_gabas3, gAMPA1, gNMDA1)
  connect_synapses(MAT_D[:,0:Nperpop],standardBoostToOutputSynapses, g_ampas6, g_nmdas6, g_gabas6, gAMPA3, gNMDA3)
  connect_synapses(MAT_E[:,0:Nperpop],deviantBoostToOutputSynapses, g_ampas7, g_nmdas7, g_gabas7, gAMPA3, gNMDA3)

  tstop = (nSP+1)*(Nstandards+1+Nstandards_after+1)*stimUnitDt
  print("tstop = "+str(tstop))

  def initialize_synapse(S, tauD, pv, g_ampas_this, g_nmdas_this, g_gabas_this):
    S.tau_x_ampa = 0.5*msecond   # AMPA rise time constant
    S.tau_ampa = 5*msecond       # AMPA decay time constant
    S.tau_x_nmda = 2*msecond     # NMDA rise time constant
    S.tau_nmda = 50*msecond      # NMDA decay time constant
    S.tau_gaba = 10*msecond      # GABA decay time constant
    S.tau_D_glut = tauD*msecond  # Presynaptic depression time constant
    S.alpha_ampa = 1.0
    S.alpha_nmda = 1.0
    S.alpha_gaba = 1.0
    S.alpha_s = 1.0/msecond # x to s coupling strength
    S.pv = pv
    S.g_ampa = g_ampas_this
    S.g_nmda = g_nmdas_this
    S.g_gaba = g_gabas_this
    S.s_ampa = 0
    S.x_ampa = 0
    S.s_nmda = 0
    S.x_nmda = 0
    S.s_gaba = 0
    S.D_glut = 1
    S.min_D_glut = min_glut
    

  initialize_synapse(standardToOutputSynapses, tauD, pv, g_ampas, g_nmdas, g_gabas)
  initialize_synapse(deviantToOutputSynapses, tauD, pv, g_ampas3, g_nmdas3, g_gabas3)
  initialize_synapse(standardBoostToOutputSynapses, tauD, pv, g_ampas6, g_nmdas6, g_gabas6)
  initialize_synapse(deviantBoostToOutputSynapses, tauD, pv, g_ampas7, g_nmdas7, g_gabas7)

  V_outputPopulation = StateMonitor(outputPopulation, 'v', record=True)
  outputPopulationSpikeMonitor = SpikeMonitor(outputPopulation)
  standardPopulationSpikeMonitor = SpikeMonitor(standardPopulation)
  deviantPopulationSpikeMonitor = SpikeMonitor(deviantPopulation) 
  standardBoostPopulationSpikeMonitor = SpikeMonitor(standardBoostPopulation)
  deviantBoostPopulationSpikeMonitor = SpikeMonitor(deviantBoostPopulation) 

  timenow = time.time()
  run(tstop*ms)
  print("Simulation "+str(MMNtype)+" run in "+str(int(time.time()-timenow))+" seconds")

  spikes_standard.append([standardPopulationSpikeMonitor.t/msecond, array(standardPopulationSpikeMonitor.i)])
  spikes_deviant.append([deviantPopulationSpikeMonitor.t/msecond, array(deviantPopulationSpikeMonitor.i)])
  spikes_output.append([outputPopulationSpikeMonitor.t/msecond, array(outputPopulationSpikeMonitor.i)])
  spikes_standardBoost.append([standardBoostPopulationSpikeMonitor.t/msecond, array(standardBoostPopulationSpikeMonitor.i)])
  spikes_deviantBoost.append([deviantBoostPopulationSpikeMonitor.t/msecond, array(deviantBoostPopulationSpikeMonitor.i)])

scipy.io.savemat('MMNs_2pm_sep_nopm2_nopm_Nperpop'+str(Nperpop)+'_paramSD'+str(paramSD)+'_stimA'+str(stimA1)+'_'+str(stimA2)+'_gAMPA'+str(gAMPA1)+'_'+str(gAMPA2)+'_'+str(gAMPA3)+'_gNMDA'+str(gNMDA1)+'_'+str(gNMDA2)+'_'+str(gNMDA3)+'_gGABA'+str(gGABA1)+'_'+str(gGABA2)+'_dep'+str(tauD)+'_'+str(min_glut)+'_'+str(pv)+'_tau'+str(tauNeur1)+'_'+str(tauNeur2)+'_'+str(tauNeur3)+'_'+str(tauNeur4)+myseedAdd+'.mat',  {'standard': spikes_standard, 'deviant': spikes_deviant, 'output': spikes_output, 'standardBoost': spikes_standardBoost, 'deviantBoost': spikes_deviantBoost })