from neuron import h
import numpy as np

VERBOSE = 0
from new_geomnseg import *

def range_func(sect,typ,bounds):
	for seg in sect:
		exec("seg." + typ +"= np.interp(seg.x,[0,1],bounds)")

def init(neuron_obj, p):
	test = neuron_obj
	
	#old parameters adapted into dict
	distalance = p['dst']
	cattog = p['gcat']
	hcntog = p['ghcn']
	taukv4 = p['taukv4']
	shift = -p['na_shift']
	hshift = p['na_hshift']
	SCALE = p['scale']
	#h.topology(test)
	#quit()
	
	h.distance(0,test.soma(0.5)) # set 0 of distance to soma # may be redundant
	
	# intitialize area, capacitance that are summed over segments, sections
	tcap = 0  
	tnseg =0
	
	# most channels are do not distinguish soma, dendrites in existance, but may in conductance
	# create once, alter once as needed
	somatodend = h.SectionList()
	for sec in test.somatic:
		somatodend.append(sec)
	if test.basal != None:
		for sec in test.basal:
			somatodend.append(sec)
	adpType=h.ion_register('adp',0) # create ion for adp with valence 0 (it tracks concentration, but produces no current)
	cacicrType=h.ion_register('cacicr',2)
	caskType=h.ion_register('cask',2)		
	index = 0
	somdem = somatodend
	for s in somatodend:
		if s.nseg < 3:
			s.nseg = max(3,int(3*p['ra'])+(int(3*p['ra'])%2-1))
		tnseg+=s.nseg

		#s.nseg = s.nseg*4+1
		s.Ra = 100.0*p['ra']
		s.cm = 1.0*p['cm']

		s.insert('girk')
		
		s.girkbar_girk = 1e-6*p['girk']
		
		s.insert('hhb') # for Kv4, Kdr
		s.tfast_hhb = p['tfast']
		#s.insert('ican2') # non-specific, calcium sensitive channel
		#s.gcanbar_ican2 = 1e-6*p['gcan'] 
		s.insert('km') # muscarinic sensitive potassium channel
		s.insert('leak')

		s.insert('NaMark') # reduced Markov model of NaV1.6
		s.gnabar_NaMark = 1000e-6*p['gnahh']
		
		#s.insert('GRC_NA')
		#s.gnabar_GRC_NA = 1000e-6*p['gnahh']
		#print('define')

		
		#h.ion_style('adp',2,0,0,0,1) # concentration assigned, no reversal potential (x3), initialize to same value in all compartments
		# this is redundant with those parameter values
		
		s.insert('cabalcicr') # ca, adp stociometry.  ADP created 1 for 1 with Ca2+ extrusion

		s.insert('hcn') # non-specific, voltage gated cation channel
		if p['gkatp'] > 0:
			s.insert('katpStoch') # katp channel with stochastic markov scheme.  High single unit conductance
		
		#s.insert('cicr') # microdomain for calcium from specific subsets of calcium channels to selectively activate SK
		s.insert('kca') # SK
		
		s.insert('catchan') # CaV3
		s.insert('canchan') # CaV2.x
		s.insert('calhh') # CaL1.3 with slow partial voltage gated inactivation from Shin et al 2022
		s.insert('erg') # ether-a-go-go related potassium channel (K11?)
		
		s.insert('bkc')  # BK
		
		s.insert('typem') # a module that creates a range variable that specifies compartment type by color
	
		###### set values ######
		s.gkbar_bkc = 500e-6*p['gbk']
		s.bkscale_bkc =10 # 1 -> ~0.1  100 -> ~0.5 peak activation for spike in conventional cell
		

		#h.setpointer(s.ica_canchan, 'icabk', s.bkc) # couples bk channel to n-type channel
	

		s.FAST_NaMark = p['fast'] # speed at which fast inactivated state transfers into closed state
		
		s.gkmbar_km = 100e-6*p['gkm']

		# shifts to half, slope of Kv4 inactivation from values in Tarfa et al, Costa et al 2023
		s.ashift_hhb = p['kasv']
		s.asshift_hhb = p['kass']
		if h.ismembrane('katpStoch'):
			s.gkatpbar_katpStoch = 100e-6*p['gkatp']
			s.alphm_katpStoch *= p['skatp'] # 'reuptake' of adp
			s.km_katpStoch *= p['kmatp'] # 'MgADP' ec50
			s.n_katpStoch = p['nkatp'] # effective exponent
		
		s.gkabar_hhb = 300e-6*p['gka']*1.5 # 1.5 to match data from Roeper lab while keeping sum of fast/slow to 1
		s.mbar_hcn = p['mbar']
		
		# these is used for fast-slow analysis.  fptog fixes the inactivation values of kv4 at qfast, qslow respectively.
		s.qslow_hhb = p['qslow']
		s.qfast_hhb = p['qfast']
		s.fptog_hhb = p['fixkv4'] 

		# minimum KDR time constant at hyperpolarized voltages.
		s.nspeed_hhb = p['nspeed']
		s.gkhhbar_hhb = 200e-6*p['gkhh']
		
				
		s.scale_NaMark = SCALE
		#s.shift_NaMark = shift
		s.hshift_NaMark = hshift
		
		#s.Vshift_GRC_NA = shift
		#s.Valfa_GRC_NA = 14+p['na_sshift']
		#s.ACoff_GRC_NA = p['na_hshift']
		s.gergbar_erg = 100e-6*p['gkerg']
		
		for seg in s:
			tcap+=s.cm*h.area(seg.x,sec=s)

		s.gkbar_leak=1.0e-6*p['glk']
		s.gnabar_leak=1.0e-6*p['glna']

		s.gkbar_kca = 100e-6*p['gkca']
		s.tausk_kca = p['tausk']
		s.icapumpmax_cabalcicr *= p['capump']
		s.proxy_cabalcicr *= p['proxyadp']

		
		s.gcanbar_canchan =200e-6 #
		s.gcalbar_calhh =10.0e-6*p['gcal'] # 
		s.pf_calhh=p['pfcal']
		s.skcoup_calhh = p['coup'] # coupling of L-type to SK pool
		s.mhalf_calhh = -35 # half activation
		
		#s.gnabar_hhb *= 1.0
		s.gkbar_leak *=1.0
		s.gkhhbar_hhb *=1.0
		s.ghcnbar_hcn = 50.0e-6*hcntog
		
		
		
		s.taukv4_hhb = taukv4

		s.dist_NaMark = p['slow']
		s.v = -50 # -40 starts in db
		s.cai = 1e-9
		s.hshift_NaMark = p['na_hshift']
		s.nslope_hhb = p['nslope']
		s.nshift_hhb = p['nshift']
		s.kchip_hhb = p['kchip']
		s.fchip_hhb = p['fchip']
		

		if p['DISABLE'] ==0: # turn off sodium inactivation, lock state into available pool
			s.DISABLE_NaMark = p['DISABLE']
			s.oinit_NaMark=0.5
		
		# this part is formatted differently (better?) because it came from an older version
		# that version was structured such that the channel creation, editting was done separately
		if h.ismembrane('kca',sec=s): # syntax for edit channel if present
			s.tausk_kca = p['tausk']
			if h.ismembrane('cicr', sec=s) or h.ismembrane('cabalcicr'):
				s.uselocal_kca = p['local']
			s.km_kca *=1.0
		if h.ismembrane('calmark', sec=s): # Markov alternative version of CaL
			s.gcalbar_calmark *=0.0
			s.moff_calmark = -30
			s.mslope_calmark = 5.0
		if h.ismembrane('calhh', sec=s):
			#print 'derp'
			s.gcalbar_calhh *=1
			s.mhalf_calhh = -35
			s.mslope_calhh = 5.0
			s.pf_calhh = p['pfcal']

		if h.ismembrane("hcn",sec=s):
			s.scale_hcn =1
			s.mhalf_hcn = -75.0
		if h.ismembrane('cabalcicr',sec=s): # convert to instant buffering
			#s.MitoBuffer_cabalthin = 0.03
			for seg in s:  # by seg as diam not fixed over sec 
				seg.TotalBuffer_cabalcicr = 0.03
				seg.SCALE_cabalcicr = 1.0
				#seg.shellfrac_cabalcicr = min(0.1/seg.diam,1.0)
				seg.tog_cabalcicr = p['cicrtog'] # syntax for calcium release via IP3 is present but not used. 
				seg.DCa_cabalcicr = p['dca'] # calcium diffusion constant (radial)
				seg.imetamax_cabalcicr *= p['meta']
				seg.kadp_cabalcicr = p['kadp'] 
				seg.constrict_cabalcicr=p['constrict']
				seg.dense_cabalcicr = p['dense']
				seg.dsk_cabalcicr = p['dsk']
				seg.dcicr_cabalcicr = p['dcicr']
		#s.ashift_hhb = 0
		if h.ismembrane('catchan',sec=s):
			s.phi_h_catchan = 1.5
			s.shift_catchan = 0
			s.hhalf_catchan=-80
			s.tcicr_catchan=p['tcicr']#p['catbuff']#*(1+(index%3))/4.0 # move to Beyond
			s.tsk_catchan=p['catbuff']#*(1+(index%3))/4.0 # move to Beyond
			s.hfixed_catchan = p['catfix'] # spayed?
			if test.basal==None:
				s.gcatbar_catchan*=p['gcat']
			else:
				tclose =150
				for seg in s:
					temp =h.distance(seg.x,sec=s)
					#print(s,temp)
					if temp < tclose:
						#print(0,h.distance(0.5,sec=s))
						seg.gkabar_hhb *= 1.5
						seg.gcatbar_catchan = 0
					else:
						#print(1,h.distance(0.5,sec=s))
						seg.gcatbar_catchan*=p['gcat']*min(1,(temp-tclose)/(1.0*tclose)) #
						seg.gkabar_hhb*=(1-p['taperkv4']) #max(0.5,(h.distance(seg.x,sec=s)-tclose)/(5*tclose))
						seg.ghcnbar_hcn*=1#min(1,tclose/h.distance(seg.x,sec=s))
						#seg.gkbar_kca*=max(0.5,(h.distance(seg.x,sec=s)-tclose)/(5*tclose))
					if temp > 150:
						#print(0,h.distance(0.5,sec=s))
						seg.gnabar_NaMark *=0.75
						seg.gnabar_leak*=1
					if seg.x > 0.5 and temp > 150:
						seg.gnabar_NaMark *=1 # cumulative with above
						#seg.tcicr_catchan *= 2
						if seg.tcicr_catchan > 1:
							seg.tcicr_catchan=1 # goes directly into SK pool
					#else:
					#	seg.tsk_catchan = 0
				#seg.tcicr_catchan*=2 
		
		if h.ismembrane('nmda'):
			s.cafrac_nmda = p['cafrac']
  
		if p['musc'] > 0: # simulating effects of Ach
			#s.gnabar_NaMark *= 0.67 #12 to 8
			s.gkmbar_km *=0.2
			s.gkbar_kca *= 1
	
		index+=1
	for s in test.somatic:
		s.marker_typem = 0

		#s.deptog_hhb = 1
		#s.gcatbar_catchan*=0
		if test.basal!=None: # reduce size of sodium conductance if not a point model
			#s.gnabar_NaMark *=0
			s.gcatbar_catchan*=0
		
		#else:
		#	s.gkabar_hhb *=1
	if test.abd != None:
		for s in test.abd:
			s.marker_typem = 0.25
			if s.nseg < 5:
				s.nseg = 5
			#s.deptog_hhb = 1
			#s.gcatbar_catchan*=0
	
			#s.gnabar_NaMark *=0.1
	
			#s.gkabar_hhb *=2
			#s.gnabar_NaMark*=0.25
			#s.gnabar_NaMark*=0.1
	
		
	areas = []
	for s in test.all: # sum up to get areas, total area
		#s.ashift_hhb = -10.0
		#if s.nseg < 3:
			#s.nseg=3
		if s.nseg > 1:
			locarea = 0
			for seg in s:
				locarea += h.area(seg.x,sec=s)
				#print(locarea)
		else:
			locarea = h.area(0.5,sec=s)
		areas.append(locarea)
		#print(s,s.nseg, locarea/PI)
		#print(locarea/PI)
		#quit()
	#quit()
	"""try:
		for s in test.basal:
			s.marker_typem = 1
			for seg in s:
				d = h.distance(seg.x,sec=s)
				if d > 100:
					seg.gnabar_NaMark *= p['distscale']
					seg.dist_NaMark *=2
				if d < 50:
					seg.dist_NaMark = 0.5
	except:
		if VERBOSE:
			print('no dendrites')"""
	#try:
	if test.excitozone != None: # if there is an AIS
		for s in test.excitozone: #AIS
			print('has ais', s)
			s.nseg = int(2*s.L/10)+1
			s.Ra = 100.0*p['ra']
			s.cm = 1.0*p['cm']
			s.insert('hhb') # for Kv4, Kdr
			s.insert('leak')
			s.insert('NaMark')
			s.insert('cabalcicr')
			s.insert('kca')
			s.insert('canchan')
			s.insert('calhh')
			s.insert('typem')
			#s.gkatpbar_katp =0
			s.marker_typem = 0.5
			s.gnabar_NaMark = 1000.0e-6*p['ais_na']
			#s.FAST_NaMark = 0.03
			s.FAST_NaMark = p['fast']
			s.gkhhbar_hhb = 200.0e-6*p['ais_k']
			s.gkbar_leak=1.0e-6*p['glk']
			s.gnabar_leak=1.0e-6*p['glna']
			s.dist_NaMark=p['slow_ais'] # turning it off for now
			#s.gkbar_bk = 0e-3
			s.gkbar_kca = 0e-6*p['gkca']
			s.tausk_kca = p['tausk']
			s.aiscorr_NaMark = p['aiscor']
			s.gkabar_hhb=5*300e-6*p['gka']
			s.fchip_hhb = 1
			s.kchip_hhb = 0
			s.tfast_hhb=p['tfast']
			s.v=-50.0
			s.nspeed_hhb = p['nspeed']
			
			#print(s.v, s.gnabar_NaMark)
			
			s.gcanbar_canchan = 100e-6
			s.gcalbar_calhh =10.0e-6*p['gcal'] # might be too large here
			s.skcoup_calhh = 0.0 # coupling of L-type to SK pool
			s.mhalf_calhh = -30 # half activation
			
			if h.ismembrane('kca',sec=s):
				s.gkbar_kca = 0
				s.tausk_kca = p['tausk']
				s.km_kca *=1.0
			if h.ismembrane('calmark', sec=s):
				s.gcalbar_calmark *=0.0
				s.moff_calmark = -30
				s.mslope_calmark = 5.0
			if h.ismembrane('calhh', sec=s):
				#print 'derp'
				s.gcalbar_calhh *=1
				s.mhalf_calhh = -35
				s.mslope_calhh = 5.0
				s.pf_calhh = p['pfcal']
				#s.coup_calmark = 0
			if h.ismembrane('canchan'):
				s.gcanbar_canchan = 0e-6
			#s.oerg_ergkin =0.5
			#s.vshift_erg = 0
			if h.ismembrane("hcn",sec=s):
				s.scale_hcn =1
				s.mhalf_hcn = -75.0
			if h.ismembrane('cabalcicr',sec=s): # convert to instant buffering
				#s.MitoBuffer_cabalthin = 0.03
				s.icapumpmax_cabalcicr *= p['capump']
				for seg in s:  # by seg as diam not fixed over sec 
					seg.TotalBuffer_cabalcicr = 0.03
					seg.SCALE_cabalcicr = 1.0
					#seg.shellfrac_cabalcicr = min(0.1/seg.diam,1.0)
					seg.tog_cabalcicr = 0 # syntax for calcium release via IP3 is present but not used for AIS
					seg.DCa_cabalcicr = p['dca'] # calcium diffusion constant (radial)
					seg.imetamax_cabalcicr *= p['meta']
					seg.kadp_cabalcicr = p['kadp'] 
			#s.ashift_hhb = 0
			if h.ismembrane('catchan',sec=s):
				s.phi_h_catchan = 1.5
				s.shift_catchan = 0
				s.buff_catchan=p['catbuff']
			if h.ismembrane('nmda'):
				s.cafrac_nmda = p['cafrac']
			
			
			s.ashift_hhb = p['kasv']
			s.asshift_hhb = p['kass']
			
		
	else:
		test.soma.diam=1
	for s in test.all:
		s.allcorr_NaMark = p['allcor']
	#except:
	#	if VERBOSE:
		#	print 'no AIS'
	try:
		for s in test.axonal: #Axons (NYI) minimal passive compartment included for error handling
			#print s
			s.v=-50
			s.nseg = 2*int(s.L/200)+1 # nseg should be odd
			s.Ra = 100.0*p['ra']
			s.cm = 1.0*p['cm']
			s.insert('leak')

			
			#s.nspeed_hhb = 2 # deep ahps from KDR alone
	except:
		if VERBOSE:
			print('no axon')
			
	geom_nseg(test) # set nseg based on RA, passive conductances # this generally undoes cases where nseg > 1
	for s in test.all:
		s.insert('cicr')
		if h.ismembrane('bkc',sec=s):
			for seg in s:
				#h.setpointer(seg._ref_ica_calhh,'icabk_p', seg.bkc)
				#print('setpointer')
				h.setpointer(seg._ref_ica_canchan,'icabk_p', seg.bkc)
				#print('worked')
		if h.ismembrane('cicr',sec=s) and h.ismembrane('cabalcicr',sec=s):
			for seg in s:
				#h.setpointer(seg._ref_ica_calhh,'icabk_p', seg.bkc)
				#print('setpointer')
				h.setpointer(seg._ref_castore_cabalcicr,'cas_p', seg.cicr) # this might not be working
				h.setpointer(seg._ref_cacicri,'cac_p', seg.cicr)
				
				h.setpointer(seg._ref_ju_cicr,'ju_p',seg.cabalcicr)
				h.setpointer(seg._ref_jcicr_cicr,'jcicr_p',seg.cabalcicr)
				#print('worked')
	
	if test.abd !=None:	
		for s in test.abd:
			#s.nseg = 7
			print(s, s.nseg, s.diam, s.L)
	if test.excitozone != None:
		s = test.ais
		sref = h.SectionRef(sec=s)
		while sref.has_parent():
			s = sref.parent
			sref = h.SectionRef(sec=s)
			#print(sref)
			for seg in s:
				seg.ghcnbar_hcn *=4
				seg.gkabar_hhb *=1 # GKA is not elevated - but maybe should be?
				#if seg.tfast_hhb < 50:
				#	seg.tfast_hhb=100
				#print(seg.gkabar_hhb)
		for seg in test.soma:
			seg.gnabar_NaMark *= 1	 
	try:
		index = 0
		for s in test.basal:
			s.marker_typem = 1
			print(s,s.nseg)
			if index in [5,6] :
				for seg in s:
					seg.gcatbar_catchan*=1
					seg.tsk_catchan*=1
			elif index in [0,1,2,7]:
				for seg in s:
					seg.gcatbar_catchan*=0
			else:
				for seg in s:
					seg.gcatbar_catchan*=1# Different to prevent 'race condition'
					seg.tsk_catchan*=1
			index+=1		
			for seg in s:
				d = h.distance(seg.x,sec=s)
				if d > 100:
					#seg.gnabar_NaMark *= p['distscale']
					#seg.dist_NaMark *=2
					seg.girkbar_girk *= 1
				else:
					seg.girkbar_girk *= 1					
				if d < 100:
					seg.dist_NaMark *= 1
					seg.girkbar_girk*=1
					seg.dsk_cabalcicr *= 1 # constrict for CaN - only?
					seg.DCa_cabalcicr = 0.3
					seg.icapumpmax_cabalcicr*=0.5 # pumps required for CaN only is less
					#seg.skcoup_calhh =0.5
				else:
					seg.dist_NaMark *= 1
				if d > 300:
					seg.girkbar_girk *= 1
					seg.gnabar_leak*=1	
					#seg.dsk_cabalcicr*=max(200.0/d,0.5)
					d2 = d/300.0
					seg.NaMark_allcorr = pstable['na_shift']+d2	
					seg.icapumpmax_cabalcicr*=max(pow(d2,-2),0.667)
				else:
					seg.gnabar_leak*=1	
					seg.girkbar_girk *=1
					#seg.tcicr_catchan = 0
				if d > 150:
					d2 = d/150
					seg.dsk_cabalcicr*=max(pow(d2,-1),0.3)
				if d > 500:
					seg.girkbar_girk *= 1
					seg.gnabar_leak*=1	
					#seg.tcicr_catchan = min(seg.tcicr_catchan*4,1) # might not be required
					#seg.icapumpmax_cabalcicr*=0.5
					#seg.diam *= 0.5 # this breaks it.
					#seg.dsk_cabalcicr*=0.5 # compression of space
					seg.gkbar_kca*=1
					seg.gcatbar_catchan *=1
					seg.tsk_catchan*=3
					if seg.tsk_catchan > 1:
						seg.tsk_catchan =1

				else:
					seg.tsk_catchan *= 1

			
				#print(seg,seg.girkbar_girk, seg.ghcnbar_hcn)
		for s in test.somatic:	
			for seg in s:
				seg.gkabar_hhb *= 4
				seg.gnabar_NaMark*=0.5
				d = h.distance(seg.x,sec=s)			
				if d < 50:
					seg.dist_NaMark *= 0
	except:
		if VERBOSE:
			print('no dendrites')
	
	#print(test.soma.gkabar_hhb)
	tnseg = 0
	for s in test.all:
		for seg in s:
			tnseg +=1
	print(tnseg)
	#quit()


def to_shreds_you_say(neuron_obj,chop_dist):  # simulate accute dissociation for realistic morphologies
	blest = neuron_obj
	h.distance(0,blest.soma(0.5))
	for s in blest.all:
		sref = h.SectionRef(sec = s)
		dist = h.distance(0.5, sec=s)
		if dist > chop_dist:
			h.disconnect(sref)
			h.delete_section(sref)