from nrnTemplate.CellTypes.CellTemplateCommon import dINcell, MNcell, aINcell, xINcell, RBcell, tINcell, MHRcell, HHcell, dINcell_hull2015
import matplotlib.pylab as plt
import numpy as np
import random
import numpy.random as rnd
from mpi.mp_util import UniqueProcessMap
from util import plot_matrix_detailed, probabilistic_model_extension, inj_current, inj_current2, plotLeftRightSpikeTrain, plotLeftRightVoltageExt, plotLeftRightVoltageOffset, TimeVoltageTrace, classification_name, classify_behaviour
import time, datetime, os
import param
from shutil import copyfile
from neuron import h

# create a directory in figures/ named after date and time to store output files
path_tmp = "figures/"
today = datetime.datetime.now()
todaystr = today.isoformat()
os.mkdir(path_tmp+todaystr)
save_path = path_tmp+todaystr+"/"

# save files of specification for parameters
file_main = "main.py"
if ~os.path.isfile(save_path+file_main):
	copyfile(file_main,save_path+file_main)
file_util = "util.py"
if ~os.path.isfile(save_path+file_util):
	copyfile(file_util,save_path+file_util)
file_param = "param.py"
if ~os.path.isfile(save_path+file_param):
	copyfile(file_param,save_path+file_param)

par = param.create_params()

RecAll = par.RecAll
varDt = par.varDt
atol = par.atol
rtol = par.rtol
dt = par.dt
time_end = par.time_end

cell_types = par.cell_types
num_types = par.num_types
halves = par.halves
vect_index = par.vect_index
left_index = par.left_index
right_index = par.right_index
colors = par.colors

fixed_delay = par.fixed_delay
sa_prop = par.sa_prop
var_delay = par.var_delay
std_on = par.std_on

pos = par.pos
w = par.w
alpha = par.alpha
beta1 = par.beta1
beta2 = par.beta2

v = par.v
w = par.w
n_active = par.n_active
pos_active = par.pos_active
delay = par.delay
duration = par.duration
amplitude_mean = par.amplitude_mean
amplitude_std = par.amplitude_std


# Cells Creation 
def CellCreation(RecAll=0,varDt=False):
	cellist=[]
	for i in xrange(len(cell_types)):
		for j in vect_index[i]:
			if cell_types[i] in ("rb","dla","cin","mn"):
				cellist.append(MNcell(RecAll=1,varDt=varDt,atol=atol,rtol=rtol))
			elif cell_types[i] == "dlc":
				cellist.append(MNcell(RecAll=1,varDt=varDt,atol=atol,rtol=rtol))
			elif cell_types[i] == "ain":
				cellist.append(MNcell(RecAll=1,varDt=varDt,atol=atol,rtol=rtol))
			elif cell_types[i] == "ecin":
				cellist.append(MNcell(RecAll=1,varDt=varDt,atol=atol,rtol=rtol))
			elif cell_types[i] == "din":
				cellist.append(dINcell_hull2015(RecAll=1,varDt=varDt,atol=atol,rtol=rtol,theta=0.0))
			elif cell_types[i] == "xin":
				cellist.append(xINcell(RecAll=1,varDt=varDt,atol=atol,rtol=rtol))
			elif cell_types[i] == "tst":
				cellist.append(MNcell(RecAll=1,varDt=varDt,atol=atol,rtol=rtol))
			elif cell_types[i] == "tin":
				cellist.append(tINcell(RecAll=1,varDt=varDt,atol=atol,rtol=rtol))
			elif cell_types[i] == "tsp":
				cellist.append(RBcell(RecAll=1,varDt=varDt,atol=atol,rtol=rtol))			
			elif cell_types[i] == "mhr":
				cellist.append(MHRcell(RecAll=1,varDt=varDt,atol=atol,rtol=rtol))
			elif cell_types[i] == "dinr":
				cellist.append(dINcell(RecAll=1,varDt=varDt,atol=atol,rtol=rtol,theta=0.0))
			else:
				print cell_types[i]
				raise Exception("Cell created is not included in the list of available cells")
			cellist[-1].whatami=cell_types[i]
			cellist[-1].color=colors[i]
			cellist[-1].index=len(cellist)-1
			cellist[-1].type_num=i
			cellist[-1].pos=pos[j] 
			if j in left_index[i]:
				cellist[-1].body_side=1
			else:
				cellist[-1].body_side=2
	return cellist


