# $Id: run.py,v 1.56 2012/09/20 14:06:03 samn Exp $ 

from pyinit import *
from geom import *
import time # to time the simulation
import datetime # to format time of simulation
# from network import *
# from params import *
import sys
exec(open("./analysisPlottingCode.py").read()) # execfile('analysisPlottingCode.py')
try:
    if not sys.path.__contains__('/usr/site/nrniv/local/python'):
        sys.path.append('/usr/site/nrniv/local/python')
    import filt
except:
    print ("Couldn't import filt routines used in gethilbnqs")

# handler for printing out time during simulation run
def fi():
    for i in range(0,int(h.tstop),100):
        h.cvode.event(i, "print " + str(i))

fih = h.FInitializeHandler(1, fi)

# create CVode object in python
cvode = h.CVode()

def setCVevents_displaySimTime(tstop):
    '''set CVode events to display simulation time - will be called after restore state'''
    for i in range(0, int(tstop), 100):
        cvode.event(i, "print " + str(i))


# initialize random # generators of NetStims - forces it at beginning of each sim
def myInitNetStims():
    #for i in range(19):
    #   print i,net.pyr.cell[i].soma.v,net.pyr.cell[i].Adend3.v,net.pyr.cell[i].Bdend.v
    #for i in range(19):
    #   print i,net.olm.cell[i].soma.v
    #for i in range(19):
    #   print i,net.bas.cell[i].soma.v
    net.init_NetStims()
        

# sets up external inputs
if net.noise:
    net.set_noise_inputs(h.tstop) #h.tstop sets duration of inpus for make noise case
    fihns = h.FInitializeHandler(0, myInitNetStims)

# sets up signal to all pyramidal cells
if net.DoMakeSignal:
    net.set_signal_input(h.tstop)
    
# setup recording from pyramidal cell inputs
# setup recording 
# config parameter read in network.py
if recPyrInputSpikes: net.RecPYRInputs()
if recSignalSpikes: net.recordSignal_spikes()

# fihns = h.FInitializeHandler(1, myInitNetStims)

# handler for washin/washout
fiwash = None

olmWash =  [0, 0] # olm NMDA value for washin/washout
basWash =  [0, 0] # basket NMDA value for washin/washout
pyrWashA = [0, 0] # ...
pyrWashB = [0, 0] # ...
washinT  = 0      # washin time
washoutT = 0      # washout time

def dowashin():
    print ("washIN at ", washinT, " = ", h.t , " ", olmWash[0], basWash[0], pyrWashB[0], pyrWashA[0])
    net.olm.set_r("somaNMDA",olmWash[0])
    net.bas.set_r("somaNMDA",basWash[0])
    net.pyr.set_r("BdendNMDA",pyrWashB[0])
    net.pyr.set_r("Adend3NMDA",pyrWashA[0])

def dowashout():
    print ("washOUT at ", washoutT, " = " , h.t, " ", olmWash[1], basWash[1], pyrWashB[1], pyrWashA[1])   
    net.olm.set_r("somaNMDA",olmWash[1])
    net.bas.set_r("somaNMDA",basWash[1])
    net.pyr.set_r("BdendNMDA",pyrWashB[1])
    net.pyr.set_r("Adend3NMDA",pyrWashA[1])

def setwash():
    print ("washinT ", washinT, " washoutT ", washoutT)
    h.cvode.event(washinT,"nrnpython(\"dowashin()\")")
    h.cvode.event(washoutT,"nrnpython(\"dowashout()\")")

# example to do washin/washout, after loading sim:
# import run
# h.tstop=100
# run.olmWash =  [0, 1]
# run.basWash =  [1, 1]
# run.pyrWashA = [1, 1]
# run.pyrWashB = [1, 1]
# run.washinT  = 30
# run.washoutT = 60
# fiwash = h.FInitializeHandler(1,setwash)
# h.run()

# save LFP with current pyramidal cell voltages
_svNUM = 0
def saveLFPInterm (fbase):
    global _svNUM
    fout = fbase + "_svNUM_" + str(_svNUM) + "_lfp.vec"
    print ("time is " , h.t, " saving LFP to " , fout)
    net.calc_lfp()
    mysvvec(fout,net.vlfp)
    net.clear_mem()
    _svNUM += 1

