""" class_net.py - establishes the Network class and related methods
Copyright (C) 2013 Shane Lee and Stephanie Jones
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
import itertools as it
from neuron import h as nrn
import numpy as np
import feed
from L2_pyramidal import L2Pyr
from L5_pyramidal import L5Pyr
from L2_basket import L2Basket
from L5_basket import L5Basket
# creates the external feed params based on individual simulation params p
def create_pext(p, tstop):
# indexable py list of param dicts for parallel
# turn off individual feeds by commenting out relevant line here.
# always valid, no matter the length
p_ext = []
# p_unique is a dict of input param types that end up going to each cell uniquely
p_unique = {}
# default params
feed_prox = {
'f_input': p['f_input_prox'],
't0': p['t0_input_prox'],
'tstop': p['tstop_input_prox'],
'stdev': p['f_stdev_prox'],
'L2Pyr_ampa': (p['input_prox_A_weight_L2Pyr_ampa'], p['input_prox_A_delay_L2']),
'L2Pyr_nmda': (p['input_prox_A_weight_L2Pyr_nmda'], p['input_prox_A_delay_L2']),
'L5Pyr_ampa': (p['input_prox_A_weight_L5Pyr_ampa'], p['input_prox_A_delay_L5']),
'L5Pyr_nmda': (p['input_prox_A_weight_L5Pyr_nmda'], p['input_prox_A_delay_L5']),
'L2Basket_ampa': (p['input_prox_A_weight_inh_ampa'], p['input_prox_A_delay_L2']),
'L2Basket_nmda': (p['input_prox_A_weight_inh_nmda'], p['input_prox_A_delay_L2']),
'L5Basket_ampa': (p['input_prox_A_weight_inh_ampa'], p['input_prox_A_delay_L5']),
'L5Basket_nmda': (p['input_prox_A_weight_inh_nmda'], p['input_prox_A_delay_L5']),
'events_per_cycle': p['events_per_cycle_prox'],
'distribution': p['distribution_prox'],
'lamtha': 100.,
'loc': 'proximal',
# ensures time interval makes sense
p_ext = feed_validate(p_ext, feed_prox, tstop)
feed_dist = {
'f_input': p['f_input_dist'],
't0': p['t0_input_dist'],
'tstop': p['tstop_input_dist'],
'stdev': p['f_stdev_dist'],
'L2Pyr_ampa': (p['input_dist_A_weight_L2Pyr_ampa'], p['input_dist_A_delay_L2']),
'L2Pyr_nmda': (p['input_dist_A_weight_L2Pyr_nmda'], p['input_dist_A_delay_L2']),
'L5Pyr_ampa': (p['input_dist_A_weight_L5Pyr_ampa'], p['input_dist_A_delay_L5']),
'L5Pyr_nmda': (p['input_dist_A_weight_L5Pyr_nmda'], p['input_dist_A_delay_L5']),
'L2Basket_ampa': (p['input_dist_A_weight_inh_ampa'], p['input_dist_A_delay_L2']),
'L2Basket_nmda': (p['input_dist_A_weight_inh_nmda'], p['input_dist_A_delay_L2']),
'events_per_cycle': p['events_per_cycle_dist'],
'distribution': p['distribution_dist'],
'lamtha': 100.,
'loc': 'distal',
p_ext = feed_validate(p_ext, feed_dist, tstop)
# Create evoked response parameters
# f_input needs to be defined as 0
# these vals correspond to non-perceived max
# conductance threshold in uS (Jones et al. 2007)
p_unique['evprox0'] = {
't0': p['t_evprox_early'],
'L2_pyramidal': (p['gbar_evprox_early_L2Pyr'], 0.1, p['sigma_t_evprox_early']),
'L2_basket': (p['gbar_evprox_early_L2Basket'], 0.1, p['sigma_t_evprox_early']),
'L5_pyramidal': (p['gbar_evprox_early_L5Pyr'], 1., p['sigma_t_evprox_early']),
'L5_basket': (p['gbar_evprox_early_L5Basket'], 1., p['sigma_t_evprox_early']),
'lamtha_space': 3.,
'loc': 'proximal',
# see if relative start time is defined
if p['dt_evprox0_evdist'] == -1:
# if dt is -1, assign the input time based on the input param
t0_evdist = p['t_evdist']
# use dt to set the relative timing
t0_evdist = p_unique['evprox0']['t0'] + p['dt_evprox0_evdist']
# relative timing between evprox0 and evprox1
# not defined by distal time
if p['dt_evprox0_evprox1'] == -1:
t0_evprox1 = p['t_evprox_late']
t0_evprox1 = p_unique['evprox0']['t0'] + p['dt_evprox0_evprox1']
# next evoked input is distal
p_unique['evdist'] = {
't0': t0_evdist,
'L2_pyramidal': (p['gbar_evdist_L2Pyr'], 0.1, p['sigma_t_evdist']),
'L5_pyramidal': (p['gbar_evdist_L5Pyr'], 0.1, p['sigma_t_evdist']),
'L2_basket': (p['gbar_evdist_L2Basket'], 0.1, p['sigma_t_evdist']),
'lamtha_space': 3.,
'loc': 'distal',
# next evoked input is proximal also
p_unique['evprox1'] = {
't0': t0_evprox1,
'L2_pyramidal': (p['gbar_evprox_late_L2Pyr'], 0.1, p['sigma_t_evprox_late']),
'L2_basket': (p['gbar_evprox_late_L2Basket'], 0.1, p['sigma_t_evprox_late']),
'L5_pyramidal': (p['gbar_evprox_late_L5Pyr'], 5., p['sigma_t_evprox_late']),
'L5_basket': (p['gbar_evprox_late_L5Basket'], 5., p['sigma_t_evprox_late']),
'lamtha_space': 3.,
'loc': 'proximal',
# this needs to create many feeds
# (amplitude, delay, mu, sigma). ordered this way to preserve compatibility
p_unique['extgauss'] = {
'stim': 'gaussian',
'L2_basket': (p['L2Basket_Gauss_A_weight'], 1., p['L2Basket_Gauss_mu'], p['L2Basket_Gauss_sigma']),
'L2_pyramidal': (p['L2Pyr_Gauss_A_weight'], 0.1, p['L2Pyr_Gauss_mu'], p['L2Pyr_Gauss_sigma']),
'L5_basket': (p['L5Basket_Gauss_A_weight'], 1., p['L5Basket_Gauss_mu'], p['L5Basket_Gauss_sigma']),
'L5_pyramidal': (p['L5Pyr_Gauss_A_weight'], 1., p['L5Pyr_Gauss_mu'], p['L5Pyr_Gauss_sigma']),
'lamtha': 100.,
'loc': 'proximal'
# define T_pois as 0 or -1 to reset automatically to tstop
if p['T_pois'] in (0, -1):
p['T_pois'] = tstop
# Poisson distributed inputs to proximal
p_unique['extpois'] = {
'stim': 'poisson',
'L2_basket': (p['L2Basket_Pois_A_weight'], 1., p['L2Basket_Pois_lamtha']),
'L2_pyramidal': (p['L2Pyr_Pois_A_weight'], 0.1, p['L2Pyr_Pois_lamtha']),
'L5_basket': (p['L5Basket_Pois_A_weight'], 1., p['L5Basket_Pois_lamtha']),
'L5_pyramidal': (p['L5Pyr_Pois_A_weight'], 1., p['L5Pyr_Pois_lamtha']),
'lamtha_space': 100.,
't_interval': (p['t0_pois'], p['T_pois']),
'loc': 'proximal'
return p_ext, p_unique
# qnd function to add feeds if they are sensible
def feed_validate(p_ext, d, tstop):
# only append if t0 is less than simulation tstop
if tstop > d['t0']:
if d['tstop'] > tstop:
d['tstop'] = tstop
# if stdev is zero, increase synaptic weights 5 fold to make
# single input equivalent to 5 simultaneous input to prevent spiking
if not d['stdev'] and d['distribution'] != 'uniform':
for key in d.keys():
if key.endswith('Pyr'):
d[key] = (d[key][0] * 5., d[key][1])
elif key.endswith('Basket'):
d[key] = (d[key][0] * 5., d[key][1])
# if L5 delay is -1, use same delays as L2 unless
# L2 delay is 0.1 in which case use 1.
if d['L5Pyr_ampa'][1] == -1:
for key in d.keys():
if key.startswith('L5'):
if d['L2Pyr'][1] != 0.1:
d[key] = (d[key][0], d['L2Pyr'][1])
d[key] = (d[key][0], 1.)
return p_ext
# create Network class on each node of the sim
class NetworkOnNode():
def __init__(self, p):
# set the params internally for this net
# better than passing it around
self.p = p
# Number of time points
# Originally used to create the empty vec for synaptic currents,
# ensuring that they exist on this node irrespective of whether
# or not cells of relevant type actually do
self.N_t = np.arange(0., nrn.tstop, self.p['dt']).size + 1
# Create a nrn.Vector() with size 1xself.N_t, zero'd
self.current = {
'L5Pyr_soma': nrn.Vector(self.N_t, 0),
'L2Pyr_soma': nrn.Vector(self.N_t, 0),
# int variables for grid of pyramidal cells (for now in both L2 and L5)
self.gridpyr = {
'x': self.p['N_pyr_x'],
'y': self.p['N_pyr_y'],
# Parallel stuff
self.pc = nrn.ParallelContext()
self.n_hosts = int(self.pc.nhost())
self.rank = int(self.pc.id())
self.N_src = 0
# numbers of sources
self.N = {}
# init self.N_cells
self.N_cells = 0
# zdiff is expressed as a positive DEPTH of L5 relative to L2
# this is a deviation from the original, where L5 was defined at 0
# this should not change interlaminar weight/delay calculations
self.zdiff = 1307.4
# params of external inputs in p_ext
# Global number of external inputs ... automatic counting makes more sense
# p_unique represent ext inputs that are going to go to each cell
self.p_ext, self.p_unique = create_pext(self.p, nrn.tstop)
self.N_extinput = len(self.p_ext)
# Source list of names
# in particular order (cells, extinput, alpha names of unique inputs)
self.src_list_new = self.__create_src_list()
# cell position lists, also will give counts: must be known by ALL nodes
# extinput positions are all located at origin.
# sort of a hack bc of redundancy
self.pos_dict = dict.fromkeys(self.src_list_new)
# create coords in pos_dict for all cells first
# create coords for all other sources
# count external sources
# create dictionary of GIDs according to cell type
# global dictionary of gid and cell type
self.gid_dict = {}
# assign gid to hosts, creates list of gids for this node in __gid_list
# __gid_list length is number of cells assigned to this id()
self.__gid_list = []
# create cells (and create self.origin in create_cells_pyr())
self.cells_list = []
self.extinput_list = []
# external unique input list dictionary
self.ext_list = dict.fromkeys(self.p_unique)
# initialize the lists in the dict
for key in self.ext_list.keys():
self.ext_list[key] = []
# create sources and init
# parallel network connector
# set to record spikes
self.spiketimes = nrn.Vector()
self.spikegids = nrn.Vector()
# aggregate recording all the somatic voltages for pyr
def aggregate_currents(self):
""" this method must be run post-integration
for cell in self.cells_list:
# check for celltype
if cell.celltype == 'L5_pyramidal':
# iterate over somatic currents, assumes this list exists
# is guaranteed in L5Pyr()
for key, I_soma in cell.dict_currents.items():
# self.current_L5Pyr_soma was created upon
# in parallel, each node has its own Net()
elif cell.celltype == 'L2_pyramidal':
for key, I_soma in cell.dict_currents.items():
# self.current_L5Pyr_soma was created upon
# in parallel, each node has its own Net()
# reverse lookup of gid to type
def gid_to_type(self, gid):
for gidtype, gids in self.gid_dict.items():
if gid in gids:
return gidtype
# cell counting routine
def __count_cells(self):
# cellname list is used *only* for this purpose for now
for src in self.cellname_list:
# if it's a cell, then add the number to total number of cells
self.N[src] = len(self.pos_dict[src])
self.N_cells += self.N[src]
# count all of the external sources
def __count_extsrcs(self):
""" general counting method requires pos_dict is correct for each source
and that all sources are represented
# all src numbers are based off of length of pos_dict entry
# generally done here in lieu of upstream changes
for src in self.extname_list:
self.N[src] = len(self.pos_dict[src])
# parallel create cells AND external inputs (feeds) - see note
def __create_all_src(self):
""" these are spike SOURCES but cells are also targets
external inputs are not targets
# loop through gids on this node
for gid in self.__gid_list:
# check existence of gid with Neuron
if self.pc.gid_exists(gid):
# get type of cell and pos via gid
# now should be valid for ext inputs
type = self.gid_to_type(gid)
type_pos_ind = gid - self.gid_dict[type][0]
pos = self.pos_dict[type][type_pos_ind]
# figure out which cell type is assoc with the gid
# create cells based on loc property
# creates a NetCon object internally to Neuron
if type == 'L2_pyramidal':
self.cells_list.append(L2Pyr(pos, self.p))
self.pc.cell(gid, self.cells_list[-1].connect_to_target(None))
# run the IClamp function here
# create_all_IClamp() is defined in L2Pyr (etc)
elif type == 'L5_pyramidal':
self.cells_list.append(L5Pyr(pos, self.p))
self.pc.cell(gid, self.cells_list[-1].connect_to_target(None))
# run the IClamp function here
elif type == 'L2_basket':
self.pc.cell(gid, self.cells_list[-1].connect_to_target(None))
# also run the IClamp for L2_basket
elif type == 'L5_basket':
self.pc.cell(gid, self.cells_list[-1].connect_to_target(None))
# run the IClamp function here
elif type == 'extinput':
# to find param index, take difference between REAL gid
# here and gid start point of the items
p_ind = gid - self.gid_dict['extinput'][0]
# now use the param index in the params and create
# the cell and artificial NetCon
self.extinput_list.append(feed.ParFeedAll(type, None, self.p_ext[p_ind], gid))
self.pc.cell(gid, self.extinput_list[-1].connect_to_target())
elif type in self.p_unique.keys():
gid_post = gid - self.gid_dict[type][0]
cell_type = self.gid_to_type(gid_post)
# create dictionary entry, append to list
self.ext_list[type].append(feed.ParFeedAll(type, cell_type, self.p_unique[type], gid))
self.pc.cell(gid, self.ext_list[type][-1].connect_to_target())
print("None of these types in Net()")
print("GID does not exist. See Cell()")
# Creates cells and grid
def __create_coords_pyr(self):
""" pyr grid is the immutable grid, origin now calculated in relation to feed
xrange = np.arange(self.gridpyr['x'])
yrange = np.arange(self.gridpyr['y'])
# create list of tuples/coords, (x, y, z)
self.pos_dict['L2_pyramidal'] = [pos for pos in it.product(xrange, yrange, [0])]
self.pos_dict['L5_pyramidal'] = [pos for pos in it.product(xrange, yrange, [self.zdiff])]
# create basket cell coords based on pyr grid
def __create_coords_basket(self):
# define relevant x spacings for basket cells
xzero = np.arange(0, self.gridpyr['x'], 3)
xone = np.arange(1, self.gridpyr['x'], 3)
# split even and odd y vals
yeven = np.arange(0, self.gridpyr['y'], 2)
yodd = np.arange(1, self.gridpyr['y'], 2)
# create general list of x,y coords and sort it
coords = [pos for pos in it.product(xzero, yeven)] + [pos for pos in it.product(xone, yodd)]
coords_sorted = sorted(coords, key=lambda pos: pos[1])
# append the z value for position for L2 and L5
self.pos_dict['L2_basket'] = [pos_xy + (0,) for pos_xy in coords_sorted]
self.pos_dict['L5_basket'] = [pos_xy + (self.zdiff,) for pos_xy in coords_sorted]
# creates origin AND creates external input coords
def __create_coords_extinput(self):
xrange = np.arange(self.gridpyr['x'])
yrange = np.arange(self.gridpyr['y'])
# origin's z component isn't really used in calculating
# distance functions from origin
# these must be ints!
origin_x = xrange[int((len(xrange)-1)/2)]
origin_y = yrange[int((len(yrange)-1)/2)]
origin_z = np.floor(self.zdiff/2)
self.origin = (origin_x, origin_y, origin_z)
self.pos_dict['extinput'] = [self.origin for i in range(self.N_extinput)]
# at this time, each of the unique inputs is per cell
for key in self.p_unique.keys():
# create the pos_dict for all the sources
self.pos_dict[key] = [self.origin for i in range(self.N_cells)]
# creates the immutable source list along with corresponding numbers of cells
def __create_src_list(self):
# base source list of tuples, name and number, in this order
self.cellname_list = [
# add the legacy extinput here
self.extname_list = []
# grab the keys for the unique set of inputs and sort the names
# append them to the src list along with the number of cells
unique_keys = sorted(self.p_unique.keys())
self.extname_list += unique_keys
# return one final source list
src_list = self.cellname_list + self.extname_list
return src_list
# creates gid dicts and pos_lists
def __create_gid_dict(self):
# initialize gid index gid_ind to start at 0
gid_ind = [0]
# append a new gid_ind based on previous and next cell count
# order is guaranteed by self.src_list_new
for i in range(len(self.src_list_new)):
# N = self.src_list_new[i][1]
# grab the src name in ordered list src_list_new
src = self.src_list_new[i]
# query the N dict for that number and append here to gid_ind, based on previous entry
# accumulate total source count
self.N_src += self.N[src]
# now actually assign the ranges
for i in range(len(self.src_list_new)):
src = self.src_list_new[i]
self.gid_dict[src] = np.arange(gid_ind[i], gid_ind[i+1])
# this happens on EACH node
def __gid_assign(self):
""" creates self.__gid_list for THIS node
# round robin assignment of gids
for gid in range(self.rank, self.N_cells, self.n_hosts):
# set the cell gid
self.pc.set_gid2node(gid, self.rank)
# now to do the cell-specific external input gids on the same proc
# these are guaranteed to exist because all of these inputs were created
# for each cell
for key in self.p_unique.keys():
gid_input = gid + self.gid_dict[key][0]
self.pc.set_gid2node(gid_input, self.rank)
# legacy handling of the external inputs
# NOT perfectly balanced
for gid_base in range(self.rank, self.N_extinput, self.n_hosts):
# shift the gid_base to the extinput gid
gid = gid_base + self.gid_dict['extinput'][0]
# set as usual
self.pc.set_gid2node(gid, self.rank)
# extremely important to get the gids in the right order
# connections in parallel - search for gids on nodes
def __parnet_connect(self):
""" this NODE is aware of its cells as targets
for each syn, return list of source GIDs.
for each item in the list, do a:
nc = pc.gid_connect(source_gid, target_syn), weight,delay
Both for synapses AND for external inputs
# loop over target zipped gids and cells
# cells_list has NO extinputs anyway. also no extgausses
for gid, cell in zip(self.__gid_list, self.cells_list):
# ignore iteration over inputs, since they are NOT targets
if self.pc.gid_exists(gid) and self.gid_to_type(gid) is not 'extinput':
# print "rank:", self.rank, "gid:", gid, cell, self.gid_to_type(gid)
# for each gid, find all the other cells connected to it, based on gid
# this MUST be defined in EACH class of cell in self.cells_list
# parconnect receives connections from other cells
# parreceive receives connections from external inputs
cell.parconnect(gid, self.gid_dict, self.pos_dict, self.p)
cell.parreceive(gid, self.gid_dict, self.pos_dict, self.p_ext)
# now do the unique inputs specific to these cells
# parreceive_ext receives connections from UNIQUE external inputs
for type in self.p_unique.keys():
p_type = self.p_unique[type]
cell.parreceive_ext(type, gid, self.gid_dict, self.pos_dict, p_type)
# setup spike recording for this node
def __record_spikes(self):
# iterate through gids on this node and
# set to record spikes in spike time vec and id vec
# agnostic to type of source, will sort that out later
for gid in self.__gid_list:
if self.pc.gid_exists(gid):
self.pc.spike_record(gid, self.spiketimes, self.spikegids)
# initializes the state closer to baseline
def __state_init(self):
for cell in self.cells_list:
seclist = nrn.SectionList()
for sect in seclist:
for seg in sect:
if cell.celltype == 'L2_pyramidal':
seg.v = -71.46
elif cell.celltype == 'L5_pyramidal':
if sect.name() == 'L5Pyr_apical_1':
seg.v = -71.32
elif sect.name() == 'L5Pyr_apical_2':
seg.v = -69.08
elif sect.name() == 'L5Pyr_apical_tuft':
seg.v = -67.30
seg.v = -72.
elif cell.celltype == 'L2_basket':
seg.v = -64.9737
elif cell.celltype == 'L5_basket':
seg.v = -64.9737