# $Id: network.py,v 1.128 2012/11/09 19:40:34 samn Exp $

from pyinit import *
from geom import *
import random

gGID = 0 # global ID for cells

class Population:
    "Population of cells"
    # cell_type -- pyr, bas, olm
    # n -- number of cells in the population
    # x, y, z -- initial position for the first Cell
    # dx -- an increment of the x-position for the cell location
    # amp, dur, delay -- parameters for the IClamp in the soma
    # Spikes are stored in ltimevec (times) and lidvec (cell # within the population)
    def __init__(self, cell_type, n, x, y, z, dx, amp, dur, delay):
        global gGID
        self.cell = [] # List of cells in the population
        self.nc   = [] # NetCon list for recording spikes
        self.n    = n  # number of cells
        self.x    = x
        self.y    = y
        self.z    = z
        self.ltimevec = h.List() # list of Vectors for recording spikes, one per cell
        self.lidvec = h.List()
        self.nssidx = {}
        self.nseidx = {}
        self.ncsidx = {}
        self.nceidx = {}
        for i in range(n):
            self.cell.append(cell_type(x+i*dx,y,z,gGID))
            self.cell[-1].somaInj.amp   = amp
            self.cell[-1].somaInj.dur   = dur
            self.cell[-1].somaInj.delay = delay
            self.nc.append(h.NetCon(self.cell[-1].soma(0.5)._ref_v, None, sec=self.cell[-1].soma))
            self.ltimevec.append(h.Vector()) #NB: each NetCon gets own Vectors for recording. needed to avoid multithreading crash
            self.lidvec.append(h.Vector())
            self.nc[-1].record(self.ltimevec[-1],self.lidvec[-1],gGID) # record cell spikes with gGID
            gGID = gGID + 1 # inc global cell ID

    def set_r(self, syn, r):
        for c in self.cell:
            c.__dict__[syn].syn.r = r

    def clear_volt(self):
        for cell in self.cell: cell.clear_volt()

    def clear_spikes(self):
        for L in [self.ltimevec, self.lidvec]:
            for vec in L: vec.resize(0)

