# $Id: run.py,v 1.53 2010/12/15 22:20:27 samn Exp $ 

from pyinit import *
from geom import *
from network import *
from params import *
import sys

# sets up external inputs
if net.noise:
	net.set_noise_inputs(h.tstop) #h.tstop sets duration of inpus for make noise case

# handler for printing out time during simulation run
def fi():
	for i in range(0,int(h.tstop),100):
		h.cvode.event(i, "print " + str(i))

fih = h.FInitializeHandler(1, fi)

# initialize random # generators of NetStims - forces it at beginning of each sim
def myInitNetStims():
	#for i in range(19):
	#	print i,net.pyr.cell[i].soma.v,net.pyr.cell[i].Adend3.v,net.pyr.cell[i].Bdend.v
	#for i in range(19):
	#	print i,net.olm.cell[i].soma.v
	#for i in range(19):
	#	print i,net.bas.cell[i].soma.v
	net.init_NetStims()
		
fihns = h.FInitializeHandler(0, myInitNetStims)
# fihns = h.FInitializeHandler(1, myInitNetStims)

# handler for washin/washout
fiwash = None

olmWash =  [0, 0] # olm NMDA value for washin/washout
basWash =  [0, 0] # basket NMDA value for washin/washout
pyrWashA = [0, 0] # ...
pyrWashB = [0, 0] # ...
washinT  = 0      # washin time
washoutT = 0      # washout time

def dowashin():
	print("washIN at ", washinT, " = ", h.t , " ", olmWash[0], basWash[0], pyrWashB[0], pyrWashA[0])
	net.olm.set_r("somaNMDA",olmWash[0])
	net.bas.set_r("somaNMDA",basWash[0])
	net.pyr.set_r("BdendNMDA",pyrWashB[0])
	net.pyr.set_r("Adend3NMDA",pyrWashA[0])

def dowashout():
	print("washOUT at ", washoutT, " = " , h.t, " ", olmWash[1], basWash[1], pyrWashB[1], pyrWashA[1])	
	net.olm.set_r("somaNMDA",olmWash[1])
	net.bas.set_r("somaNMDA",basWash[1])
	net.pyr.set_r("BdendNMDA",pyrWashB[1])
	net.pyr.set_r("Adend3NMDA",pyrWashA[1])

def setwash():
	print("washinT ", washinT, " washoutT ", washoutT)
	h.cvode.event(washinT,"nrnpython(\"dowashin()\")")
	h.cvode.event(washoutT,"nrnpython(\"dowashout()\")")

# example to do washin/washout, after loading sim:
# import run
# h.tstop=100
# run.olmWash =  [0, 1]
# run.basWash =  [1, 1]
# run.pyrWashA = [1, 1]
# run.pyrWashB = [1, 1]
# run.washinT  = 30
# run.washoutT = 60
# fiwash = h.FInitializeHandler(1,setwash)
# h.run()

class Power():
	pass