# setup events to save LFP intermittently
_svFBase = "./tmp_"
_svINC = 1000
def setSaveLFPEvents ():
    global _svNUM
    _svNUM = 0
    stre = "nrnpython(\"saveLFPInterm(_svFBase)\")"
    for tt in range(_svINC,int(h.tstop),_svINC):
        h.cvode.event(tt,stre)
    h.cvode.event(h.tstop,stre)

# example to save LFP intermittently: 
#  fisv = h.FInitializeHandler(0, setSaveLFPEvents)

class Power():
    pass

class Batch:
    def __init__(self,net):
        self.net = net    #the network, cells, synapses, etc.
        self.pow = Power() #the data
        
    def copydata(self,obj):
        self.pow.n     = obj.n
        self.pow.x     = obj.x
        self.pow.timer = obj.timer
        self.pow.tp    = obj.tp
        self.pow.gp    = obj.gp
        self.pow.tf    = obj.tf
        self.pow.gf    = obj.gf
        self.pow.arch  = obj.arch
        
    def save(self):
        file = open('filen.obj', 'w')
        pickle.dump(self.pow,file)
        
    def load(self):
        file = open('filen.obj', 'r')
        self.pow = pickle.load(file)

    #this function is based on loop in r function, to get a string for the sim params
    def getsimstr(self,r1,r2,r3,r4):
        simstr = "olm_somaNMDA_" + str(r1) + "_"
        simstr = simstr + "bas_somaNMDA_" + str(r2) + "_"
        simstr = simstr + "pyr_BdendNMDA_" + str(r3) + "_"
        simstr = simstr + "pyr_Adend3NMDA_" + str(r4) + "_"
        return simstr       
        
    def r(self, n):
        self.pow.n     = n
        self.pow.arch  = Archive()
        self.pow.x     = numpy.linspace(0,1,self.pow.n)
        self.pow.timer = h.startsw()
        self.pow.tp    = numpy.zeros((self.pow.n,self.pow.n,self.pow.n,self.pow.n))
        self.pow.gp    = numpy.zeros((self.pow.n,self.pow.n,self.pow.n,self.pow.n))
        self.pow.tf    = numpy.zeros((self.pow.n,self.pow.n,self.pow.n,self.pow.n))
        self.pow.gf    = numpy.zeros((self.pow.n,self.pow.n,self.pow.n,self.pow.n))
        
        for i1,r1 in enumerate(self.pow.x):
            self.net.olm.set_r("somaNMDA",r1)
            for i2, r2 in enumerate(self.pow.x):
                self.net.bas.set_r("somaNMDA",r2)
                for i3, r3 in enumerate(self.pow.x):
                    self.net.pyr.set_r("BdendNMDA",r3)
                    for i4, r4 in enumerate(self.pow.x):
                        self.net.pyr.set_r("Adend3NMDA",r4)
                        simstr = self.getsimstr(r1,r2,r3,r4)
                        print ("NMDA/AMPA: " + simstr)
                        h.run()
                        print ("Time: ", h.startsw() - self.pow.timer)
                    
                        self.pow.arch.reset_time_stamp()
                        
                        self.net.calc_psd() #calculate lfp,psd and draw it, then save
                        self.pow.arch.save_fig(3,simstr+"fft")
                        
                        self.net.rasterplot()#draw raster for all cells, then save
                        self.pow.arch.save_fig(1,simstr+"rasterogram")

                        self.pow.arch.save_vec(simstr+"lfp",self.net.vlfp) #save LFP Vector to file

                        self.net.setrastervecs() #setup raster Vectors for ALL cells & save
                        self.pow.arch.save_vec(simstr+"idvec",self.net.myidvec)
                        self.pow.arch.save_vec(simstr+"timevec",self.net.mytimevec)
            
                        self.pow.tp[i1,i2,i3,i4] = self.net.tp
                        self.pow.gp[i1,i2,i3,i4] = self.net.gp
                        self.pow.tf[i1,i2,i3,i4] = self.net.tf
                        self.pow.gf[i1,i2,i3,i4] = self.net.gf

                        
    def plot_r(self):       
        self.plot_fun(5, "tp","Theta_Power")
        self.plot_fun(6, "gp","Gamma_Power")
        self.plot_fun(7, "tf","Theta_Frequency")
        self.plot_fun(8, "gf","Gamma_Frequency")        
        self.plot_fun(9, "tp","Theta_Power_Mean",    1)
        self.plot_fun(10,"gp","Gamma_Power_Mean",    1)
        self.plot_fun(11,"tf","Theta_Frequency_Mean",1)
        self.plot_fun(12,"gf","Gamma_Frequency_Mean",1)
        
    def plot_fun(self, fig, var, ylabel, mode=0):
        cond = ["olm","bas","pyrB","pyrA3"]
        f = pylab.figure(fig)
        f.clf()
        f.canvas.mpl_connect('pick_event', onpick)
        
        pylab.subplot(2,2,1)
        pylab.xlabel("NMDA/AMPA for " + cond[0])
        pylab.ylabel(ylabel)
        if mode==0:
            for i1, r1 in enumerate(self.pow.x):
                for i2, r2 in enumerate(self.pow.x):
                    for i3, r3 in enumerate(self.pow.x):
                        print ("[:,"+str(i1)+","+str(i2)+","+str(i3)+"]")
                        pylab.plot(self.pow.x, self.pow.__dict__[var][:,i1,i2,i3],label="[:,"+str(i1)+","+str(i2)+","+str(i3)+"]", picker=1)
        #pylab.label()
        else:
            pylab.plot(self.pow.x, self.pow.__dict__[var].mean(axis=1).mean(axis=1).mean(axis=1))

        
        pylab.subplot(2,2,2)
        pylab.xlabel("NMDA/AMPA for " + cond[1])
        pylab.ylabel(ylabel)
        if mode==0: 
            for i1, r1 in enumerate(self.pow.x):
                for i2, r2 in enumerate(self.pow.x):
                    for i3, r3 in enumerate(self.pow.x):
                        pylab.plot(self.pow.x, self.pow.__dict__[var][i1,:,i2,i3],label="["+str(i1)+",:,"+str(i2)+","+str(i3)+"]", picker=1)
                        
        #pylab.label()
        else:
            pylab.plot(self.pow.x, self.pow.__dict__[var].mean(axis=0).mean(axis=1).mean(axis=1))
            
        pylab.subplot(2,2,3)
        pylab.xlabel("NMDA/AMPA for " + cond[2])
        pylab.ylabel(ylabel)
        if mode==0: 
            for i1, r1 in enumerate(self.pow.x):
                for i2, r2 in enumerate(self.pow.x):
                    for i3, r3 in enumerate(self.pow.x):
                        pylab.plot(self.pow.x, self.pow.__dict__[var][i1,i2,:,i3],label="["+str(i1)+","+str(i2)+",:,"+str(i3)+"]", picker=1)
                        
        #pylab.label()      
        else:
            pylab.plot(self.pow.x, self.pow.__dict__[var].mean(axis=0).mean(axis=0).mean(axis=1))
        
        
        pylab.subplot(2,2,4)
        pylab.xlabel("NMDA/AMPA for " + cond[3])
        pylab.ylabel(ylabel)
        if mode==0: 
            for i1, r1 in enumerate(self.pow.x):
                for i2, r2 in enumerate(self.pow.x):
                    for i3, r3 in enumerate(self.pow.x):
                        pylab.plot(self.pow.x, self.pow.__dict__[var][i1,i2,i3,:],label="["+str(i1)+","+str(i2)+","+str(i3)+",:]", picker=1)
                        
        #pylab.label()          
        else:
            pylab.plot(self.pow.x, self.pow.__dict__[var].mean(axis=0).mean(axis=0).mean(axis=0))
        
        self.pow.arch.save_fig(fig,ylabel)

