'''
written by Robert Egger (robert.egger@tuebingen.mpg.de)
(c) 2013-2015 Max Planck Society
'''
#from neuron import h, nrn
import numpy as np
import synapse
from collections import Sequence
import neuron
nrn = neuron.nrn
h = neuron.h
class Cell(object):
'''
Cell object providing morphological information
and hoc interface
'''
def __init__(self):
'''
Constructor:
self.id = None
self.soma = None
TODO: implement trees in python to avoid
NEURON section stack problems that may occur
during use of SectionLists
tree and branches are set up by CellParser
self.tree = None
branches are all processes attached to soma
self.branches = {}
structures are all processes with different labels
(e.g., Dendrite, ApicalDendrite, ApicalTuft, Myelin etc..
self.structures = {}
simply list of all sections
self.sections = []
self.synapses = {}
TODO: this should be read from parameter file (e_pas)
self.E = -70.0
'''
self.id = None
self.soma = None
# TODO: implement trees in python to avoid
# NEURON section stack problems that may occur
# during use of SectionLists
# tree and branches are set up by CellParser
self.tree = None
# branches are all processes attached to soma
self.branches = {}
# structures are all processes with different labels
# (e.g., Dendrite, ApicalDendrite, ApicalTuft, Myelin etc..
self.structures = {}
# simply list of all sections
self.sections = []
self.synapses = {}
# TODO: this should be read from parameter file (e_pas)
self.E = -70.0
self.changeSynParamDict = {}
def re_init_cell(self):
'''re-initialize for next simulation run'''
for sec in self.sections:
sec._re_init_vm_recording()
sec._re_init_range_var_recording()
for synType in self.synapses.keys():
for syn in self.synapses[synType]:
syn.disconnect_hoc_synapse()
def record_range_var(self, var, mech=None):
for sec in self.sections:
sec._init_range_var_recording(var, mech)
def distance_between_pts(self, sec1, x1, sec2, x2):
'''computes path length between to points given by
location x1 and x2 in sections sec1 and sec2
or by ptID x1 and x2 and section ID sec1 and sec2.
for now, use built-in NEURON method distance.
This should probably be changed at some point in the future
into something like a look-up table of pair-wise distances,
because repeated computations for 1000s of synapses are going
to be very expensive... also, only approximate, since NEURON
computes distances between centers of segments'''
if isinstance(sec1, int):
sec1 = self.sections[sec1]
sec2 = self.sections[sec2]
x1 = sec1.relPts(x1)
x2 = sec2.relPts(x2)
# set origin
silent = h.distance(0, x1, sec=sec1)
return h.distance(x2, sec=sec2)
def distance_to_soma(self, sec, x):
'''computes path length between soma and point given by
location x in sections sec or by ptID x and section ID sec'''
# assume the user knows what they're doing...
if isinstance(sec, int):
return self.distance_between_pts(self.soma.secID, 0, sec, x)
else:
return self.distance_between_pts(self.soma, 0.5, sec, x)
def add_synapse(self, secID, ptID, ptx, preType='Generic', postType='Generic'):
if not self.synapses.has_key(preType):
self.synapses[preType] = []
newSyn = synapse.Synapse(secID, ptID, ptx, preType, postType)
newSyn.coordinates = np.array(self.sections[secID].pts[ptID])
self.synapses[preType].append(newSyn)
return self.synapses[preType][-1]
def remove_synapses(self, preType=None):
if preType is None:
return
# remove all
if preType == 'All' or preType == 'all':
for synType in self.synapses.keys():
synapses = self.synapses[synType]
del synapses[:]
del self.synapses[synType]
return
# only one type
else:
try:
synapses = self.synapses[preType]
del synapses[:]
del self.synapses[preType]
except KeyError:
print 'Synapses of type ' + preType + ' not present on cell'
return
def change_synapse_parameters(self):
'''
Change parameters of synapses during simulation.
self.changeSynParamDict is dictionary of network parameter sets
with keys corresponding to event times.
This allows automatic update of parameter sets
according to their relative timing.
'''
raise NotImplementedError('Synapse parameter change does not work correctly with VecStim!')
# eventList = self.changeSynParamDict.keys()
# eventList.sort()
# print 'Cell %s: event at t = %.2f' % (self.id ,eventList[0])
# tChange = eventList[0]
# newParamDict = self.changeSynParamDict.pop(eventList[0])
# synCnt = 0
# for synType in newParamDict.keys():
# print '\tchanging parameters for synapses of type %s' % synType
# for syn in self.synapses[synType]:
# synCnt += 1
# if not syn.is_active():
# continue
# #===============================================================
# # re-compute release times in case release probability changes
# #===============================================================
# preChange = False
# for t in syn.preCell.spikeTimes:
# if t >= tChange:
# preChange = True
# break
# if preChange:
# changeBin = None
# for i in range(len(syn.releaseSite.spikeTimes)):
# if syn.releaseSite.spikeTimes[i] >= tChange:
# changeBin = i
# if changeBin is not None:
# print '\tdetermine new release times for synapse %d of type %s' % (synCnt-1, synType)
# print '\t\told VecStim: %s' % (syn.releaseSite.spikes)
# del syn.releaseSite.spikeTimes[changeBin:]
# syn.releaseSite.spikes.play()
# syn.releaseSite.spikeVec.resize(0)
# prelNew = newParamDict[synType].synapses.releaseProb
# newSpikes = []
# for t in syn.preCell.spikeTimes:
# if t >= tChange:
# if np.random.rand() < prelNew:
## syn.releaseSite.append(t)
# newSpikes.append(t)
# syn.releaseSite.spikeTimes.append(t)
# print '\t\tnew release time %.2f' % (t)
# if len(newSpikes):
# print '\t\told NetCon: %s' % (syn.netcons[0])
# print '\t\told NetCon valid: %d' % (syn.netcons[0].valid())
# del syn.netcons[0]
## syn.netcons = []
# print '\t\tcreating new VecStim'
# del syn.releaseSite.spikes
# syn.releaseSite.spikes = h.VecStim()
# print '\t\tupdating SpikeVec:'
# del syn.releaseSite.spikeVec
# syn.releaseSite.spikeVec = h.Vector(newSpikes)
# tRelStr = '\t\t'
# for t in syn.releaseSite.spikeVec:
# tRelStr += str(t)
# tRelStr += ', '
# print tRelStr
# print '\t\tactivating new VecStim %s' % (syn.releaseSite.spikes)
# syn.releaseSite.spikes.play(syn.releaseSite.spikeVec)
# print '\t\tupdated VecStim %s with %d new spike times' % (syn.releaseSite.spikes, len(newSpikes))
#
# newSyn = synapse.Synapse(syn.secID, syn.ptID, syn.x, syn.preCellType, syn.postCellType)
# newSyn.coordinates = np.array(self.sections[syn.secID].pts[syn.ptID])
# newSyn.weight = syn.weight
# newSyn.activate_hoc_syn(syn.releaseSite, syn.preCell, self, newParamDict[synType].synapses.receptors)
#
# syn.netcons = []
# syn.receptors = {}
# forget = syn
# syn = newSyn
# del forget
# #===============================================================
# # update biophysical parameters and NetCon
# #===============================================================
## for recepStr in newParamDict[synType].synapses.receptors.keys():
## recep = newParamDict[synType].synapses.receptors[recepStr]
## hocStr = 'h.'
## hocStr += recepStr
## hocStr += '(x, sec=hocSec)'
## newSyn = eval(hocStr)
## del syn.receptors[recepStr]
## syn.receptors[recepStr] = newSyn
## for paramStr in recep.parameter.keys():
## #===========================================================
## # try treating parameters as NMODL range variables,
## # then as (global) NMODL parameters
## #===========================================================
## try:
## valStr = str(recep.parameter[paramStr])
## cmd = 'syn.receptors[\'' + recepStr + '\'].' + paramStr + '=' + valStr
### print 'setting %s for synapse of type %s' % (cmd, synType)
## exec(cmd)
## except LookupError:
## cmd = paramStr + '_' + recepStr + '='
## cmd += str(recep.parameter[paramStr])
### print 'setting %s for synapse of type %s' % (cmd, synType)
## h(cmd)
## threshParam = float(recep.threshold)
## delayParam = float(recep.delay)
## newNetcon = h.NetCon(syn.releaseSite.spikes, syn.receptors[recepStr])
### print '\t\told NetCon: %s' % (syn.netcons[0])
### print '\t\told NetCon valid: %d' % (syn.netcons[0].valid())
### del syn.netcons[0]
### syn.netcons = []
## syn.netcons = [newNetcon]
## print '\t\tnew NetCon: %s' % (syn.netcons[0])
## print '\t\tnew NetCon valid: %d' % (syn.netcons[0].valid())
## syn.netcons[0].threshold = threshParam
## syn.netcons[0].delay = delayParam
## if isinstance(recep.weight, Sequence):
## for i in range(len(recep.weight)):
## syn.netcons[0].weight[i] = recep.weight[i]
## else:
## syn.netcons[0].weight[0] = recep.weight
class PySection(nrn.Section):
'''
Subclass of nrn.Section providing additional geometric
and mechanism information/ handling methods
'''
def __init__(self, name=None, cell=None, label=None):
'''
structure
self.label = label
reference to parent section
self.parent = None
connection point at parent section
self.parentx = 1.0
bounding box around 3D coordinates
self.bounds = ()
number of traced 3D coordinates
self.nrOfPts = 0
list of traced 3D coordinates
self.pts = []
list of relative position of 3D points along section
self.relPts = []
list of diameters at traced 3D coordinates
self.diamList = []
total area of all NEURON segments in this section
self.area = 0.0
list of center points of segments used during simulation
self.segPts = []
list of x values corresponding to center of each segment
for looping akin to the hoc function for(x) (without 0 and 1)
self.segx = []
list of diameters of all segments
self.segDiams = []
list of neuron Vectors recording voltage in each compartment
self.recVList = []
dict of range variables recorded
self.recordVars = {}
'''
if name is None:
name = ''
if cell is None:
nrn.Section.__init__(self)
else:
nrn.Section.__init__(self, name=name, cell=cell)
'''structure'''
self.label = label
'''reference to parent section'''
self.parent = None
'''connection point at parent section'''
self.parentx = 1.0
'''bounding box around 3D coordinates'''
self.bounds = ()
'''number of traced 3D coordinates'''
self.nrOfPts = 0
'''list of traced 3D coordinates'''
self.pts = []
'''list of relative position of 3D points along section'''
self.relPts = []
'''list of diameters at traced 3D coordinates'''
self.diamList = []
'''total area of all NEURON segments in this section'''
self.area = 0.0
'''list of center points of segments used during simulation'''
self.segPts = []
'''list of x values corresponding to center of each segment
for looping akin to the hoc function for(x) (without 0 and 1)'''
self.segx = []
'''list of diameters of all segments'''
self.segDiams = []
'''list of neuron Vectors recording voltage in each compartment'''
self.recVList = []
'''dict of range variables recorded'''
self.recordVars = {}
def set_3d_geometry(self, pts, diams):
'''
invokes NEURON 3D geometry setup
'''
if len(pts) != len(diams):
errStr = 'List of diameters does not match list of 3D points'
raise RuntimeError(errStr)
self.pts = pts
self.nrOfPts = len(pts)
self.diamList = diams
# silent output in dummy instead of stdout
dummy = h.pt3dclear(sec=self)
for i in range(len(self.pts)):
x, y, z = self.pts[i]
d = self.diamList[i]
dummy = h.pt3dadd(x, y, z, d, sec=self)
# not good! set after passive properties have been assigned
# self.nseg = self.nrOfPts
self._compute_bounds()
self._compute_relative_pts()
# execute after nr of segments has been determined
# self._init_vm_recording()
def set_segments(self, nrOfSegments):
'''
Set spatial discretization. should be used
together with biophysical parameters to produce
meaningful, yet efficient discretization.
calls private methods to initialize recordings etc.
'''
self.nseg = nrOfSegments
self._compute_seg_pts()
self._compute_seg_diameters()
self._compute_total_area()
# TODO: find a way to make this more efficient,
# i.e. allocate memory before running simulation
self._init_vm_recording()
def _compute_seg_diameters(self):
'''fill list of diameters of all segments.
Approximation for visualization purposes only.
Takes diameter at 3D point closest to segment center.'''
self.segDiams = []
for x in self.segx:
minDist = 1.0
minID = 0
for i in range(len(self.relPts)):
dist = abs(self.relPts[i] - x)
if dist < minDist:
minDist = dist
minID = i
self.segDiams.append(self.diamList[minID])
def _compute_total_area(self):
'''computes total area of all NEURON segments in this section'''
area = 0.0
dx = 1.0/self.nseg
for i in range(self.nseg):
x = (i+0.5)*dx
area += h.area(x, sec=self)
self.area = area
def _compute_bounds(self):
pts = self.pts
xMin, xMax = pts[0][0], pts[0][0]
yMin, yMax = pts[0][1], pts[0][1]
zMin, zMax = pts[0][2], pts[0][2]
for i in range(1,len(pts)):
if pts[i][0] < xMin:
xMin = pts[i][0]
if pts[i][0] > xMax:
xMax = pts[i][0]
if pts[i][1] < yMin:
yMin = pts[i][1]
if pts[i][1] > yMax:
yMax = pts[i][1]
if pts[i][2] < zMin:
zMin = pts[i][2]
if pts[i][2] > zMax:
zMax = pts[i][2]
self.bounds = xMin, xMax, yMin, yMax, zMin, zMax
def _compute_relative_pts(self):
self.relPts = [0.0]
ptLength = 0.0
pts = self.pts
for i in range(len(pts)-1):
pt1, pt2 = np.array(pts[i]), np.array(pts[i+1])
ptLength += np.sqrt(np.sum(np.square(pt1-pt2)))
x = ptLength/self.L
self.relPts.append(x)
# avoid roundoff errors:
norm = 1.0/self.relPts[-1]
for i in range(len(self.relPts)-1):
self.relPts[i] *= norm
self.relPts[-1] = 1.0
def _compute_seg_pts(self):
''' compute 3D center points of segments
by approximating section as straight line
(only for visualization) '''
if len(self.pts) > 1:
p0 = np.array(self.pts[0])
p1 = np.array(self.pts[-1])
vec = p1 - p0
dist = np.sqrt(np.dot(vec, vec))
vec /= dist
# endpoint stays the same:
segLength = dist/self.nseg
# total length stays the same (straightening of dendrite;
# however this moves branch points):
# segLength = self.L/self.nseg
for i in range(self.nseg):
segPt = p0 + (i+0.5)*segLength*vec
self.segPts.append(segPt)
self.segx.append((i+0.5)/self.nseg)
else:
self.segPts = [self.pts[0]]
def _init_vm_recording(self):
''' set up recording of Vm at every point.
TODO: recVList[0] should store voltage recorded at
intermediate node between this and parent segment?'''
# beginVec = h.Vector()
# beginVec.record(self(0)._ref_v, sec=self)
# self.recVList.append(beginVec)
for seg in self:
vmVec = h.Vector()
vmVec.record(seg._ref_v, sec=self)
self.recVList.append(vmVec)
# endVec = h.Vector()
# endVec.record(self(1)._ref_v, sec=self)
# self.recVList.append(endVec)
def _re_init_vm_recording(self):
'''resize Vm vectors to 0 to avoid NEURON segfaults'''
for vec in self.recVList:
vec.resize(0)
def _re_init_range_var_recording(self):
'''resize Vm vectors to 0 to avoid NEURON segfaults'''
for key in self.recordVars.keys():
for vec in self.recordVars[key]:
vec.resize(0)
def _init_range_var_recording(self, var, mech=None):
if mech is None:
if not var in self.recordVars.keys():
self.recordVars[var] = []
for seg in self:
vec = h.Vector()
hRef = eval('seg._ref_'+var)
vec.record(hRef, sec=self)
self.recordVars[var].append(vec)
else:
if not var in self.recordVars.keys():
key = mech+'.'+var
self.recordVars[key] = []
for seg in self:
vec = h.Vector()
hRef = eval('seg.'+mech+'._ref_'+var)
vec.record(hRef, sec=self)
self.recordVars[key].append(vec)
class PointCell(object):
'''
simple object for use as spike source
stores spike times in hoc Vector and numpy array
requires VecStim to trigger spikes at specified times
initialize with list (iterable) of spike times
'''
def __init__(self, spikeTimes=None):
'''
self.spikeTimes = list(spikeTimes)
self.spikeVec = h.Vector(spikeTimes)
self.spikes = h.VecStim()
for use as cell connecting to synapses, not directly to NetCons:
self.synapseList = None
'''
if spikeTimes is not None:
self.spikeTimes = spikeTimes
self.spikeVec = h.Vector(spikeTimes)
self.spikes = h.VecStim()
else:
self.spikeTimes = []
self.spikeVec = h.Vector()
self.spikes = h.VecStim()
self.playing = False
self.synapseList = None
def is_active(self):
return self.playing
def play(self):
'''Activate point cell'''
if self.spikeVec.size() and not self.playing:
self.spikes.play(self.spikeVec)
self.playing = True
def append(self, spikeT):
'''append additional spike time'''
self.spikeTimes.append(spikeT)
self.spikeVec.append(spikeT)
def _add_synapse_pointer(self, synapse):
if self.synapseList is None:
self.synapseList = [synapse]
else:
self.synapseList.append(synapse)
def turn_off(self):
'''M. Hines: Note that one can turn off a VecStim without destroying it
by using VecStim.play() with no args. Turn it back on by
supplying a Vector arg. Or one could resize the Vector to 0.
necessary b/c vecevent.mod does not implement ref counting'''
self.playing = False
self.spikes.play()
self.spikeTimes = []
self.spikeVec.resize(0)
class SpikeTrain(PointCell):
'''
Simple object for use as spike train source.
Pre-computes spike times according to user-provided
parameters and plays them as a regular point cell.
Computation of spike times as in NEURON NetStim.
'''
def __init__(self):
PointCell.__init__(self)
self.rand = np.random.RandomState(np.random.randint(123456, 1234567))
self.spikeInterval = None
self.noiseParam = 1.0
self.start = 0.0
self.stop = -1.0
self.playing = False
def set_interval(self, interval):
self.spikeInterval = interval
def set_noise(self, noise):
self.noiseParam = noise
def set_start(self, start):
self.start = start
def set_stop(self, stop):
self.stop = stop
def is_active(self):
return self.playing
def play(self, nSpikes=None):
'''Activate point cell'''
if self.spikeInterval is None:
raise RuntimeError('Trying to activate SpikeTrain without mean interval')
if self.stop < self.start and nSpikes is None:
errstr = 'Trying to activate SpikeTrain without number of spikes or t stop parameter!'
raise RuntimeError(errstr)
if nSpikes is not None:
lastSpike = 0.0
for i in range(nSpikes):
if i == 0:
tSpike = self.start + self._next_interval() - self.spikeInterval*(1 - self.noiseParam)
if tSpike < 0:
tSpike = 0
else:
tSpike = lastSpike + self._next_interval()
self.append(tSpike)
lastSpike = tSpike
elif self.stop > self.start:
lastSpike = 0.0
while True:
if lastSpike == 0:
tSpike = self.start + self._next_interval() - self.spikeInterval*(1 - self.noiseParam)
if tSpike < 0:
tSpike = 0
else:
tSpike = lastSpike + self._next_interval()
if tSpike > self.stop:
break
self.append(tSpike)
lastSpike = tSpike
if self.spikeVec.size() and not self.playing:
self.spikes.play(self.spikeVec)
self.playing = True
def _next_interval(self):
if self.noiseParam == 0:
return self.spikeInterval
else:
return (1 - self.noiseParam)*self.spikeInterval + self.noiseParam*self.spikeInterval*self.rand.exponential()
class SynParameterChanger():
def __init__(self, cell, synParam, t):
self.cell = cell
self.synParam = synParam
self.tEvent = t
self.cell.changeSynParamDict[self.tEvent] = self.synParam
def cvode_event(self):
h.cvode.event(self.tEvent, self.cell.change_synapse_parameters)