# create synapse between neural types
def connection(pre,post,dist):
	if pre.whatami=="xin" and post.whatami=="xin":
		key = pre.whatami + " -> " + post.whatami
		if key in w.viewkeys():
			specs = w[key]
			tmp = rnd.choice(v)
			for spec in specs:
				syn_type = spec[0]
				w_mean = spec[1]
				w_std = spec[2]
				if pre.body_side==post.body_side:
					pre.connect(post, syn_type, w=w_mean*tmp, delay=spec[3])
				else:
					pre.connect(post, syn_type, w=w_mean*tmp, delay=2.0)

			if hasattr(post,"syn_ampa_std"):
				post.syn_ampa_std.alpha = alpha
			if hasattr(post,"syn_nmda_std"):
				post.syn_nmda_std.alpha = alpha


	else:
		key = pre.whatami + " -> " + post.whatami
		if key in w.viewkeys():
			specs = w[key]
			for spec in specs:
				if spec != None:
					syn_type = spec[0]
					w_mean = spec[1]
					w_std = spec[2]
					if w_std != 0.0:
						weight = rnd.normal(w_mean,w_std*std_on)
					else:
						if pre.whatami=="xin":
							molt=rnd.choice(v) #rnd.uniform(0,1)
							weight=molt*w_mean
						else:
							weight = w_mean	

					if spec[3]=="distance":
						distance = fixed_delay+dist*var_delay
					else:
						distance = spec[3]
					pre.connect(post, syn_type, w=np.abs(weight), delay=distance)

			if hasattr(post,"syn_nmda_sat") and pre.whatami=="din" and post.whatami=="din":
				post.syn_nmda_sat.alpha = beta1
			if hasattr(post,"syn_nmda_std") and pre.whatami=="din" and post.whatami=="din":
				post.syn_nmda_std.alpha = beta1
			if hasattr(post,"syn_ampa_std") and pre.whatami=="xin" and post.whatami!="xin":
				post.syn_ampa_std.alpha = beta2
			if hasattr(post,"syn_nmda_std") and pre.whatami=="xin" and post.whatami!="xin":
				post.syn_nmda_std.alpha = beta2	


def CreateConfigAdjExtended(cellist,sim_num):
	seed1 = int(sim_num)
	rnd.seed(seed1)
	with open(save_path+"seed", "wt") as f_seed:
		f_seed.write(str(seed1))
	A = probabilistic_model_extension(sim_num)
	#plot_matrix_detailed(A) # to plot the adjacency matrix
	#plt.savefig(save_path+"A"+str(sim_num)+".png")

	seed2 = random.getrandbits(32) # int(sim_num) 
	rnd.seed(seed2)
	with open(save_path+"seed", "wt") as f_seed:
		f_seed.write(str(seed2))
	n=len(cellist)
	for i in xrange(n):
		for j in xrange(n):
			if A[i,j]:
				connection(cellist[i],cellist[j],abs(cellist[i].pos-cellist[j].pos))

	# dIN gj
	gj_strength=0.2e-3

	for i in left_index[4]:
		for j in left_index[4]:
			if abs(cellist[i].pos-cellist[j].pos)<100.0: 
				cellist[i].connect(cellist[j],"gap",gmax=gj_strength)

	for i in right_index[4]:
		for j in right_index[4]:
			if abs(cellist[i].pos-cellist[j].pos)<100.0: 
				cellist[i].connect(cellist[j],"gap",gmax=gj_strength)