def onpick(event):
    print ("REWR")
    print (str(event.artist.get_label())+" ("+str(event.mouseevent.xdata)+","+str(event.mouseevent.ydata)+")")
    return True

#save vec to fn (fn is path)
def mysvvec(fn,vec):
    fp = h.File()
    fp.wopen(fn)
    if fp.isopen():
        vec.vwrite(fp)
        fp.close()
    else:
        print ("savevec ERR: couldn't open " + fn)

#this class is for saving output, i.e. figures and py files to backup   
class Archive:
    def __init__(self):
        self.figprefix = "./gif" #prefix for saving figures
        self.datprefix = "./data"
        self.pyprefix = "./backup"
        self.reset_time_stamp()
        self.save_pyfile("par_sim.py")
        self.save_pyfile("Cells.py")
        
    def save_fig(self, fig, name):
        fn = os.path.join(self.figprefix, self.time_stamp+name+".svg")
        pylab.figure(fig)
        pylab.savefig(fn, orientation='landscape', format='svg', dpi=72)
        
    def reset_time_stamp(self):
        ts = datetime.datetime.now().timetuple()
        self.time_stamp = "_"+str(ts.tm_year)+"_"+str(ts.tm_mon)+"_"+str(ts.tm_mday)+"_"+str(ts.tm_hour)+"_"+str(ts.tm_min)+"_"+str(ts.tm_sec)
        
    def save_pyfile(self, fn):
        nfn = os.path.join(self.pyprefix,fn+self.time_stamp+".py")
        shutil.copy(fn, nfn)

    def save_vec(self, fn, vec):
        nfn = os.path.join(self.datprefix,fn+".vec")
        mysvvec(nfn,vec)

