import numpy as np
from neuron import h
from . import neuronFunctions as nfx
def findSites(soma, dist, method='hoc', dends=None, incDiam=False):
"""
Take top level section (usually soma) and find all sections connected to it at a certain dist from soma.
If method = 'hoc', uses soma.subtree(), and finds all hoc sections connected to soma
If method = 'struct', requires 4th input dends, which is a list of connections to soma (that must be in soma.subtree())
- just look for sections in dends that are connected at a certain distance
Return:
list of section names
segment value that's the requested distance (i.e. if section is 90-110µm and requested distance is 100µm, return 0.5)
actual distance (this is based off of how many segments, will be within smallest dx of segment, set elsewhere)
"""
# Get subtree
if method=='hoc':
# Use generic hoc method
tree = soma.subtree()
elif method=='struct':
# Use structure method (requires 4th input)
tree = dends
else:
print('Did not recognize method')
return None
# Find sections with requested distance from soma, then return segment number for specific distance
N = len(tree)
isDistance = np.zeros(N,dtype='bool')
nSegment = np.zeros(N)
prevDistance = np.zeros(N) # length of dendrite from previous branch point
postDistance = np.zeros(N) # length of dendrite after site requested
distFromBranch = np.zeros(N) # distance of site from previous branch point
diamAtBranch = np.zeros(N) # diameter at previous branch point
for n in range(N):
proxDist = h.distance(tree[n](0),soma(1))
L = tree[n].L
if (proxDist < dist) & (proxDist+L >= dist):
isDistance[n]=True # record that this one is a valid distance from the soma
nSegment[n] = (dist-proxDist)/L # record segment at requested distance
distFromBranch[n] = h.distance(tree[n](nSegment[n]),tree[n](0)) # distance from previous branch point
diamAtBranch[n] = tree[n].diam
# Measure total dendritic length after site
currentTree = tree[n].subtree()
for ct in currentTree:
if ct == tree[n]:
diamAdjustment = 1
if incDiam: diamAdjustment = np.pi*ct.diam
postDistance[n] += diamAdjustment * (L - h.distance(tree[n](nSegment[n]),tree[n](0))) # add only the distance after the requested site
else:
diamAdjustment = 1
if incDiam: diamAdjustment = np.pi*ct.diam
postDistance[n] += ct.L * diamAdjustment # add all children distance
# Measure total dendritic length after previous branch point (only works rn because the parents are always parents of sisters)
parentTree = tree[n].parentseg().sec.subtree()
for ct in parentTree:
if ct!=tree[n].parentseg().sec: # don't include the actual parent section (which is included in the subtree)
diamAdjustment = 1
if incDiam: diamAdjustment = np.pi*ct.diam
prevDistance[n] += ct.L * diamAdjustment
outSection = [tree[sec] for sec in np.where(isDistance)[0]]
outSegment = nSegment[isDistance]
outPost = postDistance[isDistance]
outPre = prevDistance[isDistance]
outDistBranch = distFromBranch[isDistance]
outDiam = diamAtBranch[isDistance]
return outSection, outSegment, outPost, outPre, outDistBranch, outDiam
def measurePrePostDistance(section,segment):
N = len(section)
prevDistance = np.zeros(N)
postDistance = np.zeros(N)
for n in range(N):
cSection = section[n]
cSegment = segment[n]
cTree = cSection.subtree()
# Measure dendritic length after site
for ct in cTree:
postDistance[n] += ct.L
if ct == cSection:
postDistance[n] -= (ct.L - h.distance(ct(cSegment),ct(0))) # discount everything proximal to ROI
# Measure dendritic length after previous branch point
parentTree = cSection.parentseg().sec.subtree()
for ct in parentTree:
if ct!=cSection.parentseg().sec: # don't include parent section (which is included in the subtree)
prevDistance[n] += ct.L
return postDistance, prevDistance
def recordSites(section,segment,recordVariable='_ref_v'):
"""
Takes a list of section names and segments within each section, and sets up recording vectors for each
Section & segment must be registered with one another...
recordVariable determines what to measure from each, (default is membrane voltage)
"""
tv = h.Vector() # Time stamp vector
tv.record(h._ref_t)
vsection = [] # list of hoc vectors for each section
for sec,seg in zip(section, segment):
vsection.append(h.Vector())
vsection[-1].record(getattr(sec(seg),recordVariable)) # record voltage... eventually make this a dynamic attribute name
return vsection,tv
def injectSites(section,segment,stim=None,amplitude=-0.1):
# Always record time stamps
tv = h.Vector()
tv.record(h._ref_t)
# Inject and record voltage in each segment
N = len(section)
i = 0
vrecord = h.Vector()
vsection = []
for sec,seg in zip(section,segment):
i += 1
#print('Working on section {0}, {1}/{2}'.format(sec,i,N))
vrecord.record(sec(seg)._ref_v)
stim = nfx.attachCC(section=sec, delay=50, dur=50, amp=amplitude, loc=seg)
nfx.simulate(tstop=101,v_init=-76,celsius=37)
vsection.append(np.array(vrecord))
return vsection,tv,stim
def injectAlphaSites(section,segment,syn=None,onset=5,tau=2,gmax=0.1,tstop=25):
# Always record time stamps
tv = h.Vector()
tv.record(h._ref_t)
# Inject and record voltage in each segment
N = len(section)
i = 0
vrecord = h.Vector()
vsomaRec = h.Vector()
vsection = []
vsoma = []
for sec,seg in zip(section,segment):
i += 1
#print('Working on section {0}, {1}/{2}'.format(sec,i,N))
vrecord.record(sec(seg)._ref_v)
vsomaRec.record(h.soma(0.5)._ref_v)
syn = nfx.attachAlpha(section=sec, seg=seg, onset=onset, tau=tau, gmax=gmax)
nfx.simulate(tstop=tstop,v_init=-76,celsius=35)
vsection.append(np.array(vrecord))
vsoma.append(np.array(vsomaRec))
return vsection,vsoma,tv,syn
def recordBranchPointDivision(section):
tv = h.Vector()
tv.record(h._ref_t)
targetBranch = []
sisterBranch = []
for sec in section:
targetBranch.append([h.Vector(), h.Vector(), 0.0])
sisterBranch.append([h.Vector(), h.Vector(), 0.0])
parentRef = h.SectionRef(sec=sec)
# Record at first and second segment of target branch immediately after previous branch point
targetBranch[-1][0].record(sec(nfx.returnSegment(sec.nseg,1))._ref_v)
targetBranch[-1][1].record(sec(nfx.returnSegment(sec.nseg,2))._ref_v)
# Resistance!
targetAxialResistance = 4*sec.Ra*1e4 / (np.pi * sec.diam**2)
targetLength = sec.L/sec.nseg
targetBranch[-1][2] = 1e-6 * targetAxialResistance * targetLength
# Record at 1st/2nd segment of sister branch
sref = h.SectionRef(sec=sec.parentseg().sec)
if sref.nchild()!=2:
print('The parent of {0} had more than 2 children! Exiting prematurely.'.format(sec))
return
childIdx = 0
if sref.child[childIdx]==sec:
childIdx=1
sisterSection = sref.child[childIdx]
sisterBranch[-1][0].record(sisterSection(nfx.returnSegment(sisterSection.nseg,1))._ref_v)
sisterBranch[-1][1].record(sisterSection(nfx.returnSegment(sisterSection.nseg,2))._ref_v)
# Resistance!
sisterAxialResistance = 4*sisterSection.Ra*1e4 / (np.pi * sisterSection.diam**2)
sisterLength = sisterSection.L / sisterSection.nseg
sisterBranch[-1][2] = 1e-6 * sisterAxialResistance * sisterLength
return tv,targetBranch,sisterBranch
def measureConvolvedBranching(section, segment, lengthConstant):
convolvedLength = []
for sec, seg in zip(section, segment):
thisSecOffsets = [] # keep track of distance to branch points
thisSecLengths = [] # keep track of distance after branch points
# Start by measuring distance after ROI itself
thisSecOffsets.append(0.0) # no distance between ROI and itself
currentTree = sec.subtree()
currentPost = 0# -h.distance(sec(seg),sec(0)) # start by subtracting distance from previous branch point to ROI (which will be included and offset in following loop)
for ct in currentTree:
currentPost += ct.L
if ct==sec: currentPost -= h.distance(sec(seg),sec(0))
thisSecLengths.append(currentPost)
# Next, measure distance after each previous branch point (including soma)
currentSec = sec
while True:
dist2branch = h.distance(sec(seg),currentSec(0)) # measure distance from ROI
thisSecOffsets.append(dist2branch)
currentTree = currentSec.subtree()
currentPost = 0.0
for ct in currentTree:
currentPost += ct.L
thisSecLengths.append(currentPost)
currentSec = currentSec.parentseg().sec
if currentSec.parentseg() is None: break
# Now, compute weighted average using exponential decay as convolutional filter
expPoints = np.exp(-(np.array(thisSecOffsets))/lengthConstant)
convolvedLength.append(np.dot(thisSecLengths,expPoints)/np.sum(expPoints))
return convolvedLength
def measureLocalBranching(section,segment,soma,lengthConstant):
localBranching = []
somaTree = soma.subtree()
for sec,seg in zip(section,segment):
currentDistance = 0.0
for ct in somaTree:
if ct==sec:
distalLength = h.distance(sec(seg),sec(1))
multFactor = lengthConstant/distalLength * (np.exp(-0/lengthConstant) - np.exp(-distalLength/lengthConstant))
currentDistance += distalLength * multFactor
proxLength = h.distance(sec(seg),sec(0))
multFactor = lengthConstant/proxLength * (np.exp(-0/lengthConstant) - np.exp(-proxLength/lengthConstant))
currentDistance += proxLength * multFactor
elif ct!=soma:
idx1 = h.distance(sec(seg),ct(1))
idx2 = h.distance(sec(seg),ct(0))
further = np.max([idx1,idx2])
closer = np.min([idx1,idx2])
cLength = further - closer
multFactor = lengthConstant/(further - closer) * (np.exp(-closer/lengthConstant) - np.exp(-further/lengthConstant))
currentDistance += cLength * multFactor
else:
None
# Don't add soma...
localBranching.append(currentDistance)
return localBranching
def measureDiscountedMorphRatio(section,lengthConstant,method='exponential'):
discountedLength = []
for sec in section:
discountedLength.append([0.0,0.0])
parentSec = sec.parentseg().sec
parentSecRef = h.SectionRef(sec=parentSec) # Get parent section reference
if parentSecRef.nchild()!=2:
print('The parent of {0} had more than 2 children! Exiting prematurely.'.format(sec))
return
childIdx = 0 # Try idx 0
if parentSecRef.child[childIdx]==sec:
childIdx=1 # Set idx to sister branch
sisSec = parentSecRef.child[childIdx]
# Measure total dendritic length, discounted exponentially
targetTree = sec.subtree()
for tt in targetTree:
idx1 = h.distance(parentSec(1),tt(0))
idx2 = h.distance(parentSec(1),tt(1))
if method=='exponential':
# mult factor is integral of exponential between start and end distance of the current section
multFactor = lengthConstant/(idx2-idx1)*(np.exp(-idx1/lengthConstant) - np.exp(-idx2/lengthConstant)) # average of exponential evaluated between idx1 & idx2
elif method=='linear':
multFactor = 1/(idx2-idx1)*1/(2*lengthConstant) * (idx2**2 - idx1**2)
elif method=='sigmoid':
if len(lengthConstant)!=3:
print('For sigmoid, must provide 3 terms, exiting now.')
return
mainTerm = lambda x: lengthConstant[0] / (lengthConstant[0] + np.exp(-lengthConstant[1]*x - lengthConstant[2]))
multFactor = 1/(idx2-idx1) * (1/lengthConstant[1]) * (np.log(mainTerm(idx2)) - np.log(mainTerm(idx1)))
elif method=='order':
if len(lengthConstant)!=3:
print('For order, must provide 3 terms, exiting now.')
return
order = 0
currSec = tt
while currSec!=sec:
order += 1
currSec = currSec.parentseg().sec
multFactor = 1 - 1/(1+lengthConstant[0]*np.exp(lengthConstant[1]*order+lengthConstant[2]))
discountedLength[-1][0] += tt.L * multFactor
sisterTree = sisSec.subtree()
for st in sisterTree:
idx1 = h.distance(parentSec(1),st(0))
idx2 = h.distance(parentSec(1),st(1))
if method=='exponential':
# mult factor is integral of exponential between start and end distance of the current section
multFactor = lengthConstant/(idx2-idx1)*(np.exp(-idx1/lengthConstant) - np.exp(-idx2/lengthConstant)) # average of exponential evaluated between idx1 & idx2
elif method=='linear':
multFactor = 1/(idx2-idx1)*1/(2*lengthConstant) * (idx2**2 - idx1**2)
elif method=='sigmoid':
if len(lengthConstant)!=3:
print('For sigmoid, must provide 3 terms, exiting now.')
return
mainTerm = lambda x: lengthConstant[0] / (lengthConstant[0] + np.exp(-lengthConstant[1]*x - lengthConstant[2]))
multFactor = 1/(idx2-idx1) * (1/lengthConstant[1]) * (np.log(mainTerm(idx2)) - np.log(mainTerm(idx1)))
elif method=='order':
if len(lengthConstant)!=3:
print('For order, must provide 3 terms, exiting now.')
return
order = 0
currSec = st
while currSec!=sisSec:
order += 1
currSec = currSec.parentseg().sec
multFactor = 1 - 1/(1+lengthConstant[0]*np.exp(lengthConstant[1]*order+lengthConstant[2]))
discountedLength[-1][1] += st.L * multFactor
return discountedLength