# run one swimming simulation
def SwimmingSimulation(tstop,sim_num):
	t_start = time.time()
	print "Running Entire Simulation..."
	
	print "Create Cells ..."
	cellist=CellCreation(RecAll=RecAll,varDt=varDt)
	print "Cells Created"

	print "Create Connectivity ..."
	CreateConfigAdjExtended(cellist,sim_num)
	print "Connectivity Created"

	print "Run Numerical Integration..."

	inj_current(cellist[pos_active:pos_active+4],delay,duration,amp_mean=amplitude_mean,amp_std=amplitude_std) # trunk skin touch
	#inj_current(cellist[2046:2046+2],delay,duration,amp_mean=amplitude_mean,amp_std=amplitude_std) # head skin touch 
	#inj_current(cellist[2216:2216+13],450.0,400.0,amp_mean=0.2,amp_std=0.02) # head pressure (13 tSps activated)
	#mhr_stop_protocol_perrins_2002(cellist[2296:2297], delay=1500.0, dur=30.0, amp_mean=0.3,amp_std=0.0, num_impulses=5) # protocol of injection of a single MHR

	all_times = RunSim(tstop=tstop,dt=dt)

	# ===== PLOTTING ======	

	print "End of the Integration"

	start_plot = 50
	# plot rostral spike trains
	plt.figure(1,figsize=(20,10))
	plotLeftRightSpikeTrain(cellist,[start_plot,tstop])
	plt.subplot(2,1,1)
	plt.xlim([start_plot,tstop])
	plt.subplot(2,1,2)
	plt.xlim([start_plot,tstop])


	# plot selected cells
	plot_idxs=[]
	plot_types=["rb","tsp","tst","mhr","tin","dlc","dla","xin","din","cin","ain","mn"]
	for type_id in plot_types:
		cells=[cell for cell in cellist if cell.whatami==type_id and cell.body_side==1]
		tmp=[len(cell.record["spk"]) for cell in cells]
		if len(tmp)>0:
			if type_id=="xin" or type_id=="din":
				for i in xrange(3):
					idx=rnd.choice([cell.index for cell in cells])
					plot_idxs.append(idx)
			else:
				idx=tmp.index(max(tmp))
				plot_idxs.append(cells[idx].index)
		else:
			plot_idxs.append(rnd.choice([cell.index for cell in cells]))
	for type_id in plot_types:
		cells=[cell for cell in cellist if cell.whatami==type_id and cell.body_side==2]
		tmp=[len(cell.record["spk"]) for cell in cells]
		if len(tmp)>0:
			if type_id=="xin" or type_id=="din":
				for i in xrange(3):
					idx=rnd.choice([cell.index for cell in cells])
					plot_idxs.append(idx)
			else:
				idx=tmp.index(max(tmp))
				plot_idxs.append(cells[idx].index)


	# plot rostral dINs voltage
	plt.figure(4,figsize=(15,8))
	plotLeftRightVoltageOffset([cellist[i] for i in plot_idxs],[start_plot,tstop],offset=50.0)
	plt.savefig(save_path+"tst_dlcr_tin_dinr_volt"+str(sim_num)+".png")

	# saving voltages 
	tmp = []
	tmp.append([cellist[idx].color for idx in plot_idxs])

	(t,v) = TimeVoltageTrace(cellist[idx])
	tmp.append(t)
	for idx in plot_idxs:
		(t,v) = TimeVoltageTrace(cellist[idx])
		tmp.append(v)
	np.save(save_path+"voltages_one_for_each"+str(sim_num)+".npy",tmp)

	np.save(save_path+"pos_one_for_each"+str(sim_num)+".npy",[cellist[idx].pos for idx in plot_idxs])

	tmp = []
	for idx in vect_index[4]:
		if pos[idx]<1000:
			(t,v) = TimeVoltageTrace(cellist[idx])
			tmp.append(v)
	np.save(save_path+"voltages_hdIN"+str(sim_num)+".npy",tmp)

	
	# algorithm for behavioural classification
	(out,tL,tR,tstar)=classify_behaviour([np.array(cell.record["spk"]) for cell in cellist],tmp,sim_num,t)

	plt.figure(1)
	plt.title(classification_name(out)+', tL='+str(tL)+', tR='+str(tR)+', tstar='+str(tstar))
	plt.savefig(save_path+"spk_train"+str(sim_num)+".png")


	# plot rostral dINs voltage
	plt.figure(5,figsize=(15,8))
	plotLeftRightVoltageExt([cellist[i] for i in [index for index in vect_index[4] if pos[index]<1000]],[start_plot,tstop]) 
	plt.subplot(2,1,1)
	plt.ylim([-60,0])
	plt.subplot(2,1,2)
	plt.ylim([-60,0])
	plt.savefig(save_path+"din_volt"+str(sim_num)+".png")

	plt.figure(5,figsize=(20,10))
	plotLeftRightSpikeTrain(cellist,[start_plot,tstop])
	plt.subplot(2,1,1)
	plt.xlim([tstop-300,tstop])
	plt.subplot(2,1,2)
	plt.xlim([tstop-300,tstop])
	#plt.savefig(save_path+"last_spk_train"+str(sim_num)+".png")


	# save spike trains
	np.save(save_path+"spikes"+str(sim_num)+".npy",[np.array(cell.record["spk"]) for cell in cellist])

	t_end = time.time()
	print "Simulation Took {0}s.".format(t_end-t_start)


	spk_mns=[]
	pos_mns=[]
	side_mns=[]
	for cell in cellist:
		if cell.whatami=="mn":
			spk_mns.append(np.array(cell.record["spk"]))
			pos_mns.append(cell.pos)
			side_mns.append(cell.body_side)
	np.save(save_path+'spk_mns'+str(sim_num),spk_mns)
	np.save(save_path+'pos_mns'+str(sim_num),pos_mns)
	np.save(save_path+'side_mns'+str(sim_num),side_mns)
	
	t_end = time.time()
	print "Simulation Took {0}s.".format(t_end-t_start)

	# return reaction times
	if classification_name(out)==1 or classification_name(out)==2:
		react_time=tstar
	else:
		react_time=None

	destroy(cellist)