#run a sim and save data
def minrunsv (simstr,tstop=1200,dt=0.1,savevolt=False):
  h.tstop=tstop
  h.dt=dt
  h.run()
  print ("saving output data")
  net.calc_lfp()
  fn = "./data/"+simstr+"_lfp.vec"
  mysvvec(fn,net.vlfp)
  net.setsnq() # make NQS with spike times
  fn = "./data/"+simstr+"_snq.nqs"
  net.snq.sv(fn)
  if savevolt:
    nqv = net.getnqvolt()
    nqv.sv('./data/'+simstr+'_nqvolt.nqs')

#read a Vector from file, fn is file-path, vec is a Vector
def myrdvec(fn,vec):
  fp=h.File()
  fp.ropen(fn)
  if not fp.isopen():
    print ("myrdvec ERRA: Couldn't open " + fn)
    return False
  vec.vread(fp)
  fp.close()
  return True

# concat a series of LFPs - fbase is base of filename
def catlfp (fbase,svn):
  vlfp, vtmp = h.Vector(), h.Vector()
  for i in range(svn):
    fin = fbase + "_svNUM_" + str(i) + "_lfp.vec"
    if myrdvec(fin,vtmp): vlfp.append(vtmp)
  return vlfp

#load data from minrunsv into net.vlfp,net.snq
def loadminrundat(simstr,datadir="./data/",rdvolt=False):
  fs = datadir+simstr+"_lfp.vec"
  try:
    net.vlfp.resize(0)
  except:
    net.vlfp = h.Vector()
    myrdvec(fs,net.vlfp)
  fs = datadir+simstr+"_snq.nqs"
  try:
    h.nqsdel(net.snq)       
  except:
    pass
  try:
    net.snq=h.NQS(fs)
  except:
    print ("loadminrundat ERRB: couldn't read snq from " + fs)
  net.snq.verbose=0 # next, copy snq into vectors so can plot with net.rasterplot
  for po in net.cells:
    for i in range(len(po.lidvec)):
      ID = po.cell[i].id
      po.lidvec[i].resize(0)
      po.ltimevec[i].resize(0)
      if net.snq.select("id",ID):
        po.lidvec[i].copy(net.snq.getcol("id"))
        po.ltimevec[i].copy(net.snq.getcol("t"))
  net.snq.verbose=1
  if rdvolt:
    try:
      h.nqsdel(net.nqv)
    except:
      pass
    fs = datadir+simstr+'_nqvolt.nqs'
    try:
      net.nqv=h.NQS(fs)
    except:
      print ("loadminrundat ERRC: couldn't read nqvolt from " + fs)



def testrun():
  net.olm.set_r("somaNMDA",0)
  h.run()
  arch = Archive()
  net.rasterplot(1)
  arch.save_fig(1,"tmp_rasterplot")
  net.psr.cell[0].plot_volt("soma",2)
  arch.save_fig(2,"tmp_psr_soma_volt")
  net.calc_psd(3)
  arch.save_fig(3,"tmp_fft")
  print ("\a")

