from neuron import h, rxd
from pylab import *
from matplotlib import pyplot
import scipy.io
import time
import re
import mytools

h.load_file('stdrun.hoc')

dend = h.Section(name='dend')
dend.L=1
dend.diam=0.79788
cyt = rxd.Region([dend], name='cyt', nrn_region='i')

mesh_input_file = open('mesh_general.out','r')
mesh_firstline = mesh_input_file.readline()
mesh_secondline = mesh_input_file.readline()
mesh_values = mesh_secondline.split()
my_volume = float(mesh_values[-2])*1e-15 #litres
mesh_input_file.close()

Duration = 1000
tolerance = 1e-6
gaba_input_onset = 800
gaba_input_N     = 100
gaba_input_freq  = 100
gaba_input_dur   = 0.005
gaba_input_flux  = 600.0
Ntrains        = 1
trainT = 3000
initfile = ''
addition = ''
blocked = []
blockeds = []
block_factor = 1.0
alteredk = []
alteredks = []
altered_factor = 1.0
record_T = 10.0 # default 10 ms recording interval
tolstochange = 'Gi,GiaGDP,Gibg'.split(',')
tolschange = [float(x) for x in '2,2,2'.split(',')]

if len(sys.argv) > 1:
  Duration = int(float(sys.argv[1]))
if len(sys.argv) > 2:
  tolerance = float(sys.argv[2])
if len(sys.argv) > 3:
  gaba_input_onset = float(sys.argv[3])
if len(sys.argv) > 4:
  gaba_input_N     = int(float(sys.argv[4]))
if len(sys.argv) > 5:
  gaba_input_freq  = float(sys.argv[5])
if len(sys.argv) > 6:
  gaba_input_dur   = float(sys.argv[6])
if len(sys.argv) > 7:
  gaba_input_flux  = float(sys.argv[7])
if len(sys.argv) > 8:
  Ntrains  = int(float(sys.argv[8]))
if len(sys.argv) > 9:
  trainT  = float(sys.argv[9])
if len(sys.argv) > 10:
  initfile = sys.argv[10]
if len(sys.argv) > 11:
  blocked = sys.argv[11]
  blockeds = blocked.split(',')
if len(sys.argv) > 12:
  block_factor = sys.argv[12]
  block_factors = [float(x) for x in block_factor.split(',')]
if type(blocked) is not list:
  addition = '_'+blocked+'x'+str(block_factor)
if len(sys.argv) > 13:
  alteredk = sys.argv[13]
  alteredks = [int(x) for x in alteredk.split(',')]
if len(sys.argv) > 14:
  alteredk_factor = sys.argv[14]
  alteredk_factors = [float(x) for x in alteredk_factor.split(',')]
if type(alteredk) is not list:
  addition = addition+'_k'+alteredk+'x'+str(alteredk_factor)
filename = 'nrn_tstop'+str(Duration)+'_tol'+str(tolerance)+addition+'_onset'+str(gaba_input_onset)+'_n'+str(gaba_input_N)+'_freq'+str(gaba_input_freq)+'_dur'+str(gaba_input_dur)+'_flux'+str(gaba_input_flux)+'_Ntrains'+str(Ntrains)+'_trainT'+str(trainT)+'.mat'
toBeRemovedIfNecessary = ['_tol1e-06','_tstop15000000','3560000_600000','_Ninputs1','_pulseamp5.0']
for i in range(0,len(toBeRemovedIfNecessary)):
  if len(filename) > 254:
    filename = filename.replace(toBeRemovedIfNecessary[i],'')
if len(sys.argv) > 15:
  filename = sys.argv[15]
if len(sys.argv) > 16:
  record_T = float(sys.argv[16])
if len(sys.argv) > 17:
  tolstochange = 'Gi,GiaGDP,Gibg,'+sys.argv[17]
  tolstochange = toltochange.split(',')
if len(sys.argv) > 18:
  tolschange = '2,2,2,'+sys.argv[18]
  tolschange = [float(x) for x in tolchange.split(',')]

initvalues = [0.0, 0.0, 0.00039999999999999996, 0.0, 0.0026, 0.0, 0.0, 0.0, 0.0, 0.0, 0.001, 0.0, 0.0, 0.001, 9.999999999999999e-05, 0.0, 0.0, 0.0, 0.0, 0.0]