def destroy(cellist):
	for cell in cellist:
		cell.destroy()

def SimForAll(sim_num):
	seed = int(random.getrandbits(32))
	rnd.seed(seed)
	with open("seed", "wt") as f_seed:
		f_seed.write(str(seed))
	output=SwimmingSimulation(time_end,sim_num)
	return output

def RunManySim():
	num_process = 5 # number of cores to use
	processes = UniqueProcessMap(num_process)
	I = range(1,101)
	out=processes.map(SimForAll,I)
	np.save(save_path+"out.npy",out)
	return out

# run simulation using Euler/CVode method
def RunSim(v_init=-80.0,tstop=0.0,dt=0.01):
	all_times=[]
	t_start = time.time()
	h.dt = dt
	h.t = 0.0
	h.finitialize(v_init)
	while h.t<tstop:
		all_times.append(h.t)
		h.fadvance()
	return all_times


def mhr_stop_protocol_perrins_2002(cellist, delay=0.0, dur=0.0, amp_mean=0.0,amp_std=0.0, num_impulses=0):
	for cell in cellist:
		cell.impulses = []
		time_interval = 20
		for i in xrange(num_impulses):
			cell.impulses.append(h.IClamp(cell.soma(0.5)))
			cell.impulses[-1].delay = delay + (dur+time_interval)*i
			cell.impulses[-1].dur = dur
			cell.impulses[-1].amp = rnd.normal(amp_mean,amp_std)