from neuron import h
import numpy as np
from paramsdict import p_init as pderp

VERBOSE=0

def sine_func(t,freq=4,maxi=100,mini=20):
	temp = (1.0+np.sin(2*np.pi*1e-3*freq*t))/2.0
	temp *= (maxi-mini)
	temp += mini
	return temp


def gen_netstim(func,target,x=0.5,syn_type='gabaa',nspikes=1,duration=1e4,funcargs={},pstable=pderp,input_spikes=None,ncells=1):
	rtype = pstable['rtype']
	if rtype not in ['poisson','nexp']: # note that nexp creates a poisson distribution
		rtype = 'nexp'
		if VERBOSE:
			print('type not supported, reverting to nexp')
	safety_interval = 100 # how long an interval is allowed before the rate function is re-sampled.
	# this creates an error in the tail of the distribution to avoid missing shifts in the rate kernal
	#stim = h.NetStim()
	#print pstable['gabaamp']
	SAVETOFILE=0

	
	if target == None:
		synapse = None
	else:
		if syn_type not in ['nmda','gabab','chrhod']:
			synapse = h.Exp2Syn(target(x))
			synapse.tau1=0.2
			synapse.tau2=2.4
			if syn_type=='gabaa':
				synapse.e = pstable['ecl']
			if syn_type=='ampa':
				synapse.e = 0
				synapse.tau2=3
		if syn_type == 'nmda':
			synapse = h.Exp2NMDA(target(x))
			synapse.K0 = pstable['knmda']
		if syn_type == 'gabab':
			#synapse = h.GABAb(target(0.5)) # this doesn't have right syntax
			synapse = h.Exp2GABAb(target(x))
			synapse.tau1=74
			synapse.tau2=188
			synapse.e = -90
		if syn_type == 'chrhod':
			synapse = h.Exp2Syn(target(x))
			synapse.tau1=1.23
			synapse.tau2=6.13
			synapse.e = 0
	
	#print synapse
	if target == None:	
		ncobj = None
	else:	
		ncobj=h.NetCon(None, synapse)
		ncobj.delay=2
		if syn_type=='gabaa':
			ncobj.weight[0] = pstable['gabaamp']
		if syn_type=='nmda':
			ncobj.weight[0] = pstable['nmdaamp']
		if syn_type=='gabab':
			ncobj.weight[0] = pstable['gababamp']
		#print syn_type, ncobj.weight[0]		

	
	if input_spikes != None:
		if isinstance(input_spikes,list): # if it is a python array, just copy it
			spikes = input_spikes
			if target != None:
				return ncobj, spikes, synapse
			else:
				return spikes
		else: # if it is array like (eg numpy array or neuron vector)
			try:
				spikes = []

				for things in input_spikes:
					spikes.append(things)
				if target != None:
					return ncobj, spikes, synapse
				else:
					return spikes
			except:
				print('input spikes expected to be an array object, generating from scratch')
				pass
		
	
	
	keys = funcargs.keys()
	execstr = 'global freq; freq = func(t'
	for key in keys:
	   execstr+= ',%s=%f' % (key, funcargs[key])
	execstr+=')'
	spikes = []
	sftinv = 1.0/safety_interval
	old=0
	global freq
	print(execstr)
	for j in range(ncells):
		t=0
		if rtype not in ['regular']:
			while t < duration: # this can cause problems if the minimum frequency is very low.
				nspikes=1
				exec(execstr) in locals(),globals() # freq= func(t,args)
				#print freq
				if freq < sftinv:
					t+=10
					continue
				if rtype=='poisson':
					isi = np.random.poisson(1000.0/freq,nspikes)
				if rtype=='nexp':
					isi = np.random.exponential(1000.0/freq,nspikes)
				#print isi
				if max(isi) > safety_interval: # 
					t+=safety_interval
					continue
				cs = np.cumsum(isi)
				for things in cs:
					if things < safety_interval:
						spikes.append(things+t)
						old = things
					else:
						break
				t = t + old
				
	if ncells > 1:
		spikes.sort()
		
	
	print( len(spikes), syn_type, funcargs)
	#test1 = h.Vector(spikes).deriv(1)
	#for i in range(len(test1)):
		#print spikes[i],test1[i]
	if target != None:
		return ncobj, spikes, synapse
	else:
		return spikes
		