class Network:

    def __init__(self,noise=True,connections=True,DoMakeNoise=True,iseed=1234,UseNetStim=True,wseed=4321,scale=1.0,MSGain=0.0,SaveConn=False):
        import math
        print("Setting Cells")
        self.pyr = Population(cell_type=PyrAdr,n=int(math.ceil(800*scale)), x= 0, y=0, z=0, dx=50, amp= 50e-3, dur=1e9, delay=2*h.dt)
        self.bas = Population(cell_type=Bwb,   n=int(math.ceil(200*scale)), x=10, y=0, z=0, dx=50, amp=     0, dur=  0, delay=2*h.dt)
        self.olm = Population(cell_type=Ow,   n=int(math.ceil(200*scale)), x=20, y=0, z=0, dx=50, amp=-25e-3, dur=1e9, delay=2*h.dt)

        # psr = sensor cell to estimate the E->E connections
        self.psr = Population(cell_type=PyrAdr,n=1,   x= 0, y=0, z=0, dx=50, amp= 50e-3, dur=1e9, delay=2*h.dt)
        self.cells = [self.pyr, self.bas, self.olm, self.psr]
        self.iseed = iseed # seed for noise inputs
        self.noise = noise
        self.DoMakeNoise = DoMakeNoise
        self.UseNetStim = UseNetStim
        self.wseed = wseed # seed for 'wiring'
        self.MSGain = MSGain # gain for MS weights
        self.RecPyr = False
        self.SaveConn = SaveConn

        if connections:
            print("Setting Connections")
            self.set_all_conns()

    def set_noise_inputs(self,simdur): #simdur only used for make_all_noise
        if self.DoMakeNoise:
            if self.UseNetStim:
                self.make_all_NetStims(simdur,self.iseed)
            else:
                self.make_all_noise(simdur,self.iseed)
        else:
            self.load_all_noise()
        print("Done!")

    def load_all_noise(self): #load noise from data files
        print("Loading Noise")
        print("to PYR")
        self.b_pyr_somaAMPAf=self.load_spikes("spike_noise_pyr_800_soma_AMPA_ISI_1_N_10000_noise_1.npy",self.pyr,"somaAMPAf",0.05e-3)
        self.b_pyr_Adend3AMPAf=self.load_spikes("spike_noise_pyr_800_Adend3_AMPA_ISI_1_N_10000_noise_1.npy",self.pyr,"Adend3AMPAf",0.05e-3)
        self.b_pyr_somaGABAf=self.load_spikes("spike_noise_pyr_800_soma_GABA_ISI_1_N_10000_noise_1.npy",self.pyr,"somaGABAf",0.012e-3)
        self.b_pyr_Adend3GABAf=self.load_spikes("spike_noise_pyr_800_Adend3_GABA_ISI_1_N_10000_noise_1.npy",self.pyr,"Adend3GABAf",0.012e-3)
        self.b_pyr_Adend3NMDA=self.load_spikes("spike_noise_pyr_800_Adend3_NMDA_ISI_100_N_100_noise_1.npy", self.pyr,"Adend3NMDA",6.5e-3)
        print("to BAS")
        self.b_bas_somaAMPAf=self.load_spikes("spike_noise_bas_200_soma_AMPA_ISI_1_N_10000_noise_1.npy",self.bas,"somaAMPAf",w=0.02e-3)
        self.b_bas_somaGABA=self.load_spikes("spike_noise_bas_200_soma_GABA_ISI_1_N_10000_noise_1.npy",self.bas,"somaGABAf",w=0.2e-3)
        self.b_bas_somaGABAf=self.load_spikes("spike_noise_bas_200_soma_GABAf_ISI_150_N_65_noise_0.npy",self.bas,"somaGABAss",w=1.6e-3)
        print("to OLM")
        self.b_olm_somaAMPAf=self.load_spikes("spike_noise_olm_200_soma_AMPA_ISI_1_N_10000_noise_1.npy",self.olm,"somaAMPAf",w=0.02e-3)
        self.b_olm_somaGABAf=self.load_spikes("spike_noise_olm_200_soma_GABA_ISI_1_N_10000_noise_1.npy",self.olm,"somaGABAf",w=0.2e-3)
        self.b_olm_somaGABAss=self.load_spikes("spike_noise_olm_200_soma_GABAss_ISI_150_N_65_noise_0.npy",self.olm,"somaGABAss",w=1.6e-3)

    #this should be called @ beginning of each sim - done in an FInitializeHandler in run.py
    def init_NetStims(self):
        # h.mcell_ran4_init(self.iseed)
        for i in range(len(self.nrl)):
            rds = self.nrl[i]
            sead = self.nrlsead[i]
            rds.MCellRan4(sead,sead)
            rds.negexp(1)
            # print i,rds,sead

    #creates NetStims (and associated NetCon,Random) - provide 'noise' inputs
    #returns next useable value of sead
    def make_NetStims(self,po,syn,w,ISI,time_limit,sead):
        po.nssidx[syn] = len(self.nsl) #index into net.nsl
        po.ncsidx[syn] = len(self.ncl) #index into net.ncl
        for i in range(po.n):
            cel = po.cell[i]

            ns = h.NetStim()
            ns.interval = ISI
            ns.noise = 1
            ns.number = (1e3 / ISI) * time_limit
            ns.start = 0

            nc = h.NetCon(ns,cel.__dict__[syn].syn)
            nc.delay = h.dt * 2 # 0
            nc.weight[0] = w

            rds = h.Random()
            rds.negexp(1)            # set random # generator using negexp(1) - avg interval in NetStim
            rds.MCellRan4(sead,sead) # seeds are in order, shouldn't matter
            ns.noiseFromRandom(rds)  # use random # generator for this NetStim

            #ns.start = rds.discunif(0,1e3) # start inputs random time btwn 0-1e3 ms to avoid artificial sync
            #rds.MCellRan4(sead,sead) # reinit rand # generator

            self.nsl.append(ns)
            self.ncl.append(nc)
            self.nrl.append(rds)
            self.nrlsead.append(sead)
            sead = sead + 1

        po.nseidx[syn] = len(self.nsl)-1
        po.nceidx[syn] = len(self.ncl)-1

        return sead

    # setup recording of pyramidal cell inputs, assumes using NetCon,NetStims
    def RecPYRInputs(self):
        self.RecPyr = True
        self.NCV = {}
        self.sys = ['somaAMPAf', 'Adend3AMPAf', 'somaGABAf', 'Adend3GABAf']
        sys=self.sys
        for s in sys:
            self.NCV[s] = []
            sidx = self.pyr.ncsidx[s]
            eidx = self.pyr.nceidx[s]
            for i in range(sidx,eidx+1):
                self.NCV[s].append(h.Vector())
                self.ncl[i].record(self.NCV[s][-1])

    # make an NQS with pyramidal cell input times
    def setnqin(self):
        try:
            h.nqsdel(self.nqin)
        except:
            pass
        self.nqin = h.NQS("id","sy","vt")
        nqin=self.nqin
        nqin.odec("vt")
        jdx = 0
        for s in self.sys:
            sidx = self.pyr.ncsidx[s]
            eidx = self.pyr.nceidx[s]
            idx = 0
            for i in range(0,len(self.NCV[s])):
                nqin.append(idx,jdx,self.NCV[s][i])
                idx = idx + 1
            jdx = jdx + 1

    # make a histogram of pyramidal cell spike outputs
    def mkspkh(self,binsz):
        snq=self.snq
        snq.verbose = 0
        self.spkh = h.List()
        for i in range(0,800):
            if snq.select("id",i) > 0:
                vt = snq.getcol("t")
                self.spkh.append(vt.histogram(0,h.tstop,binsz))
            else:
                self.spkh.append(h.Vector())
        snq.verbose=1

    def make_all_NetStims(self,simdur,rdmseed):
        print("Making NetStims")
        # h.mcell_ran4_init(self.iseed)
        self.nsl = [] #NetStim List
        self.ncl = [] #NetCon List
        self.nrl = [] #Random List for NetStims
        self.nrlsead = [] #List of seeds for NetStim randoms
        # numpy.random.seed(rdmseed) # initialize random # generator
        print("Making Noise")
        print("to PYR")
        rdtmp = rdmseed # starting sead value - incremented in make_NetStims
        rdtmp=self.make_NetStims(po=self.pyr, syn="somaAMPAf",   w=0.05e-3,  ISI=1,  time_limit=simdur, sead=rdtmp)
        rdtmp=self.make_NetStims(po=self.pyr, syn="Adend3AMPAf", w=0.05e-3,  ISI=1,  time_limit=simdur, sead=rdtmp)
        rdtmp=self.make_NetStims(po=self.pyr, syn="somaGABAf",   w=0.012e-3, ISI=1,  time_limit=simdur, sead=rdtmp)
        rdtmp=self.make_NetStims(po=self.pyr, syn="Adend3GABAf", w=0.012e-3, ISI=1,  time_limit=simdur, sead=rdtmp)
        rdtmp=self.make_NetStims(po=self.pyr, syn="Adend3NMDA",  w=6.5e-3,   ISI=100,time_limit=simdur, sead=rdtmp)
        print("to BAS")
        rdtmp=self.make_NetStims(po=self.bas, syn="somaAMPAf",   w=0.02e-3,  ISI=1,  time_limit=simdur, sead=rdtmp)
        rdtmp=self.make_NetStims(po=self.bas, syn="somaGABAf",   w=0.2e-3,   ISI=1,  time_limit=simdur, sead=rdtmp)
        print("to OLM")
        #rdtmp=self.make_NetStims(po=self.olm, syn="somaAMPAf",   w=0.02e-3,  ISI=1,  time_limit=simdur, sead=rdtmp)
        rdtmp=self.make_NetStims(po=self.olm, syn="somaAMPAf",   w=0.0625e-3,  ISI=1,  time_limit=simdur, sead=rdtmp)
        rdtmp=self.make_NetStims(po=self.olm, syn="somaGABAf",   w=0.2e-3,   ISI=1,  time_limit=simdur, sead=rdtmp)
        #setup medial septal inputs to OLM and BASKET cells, note that MSGain can be 0 == no effect
        ns = h.NetStim()
        ns.interval = 150
        ns.noise = 0 # NO randomness for the MS inputs
        ns.number = (1e3 / 150.0) * simdur
        self.nsl.append(ns)
        for i in range(self.bas.n): # MS inputs to BASKET cells
            nc = h.NetCon(ns,self.bas.cell[i].__dict__["somaGABAss"].syn)
            nc.delay = 2*h.dt
            nc.weight[0] = 1.6e-3 * self.MSGain
            self.ncl.append(nc)
        for i in range(self.olm.n): # MS inputs to OLM cells
            nc = h.NetCon(ns,self.olm.cell[i].__dict__["somaGABAss"].syn)
            nc.delay = 2*h.dt
            nc.weight[0] = 1.6e-3 * self.MSGain
            self.ncl.append(nc)

    def make_all_noise(self,simdur,rdmseed): # create noise for simdur milliseconds
        numpy.random.seed(rdmseed) # initialize random # generator
        import math
        print("Making Noise")
        fctr = (simdur+simdur/2) / 10000.0
        print("to PYR")
        self.b_pyr_somaAMPAf=self.make_spikes(self.pyr,"somaAMPAf",0.05e-3,self.pyr.n,"soma",1,math.ceil(10000*fctr),1,simdur)
        self.b_pyr_Adend3AMPAf=self.make_spikes(self.pyr,"Adend3AMPAf",0.05e-3,self.pyr.n,"Adend3",1,math.ceil(10000*fctr),1,simdur)
        self.b_pyr_somaGABAf=self.make_spikes(self.pyr,"somaGABAf",0.012e-3,self.pyr.n,"soma",1,math.ceil(10000*fctr),1,simdur)
        self.b_pyr_Adend3GABAf=self.make_spikes(self.pyr,"Adend3GABAf",0.012e-3,self.pyr.n,"Adend3",1,math.ceil(10000*fctr),1,simdur)
        self.b_pyr_Adend3NMDA=self.make_spikes(self.pyr,"Adend3NMDA",6.5e-3,self.pyr.n,"Adend3",100,math.ceil(100*fctr),1,simdur)
        print("to BAS")
        self.b_bas_somaAMPAf=self.make_spikes(self.bas,"somaAMPAf",0.02e-3,self.bas.n,"soma",1,math.ceil(10000*fctr),1,simdur)
        self.b_bas_somaGABA=self.make_spikes(self.bas,"somaGABAf",0.2e-3,self.bas.n,"soma",1,math.ceil(10000*fctr),1,simdur)
        self.b_bas_somaGABAf=self.make_spikes(self.bas,"somaGABAss",1.6e-3,self.bas.n,"soma",150,math.ceil(65*fctr),0,simdur)
        print("to OLM")
        self.b_olm_somaAMPAf=self.make_spikes(self.olm,"somaAMPAf",0.02e-3,self.olm.n,"soma",1,math.ceil(10000*fctr),1,simdur)
        self.b_olm_somaGABAf=self.make_spikes(self.olm,"somaGABAf",0.2e-3,self.olm.n,"soma",1,math.ceil(10000*fctr),1,simdur)
        self.b_olm_somaGABAss=self.make_spikes(self.olm,"somaGABAss",1.6e-3,self.olm.n,"soma",150,math.ceil(65*fctr),0,simdur)

    def make_conn(self, preN, postN, conv):
        conn = numpy.zeros((postN,conv),dtype=numpy.int16)
        for i in range(postN):
            conn[i,:]=random.sample(list(range(preN)),conv)
        return conn

    def set_all_conns(self):
        random.seed(self.wseed) # initialize random # generator for wiring
        print("PYR -> X , NMDA")   # src, trg, syn, delay, weight, conv
        self.pyr_bas_NM=self.set_connections(self.pyr,self.bas, "somaNMDA", 2, 1.15*1.2e-3, 100)
        self.pyr_olm_NM=self.set_connections(self.pyr,self.olm, "somaNMDA", 2, 1.0*0.7e-3, 10)
        self.pyr_pyr_NM=self.set_connections(self.pyr,self.pyr, "BdendNMDA",2, 1*0.004e-3,  25)

        print("PYR -> X , AMPA")
        self.pyr_bas_AM=self.set_connections(self.pyr,self.bas, "somaAMPAf",2, 0.3*1.2e-3,  100)
        self.pyr_olm_AM=self.set_connections(self.pyr,self.olm, "somaAMPAf",2, 0.3*1.2e-3,  10)
        self.pyr_pyr_AM=self.set_connections(self.pyr,self.pyr, "BdendAMPA",2, 0.5*0.04e-3, 25)

        print("BAS -> X , GABA")
        #self.bas_bas_GA=self.set_connections(self.bas,self.bas, "somaGABAf",2, 1.0e-3, 60)#orig 1
        #self.bas_bas_GA=self.set_connections(self.bas,self.bas, "somaGABAf",2, 2  *  1.5*1.0e-3, 60)#new 1
        self.bas_bas_GA=self.set_connections(self.bas,self.bas, "somaGABAf",2, 3  *  1.5*1.0e-3, 60)#new 2
        self.bas_pyr_GA=self.set_connections(self.bas,self.pyr, "somaGABAf",2, 2  *  2*0.18e-3, 50)#new 1

        print("OLM -> PYR , GABA")
        #self.olm_pyr_GA=self.set_connections(self.olm,self.pyr, "Adend2GABAs",2, 3*6.0e-3, 20)#original weight value
        self.olm_pyr_GA=self.set_connections(self.olm,self.pyr, "Adend2GABAs",2, 4.0  *  3*6.0e-3, 20)#new weight value

        #pyramidal to PSR cell -- for testing only
        print("PYR -> PSR, AMPA/NMDA")
        self.pyr_psr_NM=self.set_connections(self.pyr,self.psr, "BdendNMDA",2, 1*0.004e-3,  25)
        self.pyr_psr_AM=self.set_connections(self.pyr,self.psr, "BdendAMPA",2, 0.5*0.04e-3, 25)


    def set_conn_weight(self, conn, weight):
        for nc in conn:
            nc.weight[0] = weight

    def set_connections(self,src,trg,syn,delay,w,conv):
        conn = self.make_conn(src.n,trg.n,conv)
        nc = []
        for post_id, all_pre in enumerate(conn):
            for j, pre_id in enumerate(all_pre):
                nc.append(h.NetCon(src.cell[pre_id].soma(0.5)._ref_v, trg.cell[post_id].__dict__[syn].syn, 0, delay, w, sec=src.cell[pre_id].soma))
        if self.SaveConn:
            try:
                print(self.nqcon.size())
            except:
                self.nqcon = h.NQS("id1","id2","w","syn")
                self.nqcon.strdec("syn")
            for post_id, all_pre in enumerate(conn):
                for j, pre_id in enumerate(all_pre):
                    self.nqcon.append(src.cell[pre_id].id,trg.cell[post_id].id,w,syn)

        return nc

    def load_spikes(self,fn,po,syn,w,time_limit=10000):
        fn = os.path.join("data",fn)
        events = numpy.load(fn)
        print("Begin setting events...", po)
        print(events.shape)
        for i,ii in enumerate(events):
            ii=ii[ii<=time_limit]
            po.cell[i].__dict__[syn].append(ii)
            po.cell[i].__dict__[syn].syn.Vwt = w
        print("End setting events")
        return events

    def make_spikes(self,po,syn,w,cellN,comp,ISI,eventN,noise,time_limit):
        events = numpy.random.exponential(ISI, (cellN,eventN))*noise+numpy.repeat(ISI,cellN*eventN).reshape((cellN,eventN))*(1-noise)
        events = numpy.cumsum(events,axis=1)
        print("Begin setting events...", po)
        print(events.shape)
        for i,ii in enumerate(events):
            ii=ii[ii<=time_limit]
            po.cell[i].__dict__[syn].append(ii)
            po.cell[i].__dict__[syn].syn.Vwt = w
        print("End setting events")
        return events

    def rasterplot(self,sz=2):
        pon  = 0
        if h.g[0] == None:
            h.gg()
        col = [2, 4, 3, 1]
        for po in self.cells:
            id = h.Vector()
            tv = h.Vector()
            for i in range(po.n):
                id.append(po.lidvec[i])
                tv.append(po.ltimevec[i])
            id.mark(h.g[0],tv,"O",sz,col[pon],1)
            pon += 1
        h.g[0].exec_menu("View = plot")

    def setrastervecs(self):
        self.myidvec = h.Vector() #IDs and firing times for ALL cells
        self.mytimevec = h.Vector()
        for po in self.cells:
            for i in range(po.n):
                self.myidvec.append(po.lidvec[i])
                self.mytimevec.append(po.ltimevec[i])

    # setsnq - make an NQS with ids, spike times, types
    def setsnq(self):
        try:
            h.nqsdel(self.snq)
        except:
            pass
        self.snq = h.NQS("id","t","ty")
        ty = 0
        vec = h.Vector()
        for po in self.cells:
            for i in range(po.n):
                self.snq.v[0].append(po.lidvec[i])
                self.snq.v[1].append(po.ltimevec[i])
                vec.resize(po.lidvec[i].size())
                vec.fill(ty)
                self.snq.v[2].append(vec)
            ty += 1

    def clear_volt (self):
        for pop in self.cells: pop.clear_volt()

    def clear_spikes (self):
        for pop in self.cells: pop.clear_spikes()

    def clear_mem (self):
        self.clear_volt()
        self.clear_spikes()

    # get an nqs with somatic voltage from each cell - only valid after a run
    def getnqvolt (self):
        allcells=0
        for i in range(3): allcells += len(self.cells[i].cell)
        nqv=h.NQS(allcells); cdx=0; pops=['pyr', 'bas', 'olm']
        for i in range(3):
            for c in self.cells[i].cell:
                nqv.v[cdx].copy(c.soma_volt)
                nqv.s[cdx].s = pops[i] + '_' + str(cdx)
                cdx += 1
        nqv.resize('t'); nqv.v[int(nqv.m[0])-1].indgen(0,h.t,h.dt)
        return nqv

    # setfnq - make an NQS with ids, firing rates, types
    def setfnq (self,skipms=200):
        try:
            self.snq.tog("DB")
        except:
            self.setsnq()
        try:
            h.nqsdel(self.fnq)
        except:
            pass
        self.fnq = h.NQS("id","freq","ty")
        tf = h.tstop - skipms
        ty = 0
        for po in self.cells:
            for i in range(po.n):
                id = po.cell[i].id
                n = float( self.snq.select("t",">",skipms,"id",id) )
                self.fnq.append(id, n*1e3/tf, ty)
            ty += 1

    # pravgrates - print average firing rates using self.fnq
    def pravgrates(self,skipms=200):
        try:
            self.fnq.tog("DB")
        except:
            self.setfnq(skipms)
        ty = 0
        tf = float( h.tstop - skipms )
        for po in self.cells:
            self.fnq.select("ty",ty)
            vf = self.fnq.getcol("freq")
            if vf.size() > 1:
                print("ty: ", ty, " avg rate = ", vf.mean(), "+/-", vf.stderr(), " Hz")
            else:
                print("ty: ", ty, " avg rate = ", vf.mean(), "+/-", 0.0 , " Hz")
            ty += 1

    def calc_lfp (self): # lfp is modeled as a difference between voltages in distal apical and basal compartemnts
        self.vlfp = h.Vector(self.pyr.cell[0].Adend3_volt.size()) #lfp in neuron Vector
        for cell in self.pyr.cell:
            self.vlfp.add(cell.Adend3_volt)
            self.vlfp.sub(cell.Bdend_volt)
        self.vlfp.div(len(self.pyr.cell)) # normalize lfp by amount of pyr cells
        self.lfp=numpy.array(self.vlfp.to_python()) # convert to python array (so can do PSD)

    def getlfp (self,skipms=200,subm=True):
        v1=h.Vector()
        nsamp = skipms / h.dt # number of samples to skip from start,end
        self.calc_lfp()
        v1.copy(self.vlfp,nsamp,self.vlfp.size()-1-nsamp)
        if subm: v1.sub(v1.mean())
        return v1

    def calc_psd(self,fig=3):
        self.calc_lfp()
        t0   = 200 # reject first ms of the signal
        fmax = 200 # upper limit for a periodogram frequency
        div  = int(1000/h.dt/(2*fmax)) # downsample the signal
        tr = [3,  12] # Theta frequency range
        gr = [30, 80] # Gamma frequency range
        t0i = int(t0/h.dt)
        if t0i > len(self.lfp):
            print("LFP is too short! (<200 ms)")
            return 0,0,0,0,0,0

        pylab.figure(fig)
        pylab.clf()

        pylab.subplot(2,1,1) # plot LFP
        pylab.plot(numpy.array(list(range(len(self.lfp))))*h.dt, self.lfp)

        pylab.subplot(2,1,2) # plot periodogram
        data = self.lfp[t0i::div] # downsample data
        Pxx, freqs = pylab.psd(data-data.mean(), Fs=1000/h.dt/div) # calculate FFT
        tind = numpy.where((freqs>=tr[0]) & (freqs<=tr[1]))[0] # index where for theta frequences
        gind = numpy.where((freqs>=gr[0]) & (freqs<=gr[1]))[0] # index where for gamma frequences
        self.tp = Pxx[tind].mean() * numpy.diff(tr) # integral over theta power
        self.gp = Pxx[gind].mean() * numpy.diff(gr) # integral over gamma power
        self.ti = self.get_lim_max(Pxx, tind) # index of the frequency with a maximal power in theta range
        self.gi = self.get_lim_max(Pxx, gind) # index of the frequency with a maximal power in gamma range
        self.tf = freqs[self.ti]
        self.gf = freqs[self.gi]
        pylab.scatter(self.tf, 10*numpy.log10(Pxx[self.ti]), 100, 'b','o')
        pylab.scatter(self.gf, 10*numpy.log10(Pxx[self.gi]), 100, 'r','o')
        pylab.xlim(0,fmax)

    def get_lim_max(self, data, ind):       # return the position of the maximal element in data located in the postion indexed by ind
        return  ind[data[ind].argmax()]


