#cp ../haymod3e/runcontrol_check.py runcontrol.py
# runcontrols
# A script for determining the control neuron F-I curve and limit cycle.
#
# The input code for the hoc-interface is based on BAC_firing.hoc by Etay Hay (2011)
#
# Tuomo Maki-Marttunen, Oct 2014
# (CC BY)

from neuron import h
import mytools
import pickle
import numpy as np
import sys

#genetic variants, values taken as input arguments. first argument is patientID
coeffs = []

if len(sys.argv) > 1:
  for i,arg in enumerate(sys.argv):
    if i == 1:
      patientID=int(arg)
    elif i>1:
      coeffs.append(float(arg))


mutSuffs = ['Ca_HVA','Ca_LVAst','Im','K_Pst']
mutArray = ['gCa_HVAbar_Ca_HVA','gCa_LVAstbar_Ca_LVAst','gImbar_Im','gK_Pstbar_K_Pst']
print(patientID,coeffs)

spikfreqsAll = []

for icell in range(0,1):
  morphology_file = "morphologies/cell"+str(icell+1)+".asc"
  biophys_file = "models/L5PCbiophys3.hoc"
  template_file = "models/L5PCtemplate.hoc"
  v0 = -80
  ca0 = 0.0001

  distalpoint = 620

  h("""
	load_file("stdlib.hoc")
	load_file("stdrun.hoc")
	objref cvode
	cvode = new CVode()
	cvode.active(1)
	cvode.atol(0.0002)
	load_file("import3d.hoc")
	objref L5PC
	load_file(\""""+biophys_file+"""\")
	load_file(\""""+template_file+"""\")
	L5PC = new L5PCtemplate(\""""+morphology_file+"""\")
	access L5PC.soma
	objref st1
	st1 = new IClamp(0.5)
	L5PC.soma st1

	objref vsoma, sl, tvec
	vsoma = new Vector()
	tvec = new Vector()
	sl = new List()
	double siteVec[2]
	sl = L5PC.locateSites("apic","""+str(distalpoint)+""")
	maxdiam = 0
	for(i=0;i<sl.count();i+=1){
	  dd1 = sl.o[i].x[1]
	  dd = L5PC.apic[sl.o[i].x[0]].diam(dd1)
	  if (dd > maxdiam) {
	    j = i
	    maxdiam = dd
	  }
	}

	L5PC.soma cvode.record(&v(0.5),vsoma,tvec)

	""")

  # Add all SCZ conductance variants
  for sec in h.allsec(): #go through all sections
        for i,mut in enumerate(mutSuffs): #go throguh all mutations
          if h.ismembrane(str(mut),sec=sec): #do stuff if specific mutation is in specific section
            # print(sec,mutArray[i])
            curr_val = getattr(sec,mutArray[i])
            # print('Original value is', curr_val)
            setattr(sec,mutArray[i],curr_val*coeffs[i])
            # print('Updated value is',getattr(sec,mutArray[i]))

  Is = np.linspace(0.0,1.0,11)
  spikfreqs = np.zeros(len(Is))
  #spikfreqs = len(Is)*[0]
  for iI in range(0,len(Is)):
    squareAmp = Is[iI]
    squareDur = 15800
    tstop = 16000
    h("""
	tstop = """+str(tstop)+"""
	v_init = """+str(v0)+"""
	cai0_ca_ion = """+str(ca0)+"""
	st1.amp = """+str(squareAmp)+"""
	st1.dur = """+str(squareDur)+"""
	st1.del = 200
	""")
    h.init()
    h.run()

    times=np.array(h.tvec)
    Vsoma=np.array(h.vsoma)
    spikes = mytools.spike_times(times,Vsoma,-35,100)
    spikfreqs[iI] = sum([1 for x in spikes if x >= 500.0])/15.5


  spikfreqsAll.append(spikfreqs[:])

picklelist = [spikfreqsAll, Is]
file = open('saves_ACC/ACC_patientID_'+str(patientID)+'.sav', 'wb')
pickle.dump(picklelist,file)
file.close()