def batchrun():
  bat = Batch(net)
  bat.r(3)
  bat.plot_r()

def myrast(spikes,times,sz=12): 
  if h.g[0] == None:
    h.gg()
  spikes.mark(h.g[0],times,"O",sz,1,1)
  h.g[0].exec_menu("View = plot")

# testsame - for debugging two runs to make sure output is the same
def testsame(ts,v1,v2):
    h.tstop = ts
    v1 = h.Vector()
    h.run()
    net.calc_lfp()
    v1.copy(net.vlfp)
    v2 = h.Vector()
    h.run()
    net.calc_lfp()
    v2.copy(net.vlfp)
    print ("same = " , v1.eq(v2))

#gethilbnqs - make two NQS objects out of LFP with phase/amplitude/filered signals in theta and gamma bands
def gethilbnqs(vlfp,minth=3,maxth=12,ming=30,maxg=80,usemlab=True):
  sampr = 1e3/h.dt # sampling rate in Hertz
  if usemlab:
    nqtheta=h.mathilbert(vlfp,sampr,minth,maxth)
    nqgamma=h.mathilbert(vlfp,sampr,ming,maxg)
  else:
    nar = vlfp.to_python() # -> python -> numpy format
    nar = numpy.array(nar)
    nqtheta=filt.gethilbnq(nar,sampr,minth,maxth) # get an NQS with 'theta' 
    nqgamma=filt.gethilbnq(nar,sampr,ming,maxg)# get an NQS with 'gamma'
  return [nqtheta,nqgamma]

#getampphnq - get an nqs with gamma amplitude vs theta phase - uses NQS objects created by gethilbnqs
def getampphnq(nqtheta,nqgamma,phbins=100,skipms=200):
  colp = int(nqgamma.fi("phase")) # column index for phase
  cola = int(nqgamma.fi("amp"))   # column index for amp
  phmin=nqgamma.v[colp].min() # minimum phase of gamma
  phmax=nqgamma.v[colp].max() # maximum phase of gamma
  phrng=phmax-phmin # range of gamma phase
  nq = h.NQS("avgamp","phase","n","err","minamp","maxamp") # output nqs - amp is average amplitude for a phase, vn is # of samples @ the phase
  #minamp is avgamp - stderr, maxamp is avgamp + stderr. those columns just for easier display of avg+/-error
  vamp=nq.v[0] # average amplitude for a given phase
  vph=nq.v[1] # theta phase
  vn=nq.v[2] # number of samples at the given phase
  ve=nq.v[3] # stderror
  vmin=nq.v[4] # avg-stderror
  vmax=nq.v[5] # avg+stderror
  vph.indgen(phmin,phmax,phrng/phbins) # init range of phases
  nq.pad()
  vamp.fill(0)
  vn.fill(0) # init counts to 0
  lv = h.List() # list to keep amplitude samples
  for i in range(int(vph.size())):
    lv.append(h.Vector())
  sz=int(nqgamma.v[0].size())
  startx=int(skipms/h.dt)
  for i in range(startx,sz,1):
    bin=int(phbins*(nqtheta.v[colp][i]-phmin)/phrng)
    if bin<0:
      print ("bin < 0!")
    if bin>=phbins+1:
      print ("bin >= phbins+1")
    lv.o(bin).append(nqgamma.v[cola][i])
  for i in range(0,int(vamp.size()),1):
    sz = lv.o(i).size()
    if sz > 0: # if no samples, skip
      av = lv.o(i).mean()
      if sz > 1: # make sure can call stderr
        er = lv.o(i).stderr()
      else:
        er = 0
      vamp.x[i] = av
      vn.x[i] = sz
      ve.x[i] = er
      vmin.x[i] = av - er
      vmax.x[i] = av + er
  return nq

