#should be renamed 'single cell functions'

from neuron import h
import numpy as np
from paramsdict import p_init as pderp
#from scipy.interpolate import interp1d
#from scipy.signal import butter, lfilter
from random import choice
#import matplotlib.pyplot as plt


VERBOSE=0

def sine_func(t,freq=4,maxi=100,mini=20):
	temp = (1.0+np.sin(2*np.pi*1e-3*freq*t))/2.0
	temp *= (maxi-mini)
	temp += mini
	return temp



def lfnoise(): # dummy func
	pass




def gen_netstim(func,target,x=0.5,syn_type='gabaa',nspikes=1,duration=1e4,funcargs={},pstable=pderp,input_spikes=None,ncells=1):
	rtype = pstable['rtype']
	if rtype not in ['poisson','nexp']: # note that nexp creates a poisson distribution
		rtype = 'nexp'
		if VERBOSE:
			print('type not supported, reverting to nexp')
	safety_interval = 100 # how long an interval is allowed before the rate function is re-sampled.
	# this creates an error in the tail of the distribution to avoid missing shifts in the rate kernal
	#stim = h.NetStim()
	#print pstable['gabaamp']
	SAVETOFILE=0

	
	if target == None:
		synapse = None
	else:
		if syn_type not in ['nmda','gabab','chrhod']:
			synapse = h.Exp2Syn(target(x))
			synapse.tau1=0.2
			synapse.tau2=2.4
			if syn_type=='gabaa':
				synapse.e = pstable['ecl']
			if syn_type=='ampa':
				synapse.e = 0
				synapse.tau2=3
		if syn_type == 'nmda':
			synapse = h.Exp2NMDA(target(x))
			synapse.K0 = pstable['knmda']
		if syn_type == 'gabab':
			#synapse = h.GABAb(target(0.5)) # this doesn't have right syntax
			synapse = h.Exp2GABAb(target(x))
			synapse.tau1=74
			synapse.tau2=188
			synapse.e = -90
		if syn_type == 'chrhod':
			synapse = h.Exp2Syn(target(x))
			synapse.tau1=1.23
			synapse.tau2=6.13
			synapse.e = 0
			
	
	#print synapse
	if target == None:	
		ncobj = None
	else:	
		ncobj=h.NetCon(None, synapse)
		ncobj.delay=2
		if syn_type=='gabaa':
			ncobj.weight[0] = pstable['gabaamp']
		if syn_type in ['nmda','ampa']:
			ncobj.weight[0] = pstable['nmdaamp']
		if syn_type=='gabab':
			ncobj.weight[0] = pstable['gababamp']
		if syn_type=='chrhod':
			ncobj.weight[0] = pstable['chrhodamp']
		#print syn_type, ncobj.weight[0]		

	
	if input_spikes != None:
		if isinstance(input_spikes,list): # if it is a python array, just copy it
			spikes = input_spikes
			if target != None:
				return ncobj, spikes, synapse
			else:
				return spikes
		else: # if it is array like (eg numpy array or neuron vector)
			try:
				spikes = []

				for things in input_spikes:
					spikes.append(things)
				if target != None:
					return ncobj, spikes, synapse
				else:
					return spikes
			except:
				print('input spikes expected to be an array object, generating from scratch')
				pass
		
	
	keys = funcargs.keys()
	global freq
	execstr = 'global freq; freq = func(t'
	for key in keys:
	   execstr+= ',%s=%f' % (key, funcargs[key])
	execstr+=')'
	spikes = []
	sftinv = 1.0/safety_interval
	old=0
	for j in range(ncells):
		t=0
		if rtype not in ['regular']:
			while t < duration: # this can cause problems if the minimum frequency is very low.
				nspikes=1
				freq ='e'
				exec(execstr,globals(),locals()) # freq= func(t,args)
				#print(freq)
				#quit()
				if freq < sftinv:
					t+=10
					continue
				if rtype=='poisson': # this makes a 'meta poisson'
					isi = np.random.poisson(1000.0/freq,nspikes)
				if rtype=='nexp': # this actually makes a poisson
					isi = np.random.exponential(1000.0/freq,nspikes) 
					# ISI based on frequency at t=t0 not mean over interval from t0 to t1
					# this method is only accurate when ISI is fast compared with changes in stimulus
				#print isi
				if max(isi) > safety_interval: 
					# this is a hedge against creating huge steps when rate is low - essentially chopping the tail off the nexp
					t+=safety_interval
					continue
				cs = np.cumsum(isi)
				for things in cs:
					if things < safety_interval:
						spikes.append(things+t)
						old = things
					else:
						break
				t = t + old
				
	if ncells > 1:
		spikes.sort()
		
	if SAVETOFILE and target!=None:
		temp = str(target).split('[')[1]
		temp = temp.split(']')[0]
		number = int(temp)
		fname = '%s_%s_input_raster.dat' %(syn_type, str(target).replace('[','').replace(']',''))
		fp = open(fname,'w')
		for t in spikes:
			fp.write('%e %d\n' % (t, number))
		fp.close()
		#quit()
	#print(len(spikes), syn_type, funcargs)
	#test1 = h.Vector(spikes).deriv(1)
	#for i in range(len(test1)):
		#print spikes[i],test1[i]
	if target != None:
		return ncobj, spikes, synapse
	else:
		return spikes # use this for distributed
		
		
		
