# class definitions to generate an MRG axon, extracellular electrode and trial objects
from neuron import h
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.interpolate import LinearNDInterpolator
import os
def line(x, a, b):
return a*x + b
def logfn(x, h, k):
return h*np.log(x) + k
def logreg(x, L, x0, k, b):
return (L/(1.0 + np.exp(-k*(x-x0)))) + b
h.celsius = 37.0
h.dt = 0.005 # ms
class axon():
"""Class to create an axon of passage
Structure and electrical properties are obtained from the MRG model modified for sensory neurons in Gaines et al.
Contains functions to generate the axon with user defined properties and set the XYZ position.
NOTE
----
Instantiating the axon object DOES NOT automatically generate positions for the neuron segments.
Object function setXYZpos MUST be run on the axon object AFTER it has been instantiated.
"""
def __init__(self, axonnodes, fiberD, pos=(0,0,0)):
"""
Parameters
----------
axonnodes : int
number of nodes to include in the axon (must be >= 1)
fiberD: float
diameter of axon
pos : tuple
(x,y,z) position of endpoint of axon
"""
# electrical parameters
self.rhoa = 0.7e6 # [Ohm-um] specific axoplasmic resistance
self.mycm = 0.1 # uF/cm2/lamella membrane//
self.mygm = 0.001 # S/cm2/lamella membrane//
self.timestep = h.dt
# morphological parameters
self.axonnodes = axonnodes
self.paranodes1 = 2*(axonnodes-1)
self.paranodes2 = 2*(axonnodes-1)
self.axoninter = 6*(axonnodes-1)
self.numNodes = axonnodes
self.numSections = self.axonnodes+self.paranodes1+self.paranodes2+self.axoninter
self.fiberD = fiberD
self.paralength1 = 3.0
self.nodelength = 1.0
self.space_p1 = 0.002
self.space_p2 = 0.004
self.space_i = 0.004
self.sectionDF = pd.DataFrame(columns=['nodeIndex', 'sectionIndex', 'object', 'sectionType', 'x', 'y', 'z', 'L'])
self.getModelParamsContinuous(fiberD)
# construct axon
parentSection = []
for iNode in range(axonnodes-1):
parentSection = self.createMRGfiber(iNode, parentSection=parentSection)
self.createNodeSection(iNode+1, parentSection=parentSection)
# setup recording
self.membraneV = []
def createMRGfiber(self, nodeNum, parentSection=None):
node = self.createNodeSection(nodeNum)
mysa1 = self.createMYSAsection(node)
flut1 = self.createFLUTsection(mysa1)
STINroot = flut1
for iSTIN in range(6):
STINroot = self.createSTINsection(parentSection=STINroot)
flut2 = self.createFLUTsection(STINroot)
mysa2 = self.createMYSAsection(flut2)
if parentSection:
node.connect(parentSection)
return mysa2
def createNodeSection(self, nodeNum, parentSection=None):
tmp = h.Section(name='node')
tmp.nseg = 1
tmp.L = self.nodelength
tmp.Ra = self.rhoa / 10000
tmp.cm = 2.0
tmp.diam = self.nodeD
# xc is a hoc nickname for xc[0], i.e. the capacitance in the extracellular layer that is immediately adjacent to the
# cell membrane. xc[1] is the capacitance in the outer extracellular layer. Similar comments apply to xg and xg[1].
if nodeNum ==0 or nodeNum == self.axonnodes: #80:
tmp.insert('pas')
tmp.g_pas = 0.0001
tmp.e_pas = -80
else:
tmp.insert('node_sensory')
tmp.insert('extracellular')
tmp.xraxial[0] = self.Rpn0
tmp.xg[0] = 1e10
tmp.xc[0] = 0.0
self.sectionDF = self.sectionDF.append({'sectionType': 'node',
'nodeIndex': (self.sectionDF['sectionType'] == 'node').sum(),
'sectionIndex': (self.sectionDF['sectionType']=='node').sum(),
'x': 0, 'y': 0, 'z': 0, 'L': tmp.L, 'object': tmp}, ignore_index=True)
if parentSection:
tmp.connect(parentSection)
return tmp
def createMYSAsection(self, parentSection=None):
tmp = h.Section(name='MYSA')
tmp.nseg = 1
tmp.L = self.paralength1
tmp.Ra = self.rhoa * (1 / (self.paraD1 / self.fiberD) ** 2.0) / 10000
tmp.cm = 2.0 * self.paraD1 / self.fiberD
tmp.diam = self.fiberD
tmp.insert('mysa_sensory')
tmp.insert('extracellular')
tmp.xraxial[0] = self.Rpn1
tmp.xg[0] = self.mygm / (self.nl * 2)
tmp.xc[0] = self.mycm / (self.nl * 2)
self.sectionDF = self.sectionDF.append({'sectionType': 'MYSA',
'nodeIndex': (self.sectionDF['sectionType'] == 'node').sum(),
'sectionIndex': (self.sectionDF['sectionType']=='MYSA').sum(),
'x': 0, 'y': 0, 'z': 0, 'L': tmp.L, 'object': tmp}, ignore_index=True)
if parentSection:
tmp.connect(parentSection)
return tmp
def createFLUTsection(self, parentSection=None):
tmp = h.Section(name='FLUT')
tmp.nseg = 1
tmp.L = self.paralength2
tmp.Ra = self.rhoa * (1 / (self.paraD2 / self.fiberD) ** 2.0) / 10000
tmp.cm = 2.0 * self.paraD2 / self.fiberD
tmp.diam = self.fiberD
tmp.insert('flut_sensory')
tmp.insert('extracellular')
tmp.xraxial[0] = self.Rpn2
tmp.xg[0] = self.mygm / (self.nl * 2)
tmp.xc[0] = self.mycm / (self.nl * 2)
self.sectionDF = self.sectionDF.append({'sectionType': 'FLUT',
'nodeIndex': (self.sectionDF['sectionType'] == 'node').sum(),
'sectionIndex': (self.sectionDF['sectionType'] == 'FLUT').sum(),
'x': 0, 'y': 0, 'z': 0, 'L': tmp.L, 'object': tmp}, ignore_index=True)
if parentSection:
tmp.connect(parentSection)
return tmp
def createSTINsection(self, parentSection=None):
tmp = h.Section(name='STIN')
tmp.nseg = 1
tmp.L = self.interlength
tmp.Ra = self.rhoa * (1 / (self.axonD / self.fiberD) ** 2.0) / 10000
tmp.cm = 2 * self.axonD / self.fiberD
tmp.diam = self.fiberD
tmp.insert('stin_sensory')
tmp.insert('extracellular')
tmp.xraxial[0] = self.Rpx
tmp.xg[0] = self.mygm / (self.nl * 2)
tmp.xc[0] = self.mycm / (self.nl * 2)
self.sectionDF = self.sectionDF.append({'sectionType': 'STIN',
'nodeIndex': (self.sectionDF['sectionType'] == 'node').sum(),
'sectionIndex': (self.sectionDF['sectionType'] == 'STIN').sum(),
'x': 0, 'y': 0, 'z': 0, 'L': tmp.L, 'object': tmp}, ignore_index=True)
if parentSection:
tmp.connect(parentSection)
return tmp
def setXYZpos(self, pos, bent=False, peak_index=None, xpos_set=False, xshift=None):
# place middle node of axon at center of DRG with some random jitter
"""
pos: tuple
(x,y,z) coordinates of the negative endpoint of the axon OR of the desired center point of the axon
if the x coordinate is the positive endpoint of the axon, set xpos_set=False
if the x coordinate is the desired center point of the axon, set xpos_set=True
Optional parameters should be left to defaults and only changed under the following conditions:
set bent=True IF you want to have the axon bend at a sharp angle near the top of the DRG to prevent it from sticking out of the DRG area
NOTE: no adverse effects result from axons sticking out of the DRG area as long as the axon is long enough to avoid terminal effects
set peak_index IF bent==True AND you already know the node index at which the axon should bend
set xpos_set=True IF the x coordinate in the pos tuple is the x-coordinate of exact desired center point of the axon
set xshift IF the x coordinate is the positive endpoint of the axon AND there is a specific desired amount of jitter from the center
"""
if xpos_set:
self.sectionDF['x'] = pos[0] - (self.sectionDF.loc[::1, 'L'].cumsum()[::1] - self.sectionDF['L'] / 2)
elif isinstance(xshift, float):
xpos = 0 + (self.deltax*(self.axonnodes/2))
xpos = xpos + xshift
self.sectionDF['x'] = xpos - (self.sectionDF.loc[::1, 'L'].cumsum()[::1] - self.sectionDF['L'] / 2)
self.xpos = xpos
self.xshift = xshift
else:
xpos = 0 + (self.deltax*(self.axonnodes/2))
xshift = (np.random.random()-0.5)*self.deltax
xpos = xpos + xshift
self.sectionDF['x'] = xpos - (self.sectionDF.loc[::1, 'L'].cumsum()[::1] - self.sectionDF['L'] / 2)
self.xpos = xpos
self.xshift = xshift
#self.sectionDF['x'] = pos[0] - (self.sectionDF.loc[::1, 'L'].cumsum()[::1] - self.sectionDF['L'] / 2)
self.sectionDF['y'] = pos[1]
self.sectionDF['z'] = pos[2]
if bent:
if isinstance(peak_index, int):
cutoff = peak_index
else:
idx = (self.sectionDF['x']>-500) & (self.sectionDF['x']<500)
i = np.random.randint(len(self.sectionDF.loc[idx, 'x']))
indices = self.sectionDF.index
i_test = indices[self.sectionDF.x == self.sectionDF.loc[idx,'x'].iloc[i]]
cutoff = i_test.to_list()[0]
df1 = self.sectionDF.iloc[:cutoff].copy()
df2 = self.sectionDF.iloc[cutoff+1:].copy()
seg_shift = self.sectionDF.loc[cutoff, 'L']/2
xshift = self.sectionDF.loc[cutoff, 'x']
yshift = pos[1]
zshift = pos[2]
#print(xshift, yshift, zshift)
if pos[1] >= 750:
yTube = 600
elif pos[1] <= -750:
yTube = -600
else:
yTube = yshift
if pos[2] >= 750:
zTube = 600
elif pos[2] <= -750:
zTube = -600
else:
zTube = zshift
centralTube = (-3200, yTube, zTube)
peripheralTube = (3200, yTube, zTube)
rho_p = np.sqrt((xshift - peripheralTube[0]) ** 2 + (yshift - peripheralTube[1]) ** 2 + (zshift - peripheralTube[2]) ** 2)
a_p = (xshift - peripheralTube[0]) / rho_p
b_p = (yshift - peripheralTube[1]) / rho_p
c_p = (zshift - peripheralTube[2]) / rho_p
if (peripheralTube[1] == yshift) and (peripheralTube[2] == zshift):
gamma_p = np.pi / 2
else:
gamma_p = np.abs(np.arctan((peripheralTube[0] - xshift) / np.sqrt((peripheralTube[1] - yshift)**2 + (peripheralTube[2] - zshift)**2)))
df1.loc[:,'x'] = xshift + (seg_shift + df1.loc[::-1, 'L'].cumsum() - df1['L'] / 2) * np.round(np.sin(gamma_p), 4)
df1.loc[:,'y'] = peripheralTube[1] + b_p / a_p * (df1.loc[:,'x'] - peripheralTube[0])
df1.loc[:,'z'] = peripheralTube[2] + c_p / a_p * (df1.loc[:,'x'] - peripheralTube[0])
tubeIdx_p = df1['x'] >= peripheralTube[0]
df1.loc[tubeIdx_p,'z'] = peripheralTube[2]
df1.loc[tubeIdx_p,'y'] = peripheralTube[1]
rho_c = np.sqrt((xshift-centralTube[0])**2 + (yshift-centralTube[1])**2 + (zshift-centralTube[2])**2)
a_c = (xshift-centralTube[0])/rho_c
b_c = (yshift-centralTube[1])/rho_c
c_c = (zshift-centralTube[2])/rho_c
if (centralTube[1] == yshift) and (centralTube[2] == zshift):
gamma_c = np.pi/2
else:
gamma_c = np.abs(np.arctan((centralTube[0] - xshift) / np.sqrt((centralTube[1]-yshift)**2 + (centralTube[2]-zshift)**2)))
df2.loc[:,'x'] = xshift - (seg_shift + df2.loc[::, 'L'].cumsum() - df2['L'] / 2) * np.round(np.sin(gamma_c), 4)
df2.loc[:,'y'] = centralTube[1] + b_c/a_c*(df2.loc[:,'x'] - centralTube[0])
df2.loc[:,'z'] = centralTube[2] + c_c/a_c*(df2.loc[:,'x'] - centralTube[0])
tubeIdx_c = df2['x'] <= centralTube[0]
df2.loc[tubeIdx_c,'z'] = centralTube[2]
df2.loc[tubeIdx_c,'y'] = centralTube[1]
self.sectionDF.loc[:cutoff, 'x'] = df1.loc[:, 'x']
self.sectionDF.loc[:cutoff, 'y'] = df1.loc[:, 'y']
self.sectionDF.loc[:cutoff, 'z'] = df1.loc[:, 'z']
self.sectionDF.loc[cutoff+1:, 'x'] = df2.loc[:, 'x']
self.sectionDF.loc[cutoff+1:, 'y'] = df2.loc[:, 'y']
self.sectionDF.loc[cutoff+1:, 'z'] = df2.loc[:, 'z']
self.sectionDF.loc[cutoff, 'x'] = xshift
self.sectionDF.loc[cutoff, 'y'] = yshift
self.sectionDF.loc[cutoff, 'z'] = zshift
return True
def setXYZpos_curve(self, pos):
# place middle node at center of DRG and curve along with DRG exterior
# DOES NOT GET USED FOR PAPER RESULTS
self.sectionDF['x'] = pos[0] - (self.sectionDF.loc[::1, 'L'].cumsum()[::1] - self.sectionDF['L'] / 2)
self.sectionDF['y'] = pos[1]
self.sectionDF['z'] = pos[2]
R_axon = np.sqrt(pos[1]**2 + pos[2]**2)
if pos[2] == 0:
theta = np.pi/2
else:
theta = np.abs(np.arctan(pos[1]/pos[2]))
if R_axon>=700:
x_cutoff = 3700*np.sqrt(1-(R_axon/1500)**2)
curveIdx = ((self.sectionDF['x']>=x_cutoff) & (self.sectionDF['x']<3200)) | ((self.sectionDF['x']<=-x_cutoff) & (self.sectionDF['x']>-3200))
self.sectionDF.loc[curveIdx, 'y'] = (1500*np.sqrt(1-(self.sectionDF.loc[curveIdx, 'x']/3700)**2)) * np.round(np.sin(theta), 4)
self.sectionDF.loc[curveIdx, 'z'] = (1500*np.sqrt(1-(self.sectionDF.loc[curveIdx, 'x']/3700)**2)) * np.round(np.cos(theta), 4)
tubeIdx = (self.sectionDF['x']>=3200) | (self.sectionDF['x']<=-3200)
self.sectionDF.loc[tubeIdx, 'y'] = 700 * np.round(np.sin(theta), 4)
self.sectionDF.loc[tubeIdx, 'z'] = 700 * np.round(np.cos(theta), 4)
return True
def getSectionFromDF(self, sType, nodeIdx=None, returnObj=True):
"""
Parameters:
sType: str
string associated with section type: 'node', 'MYSA', 'FLUT', OR 'STIN'
nodeIdx: int
index of the node with which the section is associated
returnObj: boolean
True if the section object is desired
False if the relevant row of the dataframe containing information about the section is desired
"""
#if nodeIdx:
if isinstance(nodeIdx, int):
#print(nodeIdx)
dfRow = self.sectionDF[(self.sectionDF ['sectionType'] == sType) &
(self.sectionDF ['nodeIndex'] == nodeIdx)]
elif not(nodeIdx) and sType:
#print('no node idx')
dfRow = self.sectionDF[self.sectionDF['sectionType'] == sType]
else:
raise ValueError('missing input argument for branch and/or node index')
if returnObj:
return dfRow['object'].item()
else:
return dfRow
def getModelParamsContinuous(self, fiberD):
"""Generate nterpolated model parameters for all fiberD values based on interpolation functions
applied to the discrete values defined in the MRG model."""
self.g = line(fiberD, 0.01716804, 0.5075587)
self.axonD = line(fiberD, 0.88904883, -1.9104369)
self.nodeD = line(fiberD, 0.34490792, -0.14841106)
self.paraD1 = line(fiberD, 0.34490792, -0.14841106)
self.paraD2 = line(fiberD, 0.88904883, -1.9104369)
self.deltax = logreg(fiberD, 3.79906687e+03, 2.13820902e+00, 2.48122018e-01, -2.19548067e+03)
self.paralength2 = logreg(fiberD, 30.77203038, 10.53182692, 0.42725082, 31.47653035)
self.nl = int(round(logfn(fiberD, 65.89739004, -32.66582976)))
self.Rpn0 = (self.rhoa * .01) / (np.pi * ((((self.nodeD / 2) + self.space_p1) ** 2) - ((self.nodeD / 2) ** 2)))
self.Rpn1 = (self.rhoa * .01) / (np.pi * ((((self.paraD1 / 2) + self.space_p1) ** 2) - ((self.paraD1 / 2) ** 2)))
self.Rpn2 = (self.rhoa * .01) / (np.pi * ((((self.paraD2 / 2) + self.space_p2) ** 2) - ((self.paraD2 / 2) ** 2)))
self.Rpx = (self.rhoa * .01) / (np.pi * ((((self.axonD / 2) + self.space_i) ** 2) - ((self.axonD / 2) ** 2)))
self.interlength = (self.deltax - self.nodelength - (2 * self.paralength1) - (2 * self.paralength2)) / 6
def getModelParameters(self, fiberD):
""" ALTERNATE PARAMETER SETTING - BEST IF USED FOR DISCRETE fiberD VALUES DEFINED IN MRG MODEL.
Obtain model parameters for discrete fiberD values from the MRG model. Generate interpolated model parameters
for non-MRG fiberD values. Interpolation functions obtained from Gaines et. al."""
if fiberD == 5.7:
self.g = 0.605
self.axonD = 3.4
self.nodeD = 1.9
self.paraD1 = 1.9
self.paraD2 = 3.4
self.deltax = 500.0
self.paralength2 = 35.0
self.nl = 80.0
elif fiberD == 7.3:
self.g = 0.630
self.axonD = 4.6
self.nodeD = 2.4
self.paraD1 = 2.4
self.paraD2 = 4.6
self.deltax = 750.0
self.paralength2 = 38.0
self.nl = 100.0
elif fiberD == 8.7:
self.g = 0.661
self.axonD = 5.8
self.nodeD = 2.8
self.paraD1 = 2.8
self.paraD2 = 5.8
self.deltax = 1000.0
self.paralength2 = 40.0
self.nl = 110.0
elif fiberD == 10.0:
self.g = 0.690
self.axonD = 6.9
self.nodeD = 3.3
self.paraD1 = 3.3
self.paraD2 = 6.9
self.deltax = 1150.0
self.paralength2 = 46.0
self.nl = 120.0
elif fiberD == 11.5:
self.g = 0.700
self.axonD = 8.1
self.nodeD = 3.7
self.paraD1 = 3.7
self.paraD2 = 8.1
self.deltax = 1250.0
self.paralength2 = 50.0
self.nl = 130.0
elif fiberD == 12.8:
self.g = 0.719
self.axonD = 9.2
self.nodeD = 4.2
self.paraD1 = 4.2
self.paraD2 = 9.2
self.deltax = 1350.0
self.paralength2 = 54.0
self.nl = 135.0
elif fiberD == 14.0:
self.g = 0.739
self.axonD = 10.4
self.nodeD = 4.7
self.paraD1 = 4.7
self.paraD2 = 10.4
self.deltax = 1400.0
self.paralength2 = 56.0
self.nl = 140.0
elif fiberD == 15.0:
self.g = 0.767
self.axonD = 11.5
self.nodeD = 5.0
self.paraD1 = 5.0
self.paraD2 = 11.5
self.deltax = 1450.0
self.paralength2 = 58.0
self.nl = 145.0
elif fiberD == 16.0:
self.g = 0.791
self.axonD = 12.7
self.nodeD = 5.5
self.paraD1 = 5.5
self.paraD2 = 12.7
self.deltax = 1500.0
self.paralength2 = 60.0
self.nl = 150.0
else:
self.g = 0.0172 * fiberD + 0.5076
self.axonD = 0.889 * fiberD - 1.9104
self.nodeD = 0.3449 * fiberD - 0.1484
self.paraD1 = 0.3527 * fiberD - 0.1804
self.paraD2 = 0.889 * fiberD - 1.9104
self.deltax = 969.3 * np.log(fiberD) - 1144.6
self.paralength2 = 2.5811 * fiberD + 19.59
self.nl = (65.897 * np.log(fiberD) - 32.666)
self.Rpn0 = (self.rhoa * .01) / (np.pi * ((((self.nodeD / 2) + self.space_p1) ** 2) - ((self.nodeD / 2) ** 2)))
self.Rpn1 = (self.rhoa * .01) / (np.pi * ((((self.paraD1 / 2) + self.space_p1) ** 2) - ((self.paraD1 / 2) ** 2)))
self.Rpn2 = (self.rhoa * .01) / (np.pi * ((((self.paraD2 / 2) + self.space_p2) ** 2) - ((self.paraD2 / 2) ** 2)))
self.Rpx = (self.rhoa * .01) / (np.pi * ((((self.axonD / 2) + self.space_i) ** 2) - ((self.axonD / 2) ** 2)))
self.interlength = (self.deltax - self.nodelength - (2 * self.paralength1) - (2 * self.paralength2)) / 6
def plotMembraneV(self):
t_vec = np.linspace(0,len(self.membraneV)*h.dt,len(self.membraneV))
plt.plot(t_vec, self.membraneV)
plt.xlabel('time (ms)')
plt.ylabel('mV')
plt.show()