#make the Network - use params in rseed.txt if the file exists -- makes it easier to run a batch
#if rseed.txt doesn't exist, the Network is created with default params
try:
    fp = open("./rseed.txt","r")
    ls = fp.readlines()
    ISEED = int(ls[0])
    WSEED = int(ls[1])
    MSG = 0.0
    if len(ls) > 2:
        MSG = float(ls[2])
    fp.close()
    #create the network
    net = Network(noise=True,connections=True,DoMakeNoise=True,iseed=ISEED,UseNetStim=True,wseed=WSEED,scale=1.0,MSGain=MSG)
    print("set network from rseed.txt : iseed=",ISEED,", WSEED=",WSEED,", MSG = ",MSG)
except:
    net = Network()
    print("set network from default constructor")

#setup some variables in hoc
def sethocix():
    h("PYRt=0")
    h("BASKETt=1")
    h("OLMt=2")
    h("PSRt=3")
    h("CTYP.o(PYRt).s=\"PYRt\"")
    h("CTYP.o(BASKETt).s=\"BASKETt\"")
    h("CTYP.o(OLMt).s=\"OLMt\"")
    h("CTYP.o(PSRt).s=\"PSRt\"")
    h("ix[PYRt]=0")
    h("ixe[PYRt]=799")
    h("ix[BASKETt]=800")
    h("ixe[BASKETt]=999")
    h("ix[OLMt]=1000")
    h("ixe[OLMt]=1199")
    h("ix[PSRt]=1200")
    h("ixe[PSRt]=1200")
    h("numc[PYRt]=800")
    h("numc[BASKETt]=200")
    h("numc[OLMt]=200")
    h("numc[PSRt]=1")

sethocix()