# checkbase - compares baseline to OLM activity off
# returns results in a python list
def checkbase(endt=3e3,skipms=200,justone=False):
  vlfp = []
  vtmp = []
  nqp = []
  nqa = []
  nqh = []
  snq = []
  fnq = []
  h.tstop=endt
  j = 0
  dt = h.dt
  for i in range(1,-1,-1):
    print ("set olm NMDA to ", float(i))
    net.olm.set_r("somaNMDA",float(i))
    print ("running for " , endt , " ms ")
    h.run()
    net.calc_lfp()
    vlfp.append(net.vlfp)
    vtmp.append(h.Vector())
    vtmp[j].copy(vlfp[j],skipms/dt,vlfp[j].size()-1)
    vtmp[j].sub(vtmp[j].mean())
    nqp.append( h.matpmtm(vtmp[j],1e3/dt) )
    vtmp[j].copy(vlfp[j],skipms/dt,vlfp[j].size()-1)
    nqh.append( gethilbnqs(vtmp[j],3,12,30,80) )
    nqa.append( getampphnq(nqh[j][0],nqh[j][1]) )
    net.setsnq()
    snq.append( h.NQS() )
    snq[j].cp(net.snq)
    net.setfnq(skipms)
    fnq.append( h.NQS() )
    fnq[j].cp(net.fnq)
    net.pravgrates()
    j += 1
    if justone:
        break
  return [vlfp,vtmp,nqp,nqa,nqh,snq,fnq]


def myrun():
    '''will contain the code which we would like to be executed during run, eg displaySimTime code'''
    # for t in numpy.arange(0, h.tstop, 1):
    #     h.cvode.event(t, displaySimTime)#displays h.t every 1 msec
    clockStart = time.time()
    h.run()
    clockEnd = time.time()  
    print ('\nsim runtime: ' + print_secondsConverted(clockEnd - clockStart))


def print_secondsConverted(seconds):
    '''convert seconds to days, hours, minutes and seconds, depending on the need'''
    minutes, seconds = divmod(seconds, 60)
    if minutes >= 60:
        hours, minutes = divmod(minutes, 60)
        if hours >= 24:
            days, hours = divmod(hours, 24)
            return '{days} days, {hours} hours, {minutes} minutes, {seconds} seconds'.format(days = days, hours = hours, minutes = minutes, seconds = seconds)
        else:
            return '{hours} hours, {minutes} minutes, {seconds} seconds'.format(hours = hours, minutes = minutes, seconds = seconds)
    else:
        return '{min} minutes, {sec} seconds'.format(min = minutes, sec = seconds)

    

## NOT USED
# #  for t in numpy.arange(tstart, tstop, 500): h.cvode.event(t, displaySimTime)#displays h.t every 500 msec

# def displaySimTime():
#   '''displays h.t as the simulation is running, as numbers that are changing dynamically - helpful to keep track of simulations that are running for long'''
#   sys.stdout.write('\rh.t: {0} msec...'.format(h.t))
#   sys.stdout.flush()


def savestate(statestr, statedir = './data/stateFiles'):
    ''' save state into statestr file'''
    s = statedir + statestr 
    f = h.File(s)
    ss = h.SaveState()
    ss.save()
    ss.fwrite(f)
    print('saved states')


def restorestate(statestr, statedir = './data/stateFiles/'):
    ''' restore saved states from statestr file'''
    s = statedir + statestr 
    f = h.File(s)
    ss = h.SaveState()
    ss.fread(f)
    ss.restore()
    print('restored states')
    return ss


def runFromSavedState(statestr, tstop = h.tstop, statedir = './data/stateFiles/'):
    ''' will initialize and run from restored states'''
    h.stdinit()
    restorestate(statestr, statedir)
    setCVevents_displaySimTime(tstop)
    print('time after restoring state: {time}'.format(time = h.t))
    clockStart = time.time()
    h.continuerun(tstop)
    clockEnd = time.time()  
    print ('\nsim runtime: ' + print_secondsConverted(clockEnd - clockStart))



############################
#   setup multithreading   #
pc = h.ParallelContext()   #
pc.nthread(32)             #
# pc.nthread(1) # to run without threading   # 


#h.load_file('parcom.hoc')
#pc = h.ParallelComputeTool()
#pc.nthread(8)
#p.multisplit(True)

############################

if 0:
    testrun()

if 0:
    h.tstop=200
    net.pyr.cell[0].set_spikes([100],"BdendNMDA", 28*0.04e-3)
    h.run()
    net.pyr.cell[0].plot_volt("soma")

if 0:
    batchrun()
####################################################################################################