class Batch:
	def __init__(self,net):
		self.net = net    #the network, cells, synapses, etc.
		self.pow = Power() #the data
		
	def copydata(self,obj):
		self.pow.n     = obj.n
		self.pow.x     = obj.x
		self.pow.timer = obj.timer
		self.pow.tp    = obj.tp
		self.pow.gp    = obj.gp
		self.pow.tf    = obj.tf
		self.pow.gf    = obj.gf
		self.pow.arch  = obj.arch
		
	def save(self):
		file = open('filen.obj', 'w')
		pickle.dump(self.pow,file)
		
	def load(self):
		file = open('filen.obj', 'r')
		self.pow = pickle.load(file)

	#this function is based on loop in r function, to get a string for the sim params
	def getsimstr(self,r1,r2,r3,r4):
		simstr = "olm_somaNMDA_" + str(r1) + "_"
		simstr = simstr + "bas_somaNMDA_" + str(r2) + "_"
		simstr = simstr + "pyr_BdendNMDA_" + str(r3) + "_"
		simstr = simstr + "pyr_Adend3NMDA_" + str(r4) + "_"
		return simstr		
		
	def r(self, n):
		self.pow.n     = n
		self.pow.arch  = Archive()
		self.pow.x     = numpy.linspace(0,1,self.pow.n)
		self.pow.timer = h.startsw()
		self.pow.tp    = numpy.zeros((self.pow.n,self.pow.n,self.pow.n,self.pow.n))
		self.pow.gp    = numpy.zeros((self.pow.n,self.pow.n,self.pow.n,self.pow.n))
		self.pow.tf    = numpy.zeros((self.pow.n,self.pow.n,self.pow.n,self.pow.n))
		self.pow.gf    = numpy.zeros((self.pow.n,self.pow.n,self.pow.n,self.pow.n))
		
		for i1,r1 in enumerate(self.pow.x):
			self.net.olm.set_r("somaNMDA",r1)
			for i2, r2 in enumerate(self.pow.x):
				self.net.bas.set_r("somaNMDA",r2)
				for i3, r3 in enumerate(self.pow.x):
					self.net.pyr.set_r("BdendNMDA",r3)
					for i4, r4 in enumerate(self.pow.x):
						self.net.pyr.set_r("Adend3NMDA",r4)
						simstr = self.getsimstr(r1,r2,r3,r4)
						print("NMDA/AMPA: " + simstr)
						h.run()
						print("Time: ", h.startsw() - self.pow.timer)
					
						self.pow.arch.reset_time_stamp()
						
						self.net.calc_psd() #calculate lfp,psd and draw it, then save
						self.pow.arch.save_fig(3,simstr+"fft")
						
						self.net.rasterplot()#draw raster for all cells, then save
						self.pow.arch.save_fig(1,simstr+"rasterogram")

						self.pow.arch.save_vec(simstr+"lfp",self.net.vlfp) #save LFP Vector to file

						self.net.setrastervecs() #setup raster Vectors for ALL cells & save
						self.pow.arch.save_vec(simstr+"idvec",self.net.myidvec)
						self.pow.arch.save_vec(simstr+"timevec",self.net.mytimevec)
			
						self.pow.tp[i1,i2,i3,i4] = self.net.tp
						self.pow.gp[i1,i2,i3,i4] = self.net.gp
						self.pow.tf[i1,i2,i3,i4] = self.net.tf
						self.pow.gf[i1,i2,i3,i4] = self.net.gf

						
	def plot_r(self):		
		self.plot_fun(5, "tp","Theta_Power")
		self.plot_fun(6, "gp","Gamma_Power")
		self.plot_fun(7, "tf","Theta_Frequency")
		self.plot_fun(8, "gf","Gamma_Frequency")		
		self.plot_fun(9, "tp","Theta_Power_Mean",    1)
		self.plot_fun(10,"gp","Gamma_Power_Mean",    1)
		self.plot_fun(11,"tf","Theta_Frequency_Mean",1)
		self.plot_fun(12,"gf","Gamma_Frequency_Mean",1)
		
	def plot_fun(self, fig, var, ylabel, mode=0):
		cond = ["olm","bas","pyrB","pyrA3"]
		f = pylab.figure(fig)
		f.clf()
		f.canvas.mpl_connect('pick_event', onpick)
		
		pylab.subplot(2,2,1)
		pylab.xlabel("NMDA/AMPA for " + cond[0])
		pylab.ylabel(ylabel)
		if mode==0:
			for i1, r1 in enumerate(self.pow.x):
				for i2, r2 in enumerate(self.pow.x):
					for i3, r3 in enumerate(self.pow.x):
						print("[:,"+str(i1)+","+str(i2)+","+str(i3)+"]")
						pylab.plot(self.pow.x, self.pow.__dict__[var][:,i1,i2,i3],label="[:,"+str(i1)+","+str(i2)+","+str(i3)+"]", picker=1)
		#pylab.label()
		else:
			pylab.plot(self.pow.x, self.pow.__dict__[var].mean(axis=1).mean(axis=1).mean(axis=1))

		
		pylab.subplot(2,2,2)
		pylab.xlabel("NMDA/AMPA for " + cond[1])
		pylab.ylabel(ylabel)
		if mode==0:	
			for i1, r1 in enumerate(self.pow.x):
				for i2, r2 in enumerate(self.pow.x):
					for i3, r3 in enumerate(self.pow.x):
						pylab.plot(self.pow.x, self.pow.__dict__[var][i1,:,i2,i3],label="["+str(i1)+",:,"+str(i2)+","+str(i3)+"]", picker=1)
						
		#pylab.label()
		else:
			pylab.plot(self.pow.x, self.pow.__dict__[var].mean(axis=0).mean(axis=1).mean(axis=1))
			
		pylab.subplot(2,2,3)
		pylab.xlabel("NMDA/AMPA for " + cond[2])
		pylab.ylabel(ylabel)
		if mode==0:	
			for i1, r1 in enumerate(self.pow.x):
				for i2, r2 in enumerate(self.pow.x):
					for i3, r3 in enumerate(self.pow.x):
						pylab.plot(self.pow.x, self.pow.__dict__[var][i1,i2,:,i3],label="["+str(i1)+","+str(i2)+",:,"+str(i3)+"]", picker=1)
						
		#pylab.label()		
		else:
			pylab.plot(self.pow.x, self.pow.__dict__[var].mean(axis=0).mean(axis=0).mean(axis=1))
		
		
		pylab.subplot(2,2,4)
		pylab.xlabel("NMDA/AMPA for " + cond[3])
		pylab.ylabel(ylabel)
		if mode==0:	
			for i1, r1 in enumerate(self.pow.x):
				for i2, r2 in enumerate(self.pow.x):
					for i3, r3 in enumerate(self.pow.x):
						pylab.plot(self.pow.x, self.pow.__dict__[var][i1,i2,i3,:],label="["+str(i1)+","+str(i2)+","+str(i3)+",:]", picker=1)
						
		#pylab.label()			
		else:
			pylab.plot(self.pow.x, self.pow.__dict__[var].mean(axis=0).mean(axis=0).mean(axis=0))
		
		self.pow.arch.save_fig(fig,ylabel)

