"""
Main simulation file for replicating
"Cellular and neurochemical basis of sleep stages in the thalamocortical network," eLife, 2016.
"""
import config
from datetime import datetime
from neuron import h
import numpy as np
h('load_file("stdgui.hoc")') # need this instead of import gui to get the simulation to be reproducible and not give an LFP flatline
from network_class import Net
h('CVode[0].use_fast_imem(1)') #see use_fast_imem() at https://www.neuron.yale.edu/neuron/static/new_doc/simctrl/cvode.html
def onerun(randSeed,Npyr,Ninh,Nre,Ntc):
h.Random().Random123_globalindex(randSeed) #this changes ALL Random123 streams
# create network
net = Net(config.Npyr,config.Ninh,config.Nre,config.Ntc)
#build list of sections in cortical cells, because only these are the ones that will generate the LFP
cort_secs = []
for pyr_gid in net.pyr_gidList:
secs_to_add = config.pc.gid2cell(pyr_gid).all
for sec in secs_to_add:
cort_secs.append(sec)
for inh_gid in net.inh_gidList:
secs_to_add = config.pc.gid2cell(inh_gid).all
for sec in secs_to_add:
cort_secs.append(sec)
if config.doextra:
recording_callback = (config.callback, cort_secs)
h.cvode.extra_scatter_gather(0,recording_callback) #this tells NEURON to call 'callback' on every time step, in order to compute LFP
'''if do_sleepstates==True, specify how parameters should change to induce different sleep states (see lines 640-777 of C++ main.cpp)'''
if config.do_sleepstates:
# see https://www.neuron.yale.edu/neuron/static/py_doc/programming/math/vector.html?highlight=vector#Vector.play
t=h.Vector([config.awake_to_s2_start,config.awake_to_s2_end,config.s2_to_s3_start,config.s2_to_s3_end,config.s3_to_rem_start,config.s3_to_rem_end,config.rem_to_s2_start,config.rem_to_s2_end,config.rem_to_s2_end+h.dt]) #last entry (with "+h.dt") ensures that last parameter values remain constant to end of simulation ("If a constant outside the range is desired, make sure the last two points have the same y value and have different t values")
#the next three entries are for changing GABA_A and GABA_B strengths in thalamus
RE_TC_GABA_A_vec=h.Vector(config.re2tc_gaba_a_str * np.array([config.awake_GABA_thal,config.s2_GABA_thal,config.s2_GABA_thal,config.s3_GABA_thal,config.s3_GABA_thal,config.rem_GABA_thal,config.rem_GABA_thal,config.s2_GABA_thal,config.s2_GABA_thal]))
RE_TC_A_veclist = []
for tc_gid in net.tc_gidList:
norm_vec = RE_TC_GABA_A_vec / config.pc.gid2cell(tc_gid).k_RE_TC_GABA_A
RE_TC_A_veclist.append(norm_vec)
RE_TC_A_veclist[-1].play(config.pc.gid2cell(tc_gid).synlist[0]._ref_gmax, t, True)
RE_TC_GABA_B_vec=h.Vector(config.re2tc_gaba_b_str * np.array([config.awake_GABA_thal,config.s2_GABA_thal,config.s2_GABA_thal,config.s3_GABA_thal,config.s3_GABA_thal,config.rem_GABA_thal,config.rem_GABA_thal,config.s2_GABA_thal,config.s2_GABA_thal]))
RE_TC_B_veclist = []
for tc_gid in net.tc_gidList:
norm_vec = RE_TC_GABA_B_vec / config.pc.gid2cell(tc_gid).k_RE_TC_GABA_B
RE_TC_B_veclist.append(norm_vec)
RE_TC_B_veclist[-1].play(config.pc.gid2cell(tc_gid).synlist[1]._ref_gmax, t, True)
RE_RE_GABA_A_vec=h.Vector(config.re2re_gaba_a_str * np.array([config.awake_GABA_thal,config.s2_GABA_thal,config.s2_GABA_thal,config.s3_GABA_thal,config.s3_GABA_thal,config.rem_GABA_thal,config.rem_GABA_thal,config.s2_GABA_thal,config.s2_GABA_thal]))
RE_RE_veclist = []
for re_gid in net.re_gidList:
norm_vec = RE_RE_GABA_A_vec / config.pc.gid2cell(re_gid).k_RE_RE
RE_RE_veclist.append(norm_vec)
RE_RE_veclist[-1].play(config.pc.gid2cell(re_gid).synlist[2]._ref_gmax, t, True)
#next three entries address changing AMPA strengths for connections terminating in thalamus
TC_RE_AMPA_vec=h.Vector(config.tc2re_ampa_str * np.array([config.awake_AMPA_thal,config.s2_AMPA_thal,config.s2_AMPA_thal,config.s3_AMPA_thal,config.s3_AMPA_thal,config.rem_AMPA_thal,config.rem_AMPA_thal,config.s2_AMPA_thal,config.s2_AMPA_thal]))
TC_RE_veclist =[]
for re_gid in net.re_gidList:
norm_vec = TC_RE_AMPA_vec / config.pc.gid2cell(re_gid).k_TC_RE
TC_RE_veclist.append(norm_vec)
TC_RE_veclist[-1].play(config.pc.gid2cell(re_gid).synlist[0]._ref_gmax, t, True)
PYR_TC_AMPA_vec=h.Vector(config.pyr2tc_ampa_str * np.array([config.awake_AMPA_thal,config.s2_AMPA_thal,config.s2_AMPA_thal,config.s3_AMPA_thal,config.s3_AMPA_thal,config.rem_AMPA_thal,config.rem_AMPA_thal,config.s2_AMPA_thal,config.s2_AMPA_thal]))
PYR_TC_veclist = []
for tc_gid in net.tc_gidList:
norm_vec = PYR_TC_AMPA_vec / config.pc.gid2cell(tc_gid).k_PY_TC
PYR_TC_veclist.append(norm_vec)
PYR_TC_veclist[-1].play(config.pc.gid2cell(tc_gid).synlist[2]._ref_gmax, t, True)
PYR_RE_AMPA_vec=h.Vector(config.pyr2re_ampa_str * np.array([config.awake_AMPA_thal,config.s2_AMPA_thal,config.s2_AMPA_thal,config.s3_AMPA_thal,config.s3_AMPA_thal,config.rem_AMPA_thal,config.rem_AMPA_thal,config.s2_AMPA_thal,config.s2_AMPA_thal]))
PYR_RE_veclist = []
for re_gid in net.re_gidList:
norm_vec = PYR_RE_AMPA_vec / config.pc.gid2cell(re_gid).k_PY_RE
PYR_RE_veclist.append(norm_vec)
PYR_RE_veclist[-1].play(config.pc.gid2cell(re_gid).synlist[1]._ref_gmax, t, True)
#next three entries address all AMPA connections termining in cortex, other than PYR->PYR connections
TC_PYR_AMPA_D2_vec=h.Vector(config.tc2pyr_ampa_str * np.array([config.awake_AMPA_cort,config.s2_AMPA_cort,config.s2_AMPA_cort,config.s3_AMPA_cort,config.s3_AMPA_cort,config.rem_AMPA_cort,config.rem_AMPA_cort,config.s2_AMPA_cort,config.s2_AMPA_cort]))
TC_PYR_veclist = [] #pretty sure I need to make a list of all the play vectors, bc. according to documentation, "The system maintains a set of play vectors and the vector will be removed from the list if the vector or var is destroyed."
for pyr_gid in net.pyr_gidList:
norm_vec = TC_PYR_AMPA_D2_vec / config.pc.gid2cell(pyr_gid).k_TC_PY #need to noramlize each connection by in-degree of post-synaptic cell
TC_PYR_veclist.append(norm_vec)
TC_PYR_veclist[-1].play( config.pc.gid2cell(pyr_gid).synlist[0]._ref_gmax, t, True)
TC_INH_AMPA_D2_vec=h.Vector(config.tc2inh_ampa_str * np.array([config.awake_AMPA_cort,config.s2_AMPA_cort,config.s2_AMPA_cort,config.s3_AMPA_cort,config.s3_AMPA_cort,config.rem_AMPA_cort,config.rem_AMPA_cort,config.s2_AMPA_cort,config.s2_AMPA_cort]))
TC_INH_veclist = [] #pretty sure I need to make a list of all the play vectors, bc. according to documentation, "The system maintains a set of play vectors and the vector will be removed from the list if the vector or var is destroyed."
for inh_gid in net.inh_gidList:
norm_vec = TC_INH_AMPA_D2_vec / config.pc.gid2cell(inh_gid).k_TC_IN
TC_INH_veclist.append(norm_vec)
TC_INH_veclist[-1].play(config.pc.gid2cell(inh_gid).synlist[0]._ref_gmax, t, True)
PYR_INH_AMPA_D2_vec=h.Vector(config.pyr2inh_ampa_d2_str * np.array([config.awake_AMPA_cort,config.s2_AMPA_cort,config.s2_AMPA_cort,config.s3_AMPA_cort,config.s3_AMPA_cort,config.rem_AMPA_cort,config.rem_AMPA_cort,config.s2_AMPA_cort,config.s2_AMPA_cort]))
PYR_INH_veclist = []
for inh_gid in net.inh_gidList:
norm_vec = PYR_INH_AMPA_D2_vec / config.pc.gid2cell(inh_gid).k_PY_IN_AMPA
PYR_INH_veclist.append(norm_vec)
PYR_INH_veclist[-1].play(config.pc.gid2cell(inh_gid).synlist[1]._ref_gmax, t, True)
#next one addresses scaling of PYR->PYR connections only
PYR_PYR_AMPA_D2_vec=h.Vector( config.pyr2pyr_ampa_d2_str * np.array([config.awake_AMPA_pyrpyr,config.s2_AMPA_pyrpyr,config.s2_AMPA_pyrpyr,config.s3_AMPA_pyrpyr,config.s3_AMPA_pyrpyr,config.rem_AMPA_pyrpyr,config.rem_AMPA_pyrpyr,config.s2_AMPA_pyrpyr,config.s2_AMPA_pyrpyr]))
PYR_PYR_veclist = []
for pyr_gid in net.pyr_gidList:
norm_vec = PYR_PYR_AMPA_D2_vec / config.pc.gid2cell(pyr_gid).k_PY_PY_AMPA
PYR_PYR_veclist.append(norm_vec)
PYR_PYR_veclist[-1].play(config.pc.gid2cell(pyr_gid).synlist[1]._ref_gmax, t, True)
#this last entry is for changing GABA_A_D2 strength for INH->PYR connections alone
INH_PYR_GABA_D2_vec=h.Vector(config.inh2pyr_gaba_a_d2_str * np.array([config.awake_GABA_D2,config.s2_GABA_D2,config.s2_GABA_D2,config.s3_GABA_D2,config.s3_GABA_D2,config.rem_GABA_D2,config.rem_GABA_D2,config.s2_GABA_D2,config.s2_GABA_D2]))
INH_PYR_veclist = []
for pyr_gid in net.pyr_gidList:
norm_vec = INH_PYR_GABA_D2_vec / config.pc.gid2cell(pyr_gid).k_IN_PY
INH_PYR_veclist.append(norm_vec)
INH_PYR_veclist[-1].play(config.pc.gid2cell(pyr_gid).synlist[3]._ref_gmax, t, True)
#cellular properties
gkl_pyr_vec=h.Vector( np.array([config.gkl_pyr_awake,config.gkl_pyr_s2,config.gkl_pyr_s2,config.gkl_pyr_s3,config.gkl_pyr_s3,config.gkl_pyr_rem,config.gkl_pyr_rem,config.gkl_pyr_s2,config.gkl_pyr_s2]))
for pyr_gid in net.pyr_gidList:
gkl_pyr_vec.play(config.pc.gid2cell(pyr_gid).dend(0.5)._ref_gkL_kL, t, True) #'True' makes it so values are linearly interpolated between points specified in vectors
gkl_inh_vec=h.Vector( np.array([config.gkl_inh_awake,config.gkl_inh_s2,config.gkl_inh_s2,config.gkl_inh_s3,config.gkl_inh_s3,config.gkl_inh_rem,config.gkl_inh_rem,config.gkl_inh_s2,config.gkl_inh_s2]))
for inh_gid in net.inh_gidList:
gkl_inh_vec.play(config.pc.gid2cell(inh_gid).dend(0.5)._ref_gkL_kL, t, True)
gkl_TC_vec=h.Vector( np.array([config.gkl_TC_awake,config.gkl_TC_s2,config.gkl_TC_s2,config.gkl_TC_s3,config.gkl_TC_s3,config.gkl_TC_rem,config.gkl_TC_rem,config.gkl_TC_s2,config.gkl_TC_s2]))
for tc_gid in net.tc_gidList:
gkl_TC_vec.play(config.pc.gid2cell(tc_gid).soma(0.5)._ref_gkL_kL, t, True)
gkl_RE_vec=h.Vector( np.array([config.gkl_RE_awake,config.gkl_RE_s2,config.gkl_RE_s2,config.gkl_RE_s3,config.gkl_RE_s3,config.gkl_RE_rem,config.gkl_RE_rem,config.gkl_RE_s2,config.gkl_RE_s2]))
for re_gid in net.re_gidList:
gkl_RE_vec.play(config.pc.gid2cell(re_gid).soma(0.5)._ref_gkL_kL, t, True)
gh_TC_vec=h.Vector([config.gh_TC_awake,config.gh_TC_s2,config.gh_TC_s2,config.gh_TC_s3,config.gh_TC_s3,config.gh_TC_rem,config.gh_TC_rem,config.gh_TC_s2,config.gh_TC_s2])
for tc_gid in net.tc_gidList:
gh_TC_vec.play(config.pc.gid2cell(tc_gid).soma(0.5)._ref_fac_gh_TC_iar, t, True)
'''set up custom initialization to set RE voltage values'''
def set_RE_voltages():
#loop through all RE cells and set their initial voltages to -65 mV
for re_gid in net.re_gidList:
config.pc.gid2cell(re_gid).soma.v=-65
if(h.cvode.active()):
h.cvode.re_init()
else:
h.fcurrent()
h.frecord_init()
# run sim and gather spikes
config.pc.set_maxstep(10) #see https://www.neuron.yale.edu/neuron/static/new_doc/modelspec/programmatic/network/parcon.html#ParallelContext.set_maxstep, as well as section 2.4 of the Lytton/Salvador paper
h.dt = 0.025
fih = h.FInitializeHandler(set_RE_voltages)
h.finitialize(-68) #set initial voltages of all cells *except RE cells* to -68 mV
#h.stdinit()
if config.idhost==0:
print('Running sim...')
startTime = datetime.now() # store sim start time
raster_file = open("raster_nhost=%g.txt"%(config.nhost), 'w') # prepare file to print raster data to file
lfp_file = open("lfp_nhost=%g.txt"%(config.nhost),'w') #prepare to print biophysical lfp data to file
vcort_file = open("vcort_nhost=%g.txt"%(config.nhost), 'w') #prepare to print summed intracellular voltage trace data to file
#actually run the simulation on all nodes
t_curr = 0
while (t_curr < config.duration-h.dt): # include the '-dt' to account for rounding error; otherwise, may get error in writeVoltages
#step forward in periods of 'config.t_seg', and dump data to file after each step forward. Then resize vectors, so program does not run out of memory
if(t_curr + config.t_seg < config.duration):
config.pc.psolve(t_curr+config.t_seg)
if config.idhost==0: print("Numerically integrated through %.2f ms"%(t_curr+config.t_seg))
else:
config.pc.psolve(config.duration)
net.gatherSpikes() # gather spikes from all nodes onto master node
if config.doextra: net.gatherLFP() #gather LFP data
if(config.idhost==0):
for i in range(len(net.tVecAll)): #print raster data to file
raster_file.write("%.3f %g\n" % (net.tVecAll[i], net.idVecAll[i])) # use the bash command 'sort -k 1n,1n -k 2n,2n raster_nhost=4 > raster_nhost=4_sorted' to sort the raster plots when nhost>1
net.tVecAll = [] #reset the raster vectors that aggregate the results from all nodes, so they do not grow too large as the simulation progresses
net.idVecAll = []
if config.doextra:
for i in range(len(net.v_sum)): #print cortical voltage data to file
vcort_file.write("%.3f \n" % net.v_sum[i])
for i in range(len(net.lfp_sum)): #print LFP data to file
lfp_file.write("%.3f \n" % net.lfp_sum[i])
net.tVec.resize(0) #reset the raster vectors on each individual node, so they do not grow too large as the simulation progresses
net.idVec.resize(0)
if config.doextra:
config.v_rec=[] #reset list to being empty
config.lfp_rec=[] #reset list to being empty
t_curr = t_curr + config.t_seg
if config.idhost==0:
raster_file.close() #close raster file
lfp_file.close()
vcort_file.close()
runTime = (datetime.now() - startTime).total_seconds() # calculate run time
print("Run time for %d sec sim = %.2f sec"%(int(config.duration/1000.0), runTime) )
# plot net raster, save net data and plot cell 0 traces
'''net.gatherSpikes() # gather spikes from all nodes onto master node
if config.idhost==0: #if statement because we don't want every host to plot and save data (that would be redundant)
#to plot raster data, we need to load the file, then define tVecAll and idVecAll lists so that they exist when the plotRaster method is called
rasterdat=np.loadtxt("raster_nhost=%g.txt"%(config.nhost))
if(len(rasterdat)>0):
net.tVecAll=list(rasterdat[:,0])
net.idVecAll=list(rasterdat[:,1])
#net.plotRaster()
else:
print("No cells spiked, so there is no raster plot to display.")
net.saveData()
#net.cells[plot_cell].plotTraces()'''
del net
if config.doextra:
h.cvode.extra_scatter_gather_remove(recording_callback) #removes 'callback', so that we don't have more and more callbacks on progressive iterations
onerun(config.randSeed,config.Npyr,config.Ninh,config.Nre,config.Ntc)
config.pc.barrier()
h.quit()