from matplotlib import pyplot
import random
from datetime import datetime
import pickle
from neuron import h, gui
class Cell(object):
def __init__(self):
self.synlist = []
self.createSections()
self.buildTopology()
self.defineGeometry()
self.defineBiophysics()
self.createSynapses()
self.nclist = []
def createSections(self):
pass
def buildTopology(self):
pass
def defineGeometry(self):
"""Set the 3D geometry of the cell."""
self.soma.L = 18.8
self.soma.diam = 18.8
self.soma.Ra = 123.0
def defineBiophysics(self):
pass
def createSynapses(self):
"""Add an exponentially decaying synapse """
syn = h.ExpSyn(self.soma(0.5))
syn.tau = 2
syn.e = 0
self.synlist.append(syn) # synlist is defined in Cell
def associateGid (self):
pc.set_gid2node(self.gid, idhost)
nc = h.NetCon(self.soma(0.5)._ref_v, None, sec=self.soma)
nc.threshold = 10
pc.cell(self.gid, nc)
del nc # discard netcon
def createNetcon(self, thresh=10):
""" created netcon to record spikes """
nc = h.NetCon(self.soma(0.5)._ref_v, None, sec = self.soma)
nc.threshold = thresh
return nc
def createStim(self, number=1e9, start=1, noise=0.5, rate=50, weight=1, delay=5):
self.stim = h.NetStim()
self.stim.number = number
self.stim.start = start
self.stim.noise = noise
self.stim.interval = 1000.0/rate
self.ncstim = h.NetCon(self.stim, self.synlist[0])
self.ncstim.delay = delay
self.ncstim.weight[0] = noise # NetCon weight is a vector.
def connect2Source(self, sourceCell, thresh=10):
"""Make a new NetCon with the source cell's membrane
potential at the soma as the source (i.e. the spike detector)
onto the target (i.e. a synapse on this cell)."""
nc = h.NetCon(sourceCell.soma(1)._ref_v, self.synlist[0], sec = sourceCell.soma)
nc.threshold = thresh
return nc
def setRecording(self):
"""Set soma, dendrite, and time recording vectors on the cell. """
self.soma_v_vec = h.Vector() # Membrane potential vector at soma
self.tVec = h.Vector() # Time stamp vector
self.soma_v_vec.record(self.soma(0.5)._ref_v)
self.tVec.record(h._ref_t)
def plotTraces(self):
"""Plot the recorded traces"""
pyplot.figure(figsize=(8,4)) # Default figsize is (8,6)
somaPlot = pyplot.plot(self.tVec, self.soma_v_vec, color='black')
pyplot.legend(somaPlot, ['soma'])
pyplot.xlabel('time (ms)')
pyplot.ylabel('mV')
pyplot.title('Cell %d voltage trace'%(self.gid))
pyplot.show()
#pyplot.savefig('traces')
class HHCell(Cell):
"""HH cell: A soma with active channels"""
def createSections(self):
"""Create the sections of the cell."""
self.soma = h.Section(name='soma', cell=self)
def defineBiophysics(self):
"""Assign the membrane properties across the cell."""
# Insert active Hodgkin-Huxley current in the soma
self.soma.insert('hh')
self.soma.gnabar_hh = 0.12 # Sodium conductance in S/cm2
self.soma.gkbar_hh = 0.036 # Potassium conductance in S/cm2
self.soma.gl_hh = 0.003 # Leak conductance in S/cm2
self.soma.el_hh = -70 # Reversal potential in mV
class Net:
"""Creates Network of N neurons (using parallelContext)
Connectivity and stimulation params provided as arguments
Also ncludes methods to gather and plot spikes
"""
def __init__(self, N=5, cellType=HHCell, connParams={}, stimParams={}):
"""
N: Number of cells.
cellType: class of cell type
connParams: dict of connectivity params
stimParams: dict of stimulation params
"""
self.cellType = cellType
self.N = N # number of cells
self.connParams = connParams # connectivity params
self.stimParams = stimParams # backgroudn stim params
self.cells = [] # List of Cell objects in the net
self.nclist = [] # List of NetCon in the net
self.tVec = h.Vector() # spike time of all cells
self.idVec = h.Vector() # cell ids of spike times
self.createNet() # Actually build the net
def createNet(self):
"""Create, layout, and connect N cells."""
self.setGids() #### set global ids (gids), used to connect cells
self.createCells()
self.connectCells()
self.createStims()
def setGids(self):
self.gidList = []
#### Round-robin counting. Each host as an id from 0 to nhost - 1.
for i in range(idhost, self.N, nhost):
self.gidList.append(i)
def createCells(self):
"""Create and layout cells (in this host) in the network."""
self.cells = []
for i in self.gidList: #### Loop over cells in this node/host
cell = self.cellType() # dynamically create cell object
self.cells.append(cell) # add cell object to net cell list
cell.gid = i # assign gid (can be any unique integer)
cell.associateGid() # associated gid to each cell
pc.spike_record(cell.gid, self.tVec, self.idVec) # Record spikes of this cell
print 'Created cell %d on host %d out of %d'%(i, idhost, nhost)
def connectCells(self):
"""Connect cells"""
connType = self.connParams['type']
if connType == 'rand':
weight = self.connParams['weight']
delayMean = self.connParams['delayMean']
delayVar = self.connParams['delayVar']
delayMin = self.connParams['delayMin']
maxConnsPerCell = self.connParams['maxConnsPerCell']
self.nclist = []
## create random delays
random.seed(randSeed) # Reset random number generator
randDelays = [max(delayMin, random.gauss(delayMean, delayVar)) for pre in range(maxConnsPerCell*self.N)] # select random delays based on mean and var params
## loop over postsyn gids in this host
for postCell in self.cells:
preGids = [gid for gid in self.gidList if gid != postCell.gid] # get list of presyn cell gids (omit post to prevent self connection)
randPreGids = random.sample(preGids, random.randint(0, min(maxConnsPerCell, len(preGids))))
for preGid in randPreGids: # for each presyn cell
nc = pc.gid_connect(preGid, postCell.synlist[0]) # create NetCon by associating pre gid to post synapse
nc.weight[0] = weight
nc.delay = randDelays.pop()
nc.threshold = 10
self.nclist.append((preGid,postCell.gid,nc))
postCell.nclist.append((preGid,nc))
print 'Created conn between pregid %d and postgid %d on host %d'%(preGid, postCell.gid, idhost)
def createStims(self):
"""Connect a spiking generator to the first cell to get
the network going."""
#### Only continue if the first cell is not on this host
for cell in self.cells:
cell.createStim(
noise=self.stimParams['noise'],
rate=self.stimParams['rate'],
weight=self.stimParams['weight'],
delay=self.stimParams['delay'])
print 'Created stim on cell %d on host %d'%(cell.gid, idhost)
def gatherSpikes(self):
"""Gather spikes from all nodes/hosts"""
if idhost==0: print 'Gathering spikes ...'
data = [None]*nhost
data[0] = {'tVec': self.tVec, 'idVec': self.idVec}
pc.barrier()
gather=pc.py_alltoall(data)
pc.barrier()
self.tVecAll = []
self.idVecAll = []
if idhost==0:
for d in gather:
self.tVecAll.extend(list(d['tVec']))
self.idVecAll.extend(list(d['idVec']))
def plotRaster(self):
print 'Plotting raster ...'
pyplot.figure()
pyplot.scatter(self.tVecAll,self.idVecAll,marker=".",s=1,color='blue')
pyplot.xlabel('Time (ms)')
pyplot.ylabel('Cell ID')
pyplot.title('Raster Plot of Network with 500 HH Cells')
pyplot.xlim(0,max(self.tVecAll))
pyplot.ylim(0,self.N)
pyplot.show()
#pyplot.savefig('raster')
def saveData(self):
print 'Savind data ...'
dataSave = {'N': self.N, 'connParams': self.connParams, 'stimParams': self.stimParams, 'tVec': self.tVecAll, 'idVec': self.idVec}
with open('output.pkl', 'wb') as f:
pickle.dump(dataSave, f)
#### New ParallelContext object
pc = h.ParallelContext()
pc.set_maxstep(10)
idhost = int(pc.id())
nhost = int(pc.nhost())
# set randomizer seed
randSeed = 1
# create network
net = Net(N=500, cellType=HHCell,
connParams={'type': 'rand', 'weight': 0.004, 'delayMean': 13.0, 'delayVar': 1.4, 'delayMin': 0.2, 'maxConnsPerCell': 20},
stimParams={'rate': 50, 'noise': 0.5, 'weight': 50, 'delay':5})
# set voltage recording for cell 0 in net
net.cells[0].setRecording()
# run sim and gather spikes
h.stdinit()
h.dt = 0.025
duration = 1000.0
if idhost==0:
print 'Running sim...'
startTime = datetime.now() # store sim start time
pc.psolve(duration) # actually run the sim in all nodes
if idhost==0:
runTime = (datetime.now() - startTime).total_seconds() # calculate run time
print "Run time for %d sec sim = %.2f sec"%(int(duration/1000.0), runTime)
net.gatherSpikes() # gather spikes from all nodes onto master node
# plot net raster, save net data and plot cell 0 traces
if idhost==0:
net.plotRaster()
net.saveData()
net.cells[0].plotTraces()
pc.done()