def onpick(event):
	print("REWR")
	print(str(event.artist.get_label())+" ("+str(event.mouseevent.xdata)+","+str(event.mouseevent.ydata)+")")
	return True

#save vec to fn (fn is path)
def mysvvec(fn,vec):
	fp = h.File()
	fp.wopen(fn)
	if fp.isopen():
		vec.vwrite(fp)
		fp.close()
	else:
		print("savevec ERR: couldn't open " + fn)

#this class is for saving output, i.e. figures and py files to backup	
class Archive:
	def __init__(self):
		self.figprefix = "./gif" #prefix for saving figures
		self.datprefix = "./data"
		self.pyprefix = "./backup"
		self.reset_time_stamp()
		self.save_pyfile("par_sim.py")
		self.save_pyfile("Cells.py")
		
	def save_fig(self, fig, name):
		fn = os.path.join(self.figprefix, self.time_stamp+name+".svg")
		pylab.figure(fig)
		pylab.savefig(fn, orientation='landscape', format='svg', dpi=72)
		
	def reset_time_stamp(self):
		ts = datetime.datetime.now().timetuple()
		self.time_stamp = "_"+str(ts.tm_year)+"_"+str(ts.tm_mon)+"_"+str(ts.tm_mday)+"_"+str(ts.tm_hour)+"_"+str(ts.tm_min)+"_"+str(ts.tm_sec)
		
	def save_pyfile(self, fn):
		nfn = os.path.join(self.pyprefix,fn+self.time_stamp+".py")
		shutil.copy(fn, nfn)

	def save_vec(self, fn, vec):
		nfn = os.path.join(self.datprefix,fn+".vec")
		mysvvec(nfn,vec)