# THIS IS ALL HORRIBLE KLUDGE CODE AND SHOULD NOT BE USED AS A TEMPLATE FOR ANYTHING.	
def gaba_run(rec_list,neuron_class,func=sine_func,pstable=pderp,funcargs={},args=['',''],ncells=1,stype='gabaa',nmdafunc=None,nmdaargs={},nmdacells=1, HISTO=0, RAMP=0, DISTRIBUTE=0,SUMMATE=1,UNIFORM=1,VARNMDA=1, VARGABA=0,DISTAL=0,SUM_INDS=[],DRAW=0):
	affix=args[0]
	affix2=args[1]
	dt =0.1

	if UNIFORM:
		seed = int(pstable['control'])
		np.random.seed(seed)

	if func==lfnoise:
		l=int(3*pstable['idel']+pstable['idur'])

		if 'mean' in funcargs.keys():
			mean = funcargs['mean']
		else:
			mean = 0
		if 'maxfreq' in funcargs.keys():
			maxfreq = funcargs['maxfreq']
			# must be between 0 and 1, but filter initialization catches this error
		else:
			maxfreq = 0.01
		if 'var' in funcargs.keys():
			var = funcargs['var']
		else:
			var = 1
		if 'mnx' in funcargs.keys():
			mnx = funcargs['mnx']
		else:
			mnx = 0
		if 'step' in funcargs.keys():
			step = funcargs['step']
		else:
			step = 1			
		func=genlfnfunc(l,maxfreq,mean=mean,var=var,mnx=mnx,step=step)
		funcargs={}
	#pc = h.ParallelContext()
	#pc.nthread(4)


	if nmdafunc==lfnoise:
		l=int(3*pstable['idel']+pstable['idur'])

		if 'mean' in nmdaargs.keys():
			mean = nmdaargs['mean']
		else:
			mean = 0
		if 'maxfreq' in nmdaargs.keys():
			maxfreq = nmdaargs['maxfreq']
			# must be between 0 and 1, but filter initialization catches this error
		else:
			maxfreq = 0.01
		if 'var' in nmdaargs.keys():
			var = nmdaargs['var']
		else:
			var = 1
		if 'mnx' in nmdaargs.keys():
			mnx = nmdaargs['mnx']
		else:
			mnx = 0
		if 'step' in nmdaargs.keys():
			step = nmdaargs['step']
		else:
			step = 1			
		nmdafunc=genlfnfunc(l,maxfreq,mean=mean,var=var,mnx=mnx,step=step)
		nmdaargs={}
		
	cell = neuron_class(0,pstable)
	blah = h.IClamp(cell.nrn.soma(0.5))
	blah2=h.IClamp(cell.nrn.soma(0.5))
	for s in cell.nrn.somatic:
		#print(s)
		#quit()
		s.insert('nmda')
		for seg in s:
			#print(seg)
			seg.gnmdabar_nmda = pstable['gnmda']
			seg.cMg_nmda = pstable['cmg']
			seg.cafrac_nmda = pstable['cafrac']
	

	
	
	if RAMP:
		print('ramp=1')
		ampmax = pstable['iamp']
		ramp = []
		basal = pstable['basal']
		
		for i in range(pstable['idel']):
			ramp.append(basal)
		for i in range(int(pstable['idur'])):
			if i < (pstable['idur'])/2.0:
				ramp.append(basal+2.0*ampmax*(i)/float(pstable['idur'])) # @ i = blah.dur/2 this is basal + ampmax = 75
				#if i%100 == 0:
				#	print i, basal+2*ampmax*(i)/float(blah.dur)
			else:
				ramp.append(basal+2.0*ampmax*(pstable['idur']-i)/float(pstable['idur']))
		for i in range(pstable['idel']):
			ramp.append(basal)

		blah.dur = 1e9#blah.dur+2*blah.delay
		blah.delay = 0
		ampvec= h.Vector(ramp)
		ampvec.play(blah._ref_amp, 1)
		
		TSTOP=2*pstable['idel']+pstable['idur']
	
	else:	
		blah.delay = pstable['idel']
		blah.amp = pstable['iamp']#0e-3 # max at 3e-3
		blah.dur = pstable['idur']
		
		TSTOP = pstable['idel']+pstable['idur']+7000
	
	blah2.delay = pstable['idel']+ pstable['idur']+pstable['offset']
	blah2.dur = pstable['i2dur']
	blah2.amp = pstable['i2amp']
	
	if pstable['gnmda'] > 0:
		gnmda_t =[0,blah.delay,blah.delay+blah.dur]
		gnmda_g = [0,pstable['gnmda'],0]
		gnmda_t_vec = h.Vector(gnmda_t)
		gnmda_g_vec = h.Vector(gnmda_g)
		rec_list.append('gnmda_nmda')
		rec_list.append('inmda_nmda')
		for s in cell.nrn.all:
			if h.ismembrane('nmda',sec=s):
				#print(1)
				gnmda_g_vec.play(s(0.5)._ref_gnmdabar_nmda,gnmda_t_vec)
				#print(2)
	
	cell.ntc.record(cell.srec) # this might be doubling up on record commands

	rec_vectors = [h.Vector(),h.Vector()]
	i=3
	inot=3
	
	rec_vectors[0].record(blah,h._ref_t,dt,sec=cell.nrn.soma)
	rec_vectors[1].record(blah,cell.nrn.soma(0.5)._ref_v,dt,sec=cell.nrn.soma)
	if RAMP:
		rec_vectors.append(h.Vector())
		rec_vectors[2].record(blah._ref_amp,dt)
	else:
		rec_vectors.append(h.Vector())
		rec_vectors[2].record(blah._ref_i,dt)
	
	ind = 4 # 0 = t 1 = V, 2 = dv/dt 3 = iapp
	h.finitialize()				
	summed_rec_vecs = [[],[],[]]
	for things in rec_list:
		rec_vectors.append(h.Vector())
		#print(i,things, SUM_INDS)
		if (i-3) in SUM_INDS:
			print(i-3,things)
			summed_rec_vecs.append([])
			#print(cell.nrn.all)
			for s in cell.nrn.somden:
				seg_ind = 0
				for seg in s:
					summed_rec_vecs[-1].append(h.Vector())
					loc = str(s).split('.')[1]
					#print(loc)
					exec('summed_rec_vecs[-1][-1].record(blah,cell.nrn.%s(seg.x)._ref_%s,dt,sec=s)' % (loc,things) ) in locals() # -1 might work same
					#seg_ind+=1
		else:
			summed_rec_vecs.append([])
			if len(things.split('.')) <2:
				exec('rec_vectors[%d].record(blah,cell.nrn.soma(0.5)._ref_%s,dt,sec=cell.nrn.soma)' % (i,things) ) in locals() # blah indicates what cell to record from
			else:
				loc = things.split('.')
				exec('rec_vectors[%d].record(blah,cell.nrn.%s(1)._ref_%s,dt,sec=cell.nrn.soma)' % (i,loc[0],loc[1]) ) in locals()
		i+=1
		#except:
		#	print('failed to record %s'  % things)
		#	rec_list.pop(i-inot)
		#	rec_vectors.pop()
	

			
	#quit()
	if not DISTRIBUTE:
		#quit()
		if stype == 'bpulse': # using window function exclusively
			nc, spikes,synapse = gen_netstim(func, cell.nrn.soma,nspikes=100,duration=TSTOP,funcargs={'off':TSTOP,'on':0,'fon':0,'foff':pstable['foff']},ncells=ncells,syn_type='gabaa')
			nc_b, spikes_b, synapse_b = gen_netstim(func, cell.nrn.soma,nspikes=1,duration=TSTOP,funcargs={'off':pstable['woff'],'on':pstable['won'],'fon':pstable['fon'],'foff':0},ncells=ncells,syn_type='gabab')
	    
		if stype in ['gabaa','gabab']:
			
			if not VARGABA:
				funcargs = {'fon':pstable['foff'],'on':pstable['won'],'foff':pstable['foff'],'off':pstable['woff']}
			nc, spikes,synapse = gen_netstim(func, cell.nrn.soma,nspikes=1,duration=TSTOP,funcargs=funcargs,ncells=ncells,syn_type=stype)
		
		if nmdafunc!=None:
			if VARNMDA:
				nmdaargs = {'fon':pstable['fon'],'on':pstable['won'],'foff':pstable['foff'],'off':pstable['woff']}
			else:
				nmdaargs = {'fon':pstable['foff'],'on':pstable['won'],'foff':pstable['foff'],'off':pstable['woff']}
			nc_nmda, spikes_nmda, synapse_nmda = gen_netstim(nmdafunc, cell.nrn.soma,nspikes=1,duration=TSTOP,funcargs=nmdaargs,ncells=nmdacells,syn_type='nmda')
			if pstable['ampa_ratio'] > 0:
				nc_ampa, dummy, synapse_ampa = gen_netstim(nmdafunc,cell.nrn.soma,duration=0,funcargs=nmdaargs,syn_type='ampa') # creates netcon and synapse
				nc_ampa.weight[0] *= pstable['ampa_ratio']
				nc_nmda.weight[0] *= (1.0-pstable['ampa_ratio'])
		
		
		if pstable['regamp'] > 0:
			if 1000.0*pstable['number']/float(pstable['freq']) > pstable['won']+pstable['woff']:
				print('error: stimulus is wider than ergodic interval')
				quit()
			if DISTRIBUTE and DISTAL > 0:				
				ncr,rspikes,rsyn = gen_netstim(func,cell.nrn.dend[DISTAL],x=0.75,pstable=pstable,input_spikes=[],syn_type=pstable['regtype']) # with input_spikes set to [], this only creates the netcon and synapse
			else:
				ncr,rspikes,rsyn = gen_netstim(func,cell.nrn.soma,pstable=pstable,input_spikes=[],syn_type=pstable['regtype']) # with input_spikes set to [], this only creates the netcon and synapse
			ncr.weight[0] = pstable['regamp']
			t=pstable['won']
			while t<TSTOP:
				tloc = t
				tend = tloc + pstable['woff']+pstable['won']
				j=0
				while tloc < tend:
					if j >= pstable['number']:
						t=tend
						break
					if pstable['freq'] > 0:
						rspikes.append(t)
						#print(t)
						t+=1000.0/float(pstable['freq'])
					j+=1

		h.t = 0
		
		#print(1)
		cv = h.CVode()
		#print(2)
		cv.active(1)
		#print(3)
		
		h.finitialize()
		for t in spikes:
			nc.event(t)
		
		if stype=='bpulse':
			for t in spikes_b:
				nc_b.event(t)
		
		if nmdafunc!=None:
			for t in spikes_nmda:
				nc_nmda.event(t)
				if pstable['ampa_ratio'] >0:
					nc_ampa.event(t)
		
		if pstable['regamp'] > 0:
			rpsikes.pop(0)
			for t in rspikes:
				#print(t)
				ncr.event(t)

	if DISTRIBUTE:
		if not VARGABA:
			funcargs = {'fon':pstable['foff'],'on':pstable['won'],'foff':pstable['foff'],'off':pstable['woff']}
		spikes_gaba = gen_netstim(func, None ,nspikes=1,duration=TSTOP,funcargs=funcargs,ncells=ncells,syn_type=stype)
		if VARNMDA:
			nmdaargs = {'fon':pstable['fon'],'on':pstable['won'],'foff':pstable['foff'],'off':pstable['woff']}
		spikes_nmda = gen_netstim(nmdafunc, None,nspikes=1,duration=TSTOP,funcargs=nmdaargs,ncells=nmdacells,syn_type='nmda')
		
		
		areas = []
		lengths = []
		
		gaba_a_ntcs = []
		glut_nmda_ntcs = []
		glut_ampa_ntcs = []
				
		gaba_a_syns = []
		glut_nmda_syns = []
		glut_ampa_syns = []
		
		gaba_recs = []
		nmda_recs = []
		ampa_recs = []
		

		ind = 0
		for s in cell.nrn.all:
			for seg in s:
				areas.append(h.area(seg.x,sec=s))
				lengths.append(s.L/(1.0*s.nseg))
				for syn_type in ['gabaa','nmda']: # just these for now
					synapse = h.Exp2Syn(s(seg.x))
					synapse.tau1=0.2
					synapse.tau2=2.4
					if syn_type=='gabaa':
						synapse.e = pstable['ecl']
						ntc = h.NetCon(None, synapse)
						ntc.weight[0] = pstable['gabaamp']
						gaba_a_ntcs.append(ntc)
						gaba_a_syns.append(synapse)
						if SUMMATE:
							gaba_recs.append(h.Vector())
							gaba_recs[ind].record(gaba_a_syns[ind]._ref_g,dt)
					if syn_type == 'nmda':
						synapse1 = h.Exp2NMDA(s(seg.x))
						synapse2 = h.Exp2Syn(s(seg.x))
						synapse2.e = 0
						synapse2.tau2=3
						synapse2.tau1=0.2
						#print(synapse1.g)
						#quit()
						synapse1.K0 = pstable['knmda']	
						ntc1 = h.NetCon(None, synapse1)
						ntc1.weight[0] = pstable['nmdaamp']*(1-pstable['ampa_ratio'])	
						ntc2 = h.NetCon(None, synapse2)
						ntc2.weight[0] = pstable['nmdaamp']*(pstable['ampa_ratio'])
						glut_nmda_ntcs.append(ntc1)
						glut_ampa_ntcs.append(ntc2)
						glut_nmda_syns.append(synapse1)
						glut_ampa_syns.append(synapse2)
						if SUMMATE:
							nmda_recs.append(h.Vector())
							nmda_recs[ind].record(glut_nmda_syns[ind]._ref_g,dt)
							ampa_recs.append(h.Vector())
							ampa_recs[ind].record(glut_ampa_syns[ind]._ref_g,dt)
							
				ind	+= 1
				
		if pstable['regamp'] > 0:
			if 1000.0*pstable['number']/float(pstable['freq']) > pstable['won']+pstable['woff']:
				print('error: stimulus is wider than ergodic interval')
				quit()
			if DISTRIBUTE and DISTAL > 0:
				ncr,rspikes,rsyn = gen_netstim(func,cell.nrn.dend[DISTAL],pstable=pstable,input_spikes=[],syn_type=pstable['regtype']) # with input_spikes set to [], this only creates the netcon and synapse
			else:
				ncr,rspikes,rsyn = gen_netstim(func,cell.nrn.soma,pstable=pstable,input_spikes=[],syn_type=pstable['regtype']) # with input_spikes set to [],
			
			ncr.weight[0] = pstable['regamp']
			t=pstable['won']
			while t<TSTOP:
				tloc = t
				tend = tloc + pstable['woff']+pstable['won']
				j=0
				while tloc < tend:
					if j >= pstable['number']:
						t=tend
						break
					if pstable['freq'] > 0:
						rspikes.append(t)
						#print(t)
						t+=1000.0/float(pstable['freq'])
					j+=1
		
		numsyns = ind
		#print(1)
		cv = h.CVode()
		#print(2)
		cv.active(1)
		#print(3)
		h.finitialize()
		
		aprob = np.array(areas)
		lprob = np.array(lengths)
		tarea = np.sum(aprob)
		tlen = np.sum(lprob)
		
		#print len(aprob), len(lprob), len(glut_nmda_ntcs)
		
		for j in range(len(aprob)):
			lprob[j] /= (1.0*tlen)
			aprob[j] /= (1.0*tarea)
		
		for t in spikes_gaba:
			where = np.random.choice(gaba_a_ntcs, p=aprob)
			where.event(t)
				
		numbers= range(len(glut_nmda_ntcs))	
		for t in spikes_nmda:
			index = np.random.choice(numbers, p=aprob)
			where_ampa = glut_ampa_ntcs[index]
			where_nmda = glut_nmda_ntcs[index]
			where_ampa.event(t)
			where_nmda.event(t)
			#print where, t
		
		if pstable['regamp'] > 0:
			for n in range(int(pstable['number'])):
				rspikes.pop(0)
			for t in rspikes:
				#print(t)
				ncr.event(t)
		#quit()
		#### END DISTRIBUTE
	#quit()
	printbool=False
	
	if DRAW:
		try:
			cv.solve(pstable['idur']+pstable['idel']-1)
		except:
			print('there is no escape')
			cv.solve(pstable['idur']+pstable['idel']-1)		
		from matplotlib import pyplot, cm
		import plotly
		print('attempting to draw')
		#quit()
		blah = h.PlotShape(False)
		#blah.size(-1000,1000,-1000,1000)
		#blah.colormap(10)
		blah.view(-700,-500,1400,1000,0,0,1000,500)
		blah.exec_menu("Shape Plot")
		blah.exec_menu("View = plot")
		#blah.unmap()
		#blah.view(-1000,-1000,2000,2000,0,0,1000,1000)

		seclist = [cell.nrn.dend[3],cell.nrn.dend[1],cell.nrn.dend[0],cell.nrn.soma,cell.nrn.dend[2],cell.nrn.dend[5]]
		points = []
		h.distance(sec=cell.nrn.soma)
		index=0
		for s in seclist:
			for seg in s:
				if index <3:
					sign=-1
				else:
					sign =1
				print(s,h.distance(seg.x,sec=s))
				points.append([sign*h.distance(seg.x,sec=s),seg.v,seg.caski])
			index+=1
		np.savetxt('%s_%s%s_%d.dat' % ('prev',affix,affix2,int(h.t)),points)
		h.fadvance()
		
		var = 'caski'
		
		
		
		
		blah.variable(var)
		
		#blah.variable('m_kca')
		#blah.scale(0,1)
		blah.scale(0,0.0003)

		#blah.scale(-70,-40)
		#blah.colormap()
		

		blah.show(1)

		h.doNotify()
		
		#ax = blah.plot(pyplot,cmap=cm.jet)
		#ax.set(xlim=(-1000,1000),ylim=(-1000,1000))
		#pyplot.show()
		try:
			cv.solve(h.t+25)
		except:
			cv.solve(h.t+25)
		for i in range(10):
			blah.flush()
				#q.flush()
			#ax = blah.plot(pyplot)	
			#pyplot.show()
			#ax = blah.plot(pyplot,cmap=cm.jet)
			#ax.set(xlim=(-1000,1000),ylim=(-1000,1000))
			#pyplot.show()
			index=0
			for s in seclist:
				for seg in s:
					if index <3:
						sign=-1
					else:
						sign =1
					print(seg.x,seg.gcatbar_catchan,seg.caski)
					points.append([sign*h.distance(seg.x,sec=s),s(seg.x).v,s(seg.x).caski])
				index+=1
				print()
			#quit()
			points.sort()
			np.savetxt('%s_%s%s_%d.dat' % ('prev',affix,affix2,int(h.t)),points)
			blah.show(1)
			blah.printfile('%s_%s%s_%d.ps' % (var,affix,affix2,int(h.t)))
			try:
				cv.solve(h.t+25)
			except:
				cv.solve(h.t+25)

		if h.t < TSTOP:
			cv.solve(TSTOP)

	
	else:
		while h.t < TSTOP:
			#print('advance')
	
			
			try:
				h.fadvance()
			except:
				try:
					cv.solve(TSTOP)
				except:
					h.fadvance()
					print(h.t, "no breaks")
					#break
			#print('success')
			#quit()
			if 1e-6 > h.t%1000 or h.t%1000 > (10000.0-1e-6):
				print(h.t)		
	#cv.solve(TSTOP)
	#pc.psolve(TSTOP)
	
	size = len(rec_vectors[0])
	nstates = len(rec_vectors)
	if pstable['control'] < pstable['nprint']:
		fp= open('GABA_%s%s.dat' % (affix,affix2),'w')

		dv = h.Vector()
		dv.deriv(rec_vectors[1],dt)
		ln = 0
		for s in cell.nrn.all:
			ln+=1
		if ln == 1:
			carea = 0
			for seg in cell.nrn.soma:
				carea+= h.area(seg.x,sec=cell.nrn.soma)*cell.nrn.soma(seg.x).cm*1e-5
				
			#um^2*uf/cm2*1e-8cm2/um^2*1e3nf/uf = nf
			# nf * mV/ms = nA
			#print(carea, cell.nrn.soma.L, cell.nrn.soma.nseg)

		if DISTRIBUTE and SUMMATE:
			for i in range(numsyns-1):
				ampa_recs[0].add(ampa_recs[i+1])
				nmda_recs[0].add(nmda_recs[i+1])
				gaba_recs[0].add(gaba_recs[i+1])
				# summate the synapse vectors
						
		for i in range(size-1):
			for j in range(nstates): # 
				if j == 1:
					v = rec_vectors[j][i]
					fp.write('%e  %e  ' % (rec_vectors[j][i], dv[i]))
					if ln == 1:
						fp.write('%e  ' % (-carea*dv[i])) # area is in um^2, cm is in uf/cm2 dv is in mV/ms
					#j+=1
				else:
					if j-3 in SUM_INDS:
						temp = 0
						#print(1,j,rec_list[j-3],len(summed_rec_vecs[j]))
						for k in range(len(summed_rec_vecs[j])):
							temp += summed_rec_vecs[j][k][i]*h.area(seg.x,sec=s)
						fp.write('%e  ' % temp)
					else:						
						#print(nstates,j,len(summed_rec_vecs[j]),len(rec_vectors[j]))
						fp.write('%e  ' % rec_vectors[j][i])
			#fp.write('%e\n' %(func(rec_vectors[0][i]))) # this seems to always turn up sine func
			if DISTRIBUTE and SUMMATE:
				fp.write('%e  ' % gaba_recs[0][i])		
				fp.write('%e  ' % ampa_recs[0][i])
				fp.write('%e  ' % nmda_recs[0][i])
				if i == 0:
					c1 = 1.2/4.1
					c2 = (0.001)*(-2)*0.8*9.648e4/8.315/(273+35)
				mgblock = 1 / (1 + (c1)*np.exp(c2*v))
				fp.write('%e  ' % (nmda_recs[0][i]*mgblock)	)	
			fp.write('\n')
		fp.close()

	check_keys = ['foff','fon','on','off']
	check = any(item in funcargs.keys() for item in check_keys) or not len(funcargs.keys()) # checks to see if func args are consistent with windowfunc
	ERGO=0
	if check and HISTO:
		fp2 = open('spikes_%s%s.dat' % (affix,affix2),'w')
		for times in cell.srec:
			fp2.write('%e  %d\n' % (times,pstable['control']))
		fp2.close()
		if ERGO:
			interval = pstable['won']+pstable['woff']
			centerpoint = (pstable['won']+pstable['woff'])/2
			nintervals = int(TSTOP/interval)
			start = pstable['won'] # ignore first step
		else:
			interval = min(pstable['idur']+pstable['idel'],10000)
			nintervals = 1
			start = max(0,pstable['idel']-5000)
			centerpoint = 0
		
		
		t=start
		binsize=100
		nbins = int(interval/binsize)#should be 200 at 50
		spare_change = interval%binsize
		bins = [0 for i in range(nbins)]
		# function starts with on

		
		
		t+=centerpoint
		told = t
		temp = h.Vector()
		#nintervals = max(1,nintervals-1)
		#t += interval # skip first interval
		for i in range(nintervals):
			j=0

			for j in range(nbins):
				temp.where(cell.srec,'[]',t,t+binsize)
				bins[j]+=len(temp)/float(nintervals)
				t+=binsize
			t+= spare_change
		
		keys = funcargs.keys()
		
		execstr = 'freq = func(t'
		for key in keys:
			execstr+= ',%s=%f' % (key, funcargs[key])
		execstr+=')'
		
		
		t = start - binsize/2
		fp= open('Histo_%s%s.dat' % (affix,affix2),'w')
		for things in bins:
			t+=binsize
			exec(execstr) in locals()
			#print(t, things, freq)
			fp.write('%e  %e  %e\n'  % (t, things, freq))
		fp.close()
			
	return 0