def gaba_run(rec_list,neuron_class,func=sine_func,pstable=pderp,funcargs={},args=['',''],ncells=1,stype='gabaa',nmdafunc=None,nmdaargs={},nmdacells=1):
	affix=args[0]
	affix2=args[1]
	dt =0.1
	
		
	#pc = h.ParallelContext()
	#pc.nthread(4)
	
	cell = neuron_class(0,pstable)
	blah = h.IClamp(cell.nrn.soma(0.5))
	blah2=h.IClamp(cell.nrn.soma(0.5))
	for s in cell.nrn.all:
		s.insert('nmda')
		s.gnmdabar_nmda = pstable['gnmda']
		s.cMg_nmda = pstable['cmg']
		s.cafrac_nmda = pstable['cafrac']
	
	TSTOP=pstable['idel']*2+pstable['idur']
		
	blah.delay = pstable['idel']
	blah.amp = pstable['iamp']#0e-3 # max at 3e-3
	blah.dur = pstable['idur']
	blah2.delay = pstable['idel']+ pstable['idur']+pstable['offset']
	blah2.dur = pstable['i2dur']
	blah2.amp = pstable['i2amp']
	
	cell.ntc.record(cell.srec) # this might be doubling up on record commands

	rec_vectors = [h.Vector(),h.Vector()]
	i=2
	rec_vectors[0].record(blah,h._ref_t,dt,sec=cell.nrn.soma)
	rec_vectors[1].record(blah,cell.nrn.soma(0.5)._ref_v,dt,sec=cell.nrn.soma)
	#print rec_list
	for things in rec_list:
		try:
			rec_vectors.append(h.Vector())
			exec('rec_vectors[%d].record(blah,cell.nrn.soma(0.5)._ref_%s,dt,sec=cell.nrn.soma)' % (i,things) ) in locals()
			i+=1
		except:
			print( 'failed to record %s'  % things)
	

	cv = h.CVode()
	cv.active(1)
	if stype == 'bpulse': # using window function exclusively
		nc, spikes,synapse = gen_netstim(func, cell.nrn.soma,nspikes=100,duration=TSTOP,funcargs={'off':TSTOP,'on':0,'fon':0,'foff':pstable['foff']},ncells=ncells,syn_type='gabaa')
		nc_b, spikes_b, synapse_b = gen_netstim(func, cell.nrn.soma,nspikes=1,duration=TSTOP,funcargs={'off':pstable['woff'],'on':pstable['won'],'fon':pstable['fon'],'foff':0},ncells=ncells,syn_type='gabab')
    
	if stype in ['gabaa','gabab']:
		nc, spikes,synapse = gen_netstim(func, cell.nrn.soma,nspikes=1,duration=TSTOP,funcargs=funcargs,ncells=ncells,syn_type=stype)
	
	if nmdafunc!=None:
		nc_nmda, spikes_nmda, synapse_nmda = gen_netstim(nmdafunc, cell.nrn.soma,nspikes=1,duration=TSTOP,funcargs=nmdaargs,ncells=nmdacells,syn_type='nmda')
		if pstable['ampa_ratio'] > 0:
			nc_ampa,dummy,synapse_ampa = gen_netstim(nmdafunc,cell.nrn.soma,duration=0,funcargs=nmdaargs,ncells=nmdacells,syn_type='ampa')
			nc_nmda.weight[0]*=(1-pstable['ampa_ratio'])
			nc_ampa.weight[0]*=pstable['ampa_ratio']
	
	
	
	
	if pstable['regamp'] > 0:
		if 1000.0*pstable['number']/float(pstable['freq']) > pstable['won']+pstable['woff']:
			print('error: stimulus is wider than ergodic interval')
			quit()
		ncr,rspikes,rsyn = gen_netstim(func,cell.nrn.soma,pstable=pstable,input_spikes=[],syn_type='gabaa') # with input_spikes set to [], this only creates the netcon and synapse
		ncr.weight[0] = pstable['regamp']
			
		t=pstable['woff']+pstable['won']
		while t<TSTOP:
			tloc = t
			tend = tloc + pstable['woff']+pstable['won']
			j=0
			while tloc < tend:
				if j >= pstable['number']:
					t=tend
					break
				if pstable['freq'] > 0:
					rspikes.append(t)
					#print(t)
					t-=1000.0/float(pstable['freq'])
				j+=1
		rspikes.sort()
	
	#print spikes
	#print nc.postloc(),nc.weight[0],nc2.weight[0], funcargs, nmdaargs, len(spikes),len(spikes2)
	
	#print cell
	
	h.t = 0
	h.finitialize()
	
	for t in spikes:
		nc.event(t)
	
	if stype=='bpulse':
		for t in spikes_b:
			nc_b.event(t)
	
	if nmdafunc!=None:
		for t in spikes_nmda:
			nc_nmda.event(t)		
			if pstable['ampa_ratio'] >0:
					nc_ampa.event(t)
	
	if pstable['regamp'] > 0:
		for t in rspikes:
			print( t)
			ncr.event(t)
				
	cv.solve(TSTOP)
	#pc.psolve(TSTOP)
	
	size = len(rec_vectors[0])
	nstates = len(rec_vectors)
	if pstable['control'] < 20:
		fp= open('GABA_%s%s.dat' % (affix,affix2),'w')
		
		for i in range(size):
			for j in range(nstates):
				fp.write('%e  ' % rec_vectors[j][i])
			fp.write('%e\n' %(func(rec_vectors[0][i])))
		fp.close()
	
	check_keys = ['foff','fon','on','off']
	check = any(item in funcargs.keys() for item in check_keys) or not len(funcargs.keys()) # checks to see if func args are consistent with windowfunc
	
	if check:
		
		fp2 = open('spikes_%s%s.dat' % (affix,affix2),'w')
		for times in cell.srec:
			fp2.write('%e  %d\n' % (times,pstable['control']))
		fp2.close()
		
		interval = pstable['won']+pstable['woff']
		nintervals = 1#int(TSTOP/interval)
		binsize=50
		nbins = int(interval/binsize)
		spare_change = interval%binsize
		bins = [0 for i in range(nbins)]
		# function starts with on
		t = pstable['won']+pstable['woff'] # ignore first step
		centerpoint = pstable['won']/2 + pstable['woff']/2
		
		t+=centerpoint
		told = t
		temp = h.Vector()
		for i in range(nintervals):
			j=0
			for j in range(nbins):
				temp.where(cell.srec,'[]',t,t+binsize)
				bins[j]+=len(temp)/float(nintervals)
				t+=binsize
			t+= spare_change
		
		keys = funcargs.keys()
		execstr = 'freq = func(t'
		for key in keys:
			execstr+= ',%s=%f' % (key, funcargs[key])
		execstr+=')'
		t = pstable['won']+centerpoint-binsize/2
		fp= open('Histo_%s%s.dat' % (affix,affix2),'w')
		for things in bins:
			t+=binsize
			exec(execstr) in locals()
			print( t, things, freq)
			fp.write('%e  %e  %e\n'  % (t, things, freq))
		fp.close()
		

			
	return 0