#run a sim and save data
def minrunsv(simstr,tstop=1200,dt=0.1):
	h.tstop=tstop
	h.dt=dt
	h.run()
	print("saving output data")
	net.calc_lfp()
	fn = "./data/"+simstr+"_lfp.vec"
	mysvvec(fn,net.vlfp)
	net.setsnq() # make NQS with spike times
	fn = "./data/"+simstr+"_snq.nqs"
	net.snq.sv(fn)
	print("making and saving output figures")

#read a Vector from file, fn is file-path, vec is a Vector
def myrdvec(fn,vec):
	fp=h.File()
	fp.ropen(fn)
	if not fp.isopen():
		print("myrdvec ERRA: Couldn't open " + fn)
		return False
	vec.vread(fp)
	fp.close()
	return True

#load data from minrunsv into net.vlfp,net.snq
def loadminrundat(simstr):
	fs = "./data/"+simstr+"_lfp.vec"
	try:
		net.vlfp.resize(0)
	except:
		net.vlfp = h.Vector()
	myrdvec(fs,net.vlfp)
	fs = "./data/"+simstr+"_snq.nqs"
	try:
		h.nqsdel(net.snq)		
	except:
		pass
	try:
		net.snq=h.NQS(fs)
	except:
		print("loadminrundat ERRB: couldn't read snq from " + fs)
	net.snq.verbose=0 # next, copy snq into vectors so can plot with net.rasterplot
	for po in net.cells:
		for i in range(len(po.lidvec)):
			id = po.cell[i].id
			po.lidvec[i].resize(0)
			po.ltimevec[i].resize(0)
			if net.snq.select("id",id):
				po.lidvec[i].copy(net.snq.getcol("id"))
				po.ltimevec[i].copy(net.snq.getcol("t"))
	net.snq.verbose=1

def testrun():
	net.olm.set_r("somaNMDA",0)
	h.run()
	arch = Archive()
	net.rasterplot(1)
	arch.save_fig(1,"tmp_rasterplot")
	net.psr.cell[0].plot_volt("soma",2)
	arch.save_fig(2,"tmp_psr_soma_volt")
	net.calc_psd(3)
	arch.save_fig(3,"tmp_fft")
	print("\a")

def batchrun():
	bat = Batch(net)
	bat.r(3)
	bat.plot_r()

def myrast(spikes,times,sz=12):	
	if h.g[0] == None:
		h.gg()
	spikes.mark(h.g[0],times,"O",sz,1,1)
	h.g[0].exec_menu("View = plot")

# testsame - for debugging two runs to make sure output is the same
def testsame(ts,v1,v2):
	h.tstop = ts
	v1 = h.Vector()
	h.run()
	net.calc_lfp()
	v1.copy(net.vlfp)
	v2 = h.Vector()
	h.run()
	net.calc_lfp()
	v2.copy(net.vlfp)
	print("same = " , v1.eq(v2))

#gethilbnqs - make two NQS objects out of LFP with phase/amplitude/filered signals in theta and gamma bands
def gethilbnqs(vlfp,minth=3,maxth=12,ming=30,maxg=80,usemlab=True):
  sampr = 1e3/h.dt # sampling rate in Hertz
  if usemlab:
    nqtheta=h.mathilbert(vlfp,sampr,minth,maxth)
    nqgamma=h.mathilbert(vlfp,sampr,ming,maxg)
  else:
    nar = vlfp.to_python() # -> python -> numpy format
    nar = numpy.array(nar)
    nqtheta=filt.gethilbnq(nar,sampr,minth,maxth) # get an NQS with 'theta' 
    nqgamma=filt.gethilbnq(nar,sampr,ming,maxg)# get an NQS with 'gamma'
  return [nqtheta,nqgamma]