species = ['gaba', 'gabaOut', 'GABABR', 'gabaGABABR', 'Gi', 'GABABRGi', 'gabaGABABRGi', 'gabaGABABRGibg', 'GiaGTP', 'GiaGDP', 'RGS', 'GiaGTPRGS', 'Gibg', 'GIRK', 'VGCC', 'GIRKGibg', 'GIRKGibg2', 'GIRKGibg3', 'GIRKGibg4', 'VGCCGibg']

print("my_volume = "+str(my_volume)+" l ?= "+str(dend.L*(dend.diam/2)**2*3.14159265358)+" um3")
if len(initfile) > 3 and initfile != 'None':
  DATA_init = scipy.io.loadmat(initfile)
  for ispec in range(0,len(species)):
    initvalues[ispec] = DATA_init['DATA'][1+ispec,-1]
    if species[ispec] not in DATA_init['headers'][1+ispec]:
      print("Warning: mismatch in DATA_init, ispec = "+str(ispec))
tolscales = [1.0 for i in range(0,len(species))]
for ispec in range(0,len(species)):
  for iblock in range(0,len(blockeds)):
    if re.match(blockeds[iblock]+'$',species[ispec]):
      initvalues[ispec] = block_factors[iblock]*initvalues[ispec]
  for itol in range(0,len(tolstochange)):
    if tolstochange[itol] == species[ispec]:
      tolscales[ispec] = tolscales[ispec]*10**(-tolschange[itol])
specs = []
for ispec in range(0,len(species)):
  specs.append(rxd.Species(cyt, name='spec'+str(ispec), charge=0, initial=initvalues[ispec], atolscale=tolscales[ispec]))
  if tolscales[ispec] != 1.0:
    print('spec'+str(ispec)+' ('+species[ispec]+') atolscale = '+str(tolscales[ispec]))
gaba_flux_rate = rxd.Parameter(cyt, initial=0)
ks = [1.0]*25
ks[0]   = 0.0005            # gaba --> gabaOut (forward)
ks[1]   = 5.555000000000001 # gaba + GABABR <-> gabaGABABR (forward)
ks[2]   = 0.005             # gaba + GABABR <-> gabaGABABR (backward)
ks[3]   = 150.0             # gabaGABABR + Gi <-> gabaGABABRGi (forward)
ks[4]   = 0.00025           # gabaGABABR + Gi <-> gabaGABABRGi (backward)
ks[5]   = 0.000125          # gabaGABABRGi --> gabaGABABRGibg + GiaGTP (forward)
ks[6]   = 0.001             # gabaGABABRGibg --> gabaGABABR + Gibg (forward)
ks[7]   = 75.0              # GABABR + Gi <-> GABABRGi (forward)
ks[8]   = 0.000125          # GABABR + Gi <-> GABABRGi (backward)
ks[9]   = 5.555000000000001 # gaba + GABABRGi <-> gabaGABABRGi (forward)
ks[10]  = 0.005             # gaba + GABABRGi <-> gabaGABABRGi (backward)
ks[11]  = 2.0               # GiaGTP + RGS <-> GiaGTPRGS (forward)
ks[12]  = 0.002             # GiaGTP + RGS <-> GiaGTPRGS (backward)
ks[13]  = 0.03              # GiaGTPRGS --> GiaGDP + RGS (forward)
ks[14]  = 1250.0            # GiaGDP + Gibg --> Gi (forward)
ks[15]  = 14.0              # GIRK + Gibg <-> GIRKGibg (forward)
ks[16]  = 0.001             # GIRK + Gibg <-> GIRKGibg (backward)
ks[17]  = 14.0              # GIRKGibg + Gibg <-> GIRKGibg2 (forward)
ks[18]  = 0.001             # GIRKGibg + Gibg <-> GIRKGibg2 (backward)
ks[19]  = 14.0              # GIRKGibg2 + Gibg <-> GIRKGibg3 (forward)
ks[20]  = 0.001             # GIRKGibg2 + Gibg <-> GIRKGibg3 (backward)
ks[21]  = 14.0              # GIRKGibg3 + Gibg <-> GIRKGibg4 (forward)
ks[22]  = 0.001             # GIRKGibg3 + Gibg <-> GIRKGibg4 (backward)
ks[23]  = 14.0              # VGCC + Gibg <-> VGCCGibg (forward)
ks[24]  = 0.001             # VGCC + Gibg <-> VGCCGibg (backward)

for ialteredk in range(0,len(alteredks)):
  ks[alteredks[ialteredk]] = alteredk_factors[ialteredk]*ks[alteredks[ialteredk]]
reaction000 = rxd.Reaction(specs[0], specs[1], ks[0])
reaction001 = rxd.Reaction(specs[0] + specs[2], specs[3], ks[1], ks[2])
reaction002 = rxd.Reaction(specs[3] + specs[4], specs[6], ks[3], ks[4])
reaction003 = rxd.Reaction(specs[6], specs[7] + specs[8], ks[5])
reaction004 = rxd.Reaction(specs[7], specs[3] + specs[12], ks[6])
reaction005 = rxd.Reaction(specs[2] + specs[4], specs[5], ks[7], ks[8])
reaction006 = rxd.Reaction(specs[0] + specs[5], specs[6], ks[9], ks[10])
reaction007 = rxd.Reaction(specs[8] + specs[10], specs[11], ks[11], ks[12])
reaction008 = rxd.Reaction(specs[11], specs[9] + specs[10], ks[13])
reaction009 = rxd.Reaction(specs[9] + specs[12], specs[4], ks[14])
reaction010 = rxd.Reaction(specs[13] + specs[12], specs[15], ks[15], ks[16])
reaction011 = rxd.Reaction(specs[15] + specs[12], specs[16], ks[17], ks[18])
reaction012 = rxd.Reaction(specs[16] + specs[12], specs[17], ks[19], ks[20])
reaction013 = rxd.Reaction(specs[17] + specs[12], specs[18], ks[21], ks[22])
reaction014 = rxd.Reaction(specs[14] + specs[12], specs[19], ks[23], ks[24])

reaction_gaba_flux = rxd.Rate(specs[0], gaba_flux_rate) # gaba
vec_t = h.Vector()

vecs = []
vec_t = h.Vector()
vec_t.record(h._ref_t)
for ispec in range(0,len(species)):
  vecs.append(h.Vector())
  vecs[ispec].record(specs[ispec].nodes(dend)(0.5)[0]._ref_concentration)

cvode = h.CVode()
cvode.active(1)
hmax = cvode.maxstep(1000)
hmin = cvode.minstep(1e-10)
cvode.atol(tolerance)

h.finitialize(-65)
def set_param(param, val):
    param.nodes.value = val
    h.cvode.re_init()

### Set on and off the inputs to the spine
T = 1000./gaba_input_freq
tnow = 0
for itrain in range(0,Ntrains):
    for istim in range(0,gaba_input_N):
      tnew = gaba_input_onset + istim*T + trainT*itrain
      h.cvode.event(tnew, lambda: set_param(gaba_flux_rate, gaba_input_flux/6.022e23/my_volume*1e3))
      h.cvode.event(tnew+gaba_input_dur, lambda: set_param(gaba_flux_rate, 0))
      tnow = tnew
timenow = time.time()
h.continuerun(Duration)
print("Simulation done in "+str(time.time()-timenow)+" seconds")
def isFlux(t):
  for itrain in range(0,Ntrains):
    for istim in range(0,gaba_input_N):
      tnew = gaba_input_onset + istim*T + trainT*itrain
      if t >= tnew and t < tnew+gaba_input_dur:
        return 1
  return 0
tvec = array(vec_t)
minDT_nonFlux = 20.0
minDT_Flux = 1.0
lastt = -inf
itvec2 = []
for it in range(0,len(tvec)):
  if tvec[it] - lastt > minDT_nonFlux or (isFlux(tvec[it]) and tvec[it] - lastt > minDT_Flux):
    itvec2.append(it)
    lastt = tvec[it]

headers = [ 'tvec', 'gaba', 'gabaOut', 'GABABR', 'gabaGABABR', 'Gi', 'GABABRGi', 'gabaGABABRGi', 'gabaGABABRGibg', 'GiaGTP', 'GiaGDP', 'RGS', 'GiaGTPRGS', 'Gibg', 'GIRK', 'VGCC', 'GIRKGibg', 'GIRKGibg2', 'GIRKGibg3', 'GIRKGibg4', 'VGCCGibg' ]

myonset = gaba_input_onset
if myonset > max(tvec):
  myonset = 0
interptimes = [myonset + record_T*i for i in range(-1,int((max(tvec)-myonset)/record_T))]
if interptimes[0] < 0:
  interptimes = interptimes[1:]
interpDATA = []
for j in range(0,len(species)):
  interpDATA.append(mytools.interpolate(tvec,vecs[j],interptimes))
tcDATA = array([interptimes]+interpDATA)
maxDATA = c_[tvec,array(vecs).T].max(axis=0)
scipy.io.savemat(filename, {'DATA': tcDATA, 'maxDATA': maxDATA, 'headers': headers})