#getampphnq - get an nqs with gamma amplitude vs theta phase - uses NQS objects created by gethilbnqs
def getampphnq(nqtheta,nqgamma,phbins=100,skipms=200):
  colp = int(nqgamma.fi("phase")) # column index for phase
  cola = int(nqgamma.fi("amp"))   # column index for amp
  phmin=nqgamma.v[colp].min() # minimum phase of gamma
  phmax=nqgamma.v[colp].max() # maximum phase of gamma
  phrng=phmax-phmin # range of gamma phase
  nq = h.NQS("avgamp","phase","n","err","minamp","maxamp") # output nqs - amp is average amplitude for a phase, vn is # of samples @ the phase
  #minamp is avgamp - stderr, maxamp is avgamp + stderr. those columns just for easier display of avg+/-error
  vamp=nq.v[0] # average amplitude for a given phase
  vph=nq.v[1] # theta phase
  vn=nq.v[2] # number of samples at the given phase
  ve=nq.v[3] # stderror
  vmin=nq.v[4] # avg-stderror
  vmax=nq.v[5] # avg+stderror
  vph.indgen(phmin,phmax,phrng/phbins) # init range of phases
  nq.pad()
  vamp.fill(0)
  vn.fill(0) # init counts to 0
  lv = h.List() # list to keep amplitude samples
  for i in range(int(vph.size())):
    lv.append(h.Vector())
  sz=int(nqgamma.v[0].size())
  startx=int(skipms/h.dt)
  for i in range(startx,sz,1):
    bin=int(phbins*(nqtheta.v[colp][i]-phmin)/phrng)
    if bin<0:
      print("bin < 0!")
    if bin>=phbins+1:
      print("bin >= phbins+1")
    lv.o(bin).append(nqgamma.v[cola][i])
  for i in range(0,int(vamp.size()),1):
    sz = lv.o(i).size()
    if sz > 0: # if no samples, skip
      av = lv.o(i).mean()
      if sz > 1: # make sure can call stderr
        er = lv.o(i).stderr()
      else:
        er = 0
      vamp.x[i] = av
      vn.x[i] = sz
      ve.x[i] = er
      vmin.x[i] = av - er
      vmax.x[i] = av + er
  return nq

# checkbase - compares baseline to OLM activity off
# returns results in a python list
def checkbase(endt=3e3,skipms=200,justone=False):
  vlfp = []
  vtmp = []
  nqp = []
  nqa = []
  nqh = []
  snq = []
  fnq = []
  h.tstop=endt
  j = 0
  dt = h.dt
  for i in range(1,-1,-1):
    print("set olm NMDA to ", float(i))
    net.olm.set_r("somaNMDA",float(i))
    print("running for " , endt , " ms ")
    h.run()
    net.calc_lfp()
    vlfp.append(net.vlfp)
    vtmp.append(h.Vector())
    vtmp[j].copy(vlfp[j],skipms/dt,vlfp[j].size()-1)
    vtmp[j].sub(vtmp[j].mean())
    nqp.append( h.matpmtm(vtmp[j],1e3/dt) )
    vtmp[j].copy(vlfp[j],skipms/dt,vlfp[j].size()-1)
    nqh.append( gethilbnqs(vtmp[j],3,12,30,80) )
    nqa.append( getampphnq(nqh[j][0],nqh[j][1]) )
    net.setsnq()
    snq.append( h.NQS() )
    snq[j].cp(net.snq)
    net.setfnq(skipms)
    fnq.append( h.NQS() )
    fnq[j].cp(net.fnq)
    net.pravgrates()
    j += 1
    if justone:
	    break
  return [vlfp,vtmp,nqp,nqa,nqh,snq,fnq]

############################
#   setup multithreading   #
pc = h.ParallelContext()   #
pc.nthread(32)             #

#h.load_file('parcom.hoc')
#pc = h.ParallelComputeTool()
#pc.nthread(8)
#p.multisplit(True)

############################

if 0:
	testrun()

if 0:
	h.tstop=200
	net.pyr.cell[0].set_spikes([100],"BdendNMDA", 28*0.04e-3)
	h.run()
	net.pyr.cell[0].plot_volt("soma")

if 0:
	batchrun()
####################################################################################################