from neuron import h
import numpy as np
import math
from scipy.signal import butter, lfilter
import sys
from stimStrat import KFS
# count the number of spikes given a spike train
def firingRate(spikeTrain):
threshold = -50
aboveThresholdFlag = False
aboveThresholdSpikes = []
countSpike = 0
for sp in spikeTrain:
if sp > threshold:
aboveThresholdFlag = True
aboveThresholdSpikes.append(sp)
else:
if aboveThresholdFlag:
if abs(max(aboveThresholdSpikes)-sp) > 40:
countSpike = countSpike + 1
aboveThresholdFlag = False
aboveThresholdSpikes = []
return countSpike
# load NEURON GUI
h.load_file("nrngui.hoc")
# specify the location of stimulating electrode (please adjust the 3D coordiante to suit your case)
e_3D = [5000, 400, -300]
# build nerve
h.load_file("nerve/sciaticNerveBuilder.hoc")
"""
# Uncommment to automatically run the toolkit if you don't have the rx_xtra_interpolated.txt already in the directory.
# import toolkit header
import autoToolkit as tk
# define arguments
path2server = 'C:\\Program Files\\COMSOL\\COMSOL56\\Multiphysics\\bin\\win64'
path2mph = 'C:\\Program Files\\COMSOL\\COMSOL56\\Multiphysics\\mli'
simBox_3D = [5000, 0, 0]
simBox_size = 12000
nerve_3D = [-50, 0, 0]
nerve_R = 2100
nerve_L = 10100
substrate_3D = [5000, 0, 340]
substrate_W = 500
substrate_L = 4000
substrate_D = 30
e_R = 10
fasc_3D = []
fasc_R = []
fasc_L = 10060
with open('../nerve/fasciclesInfo.txt', 'r') as f:
for line in f.readlines():
info = line.split(',')
fasc_3D.append([float(info[0]),int(info[1]),int(info[2])])
fasc_R.append(float(info[3]))
# optional arguments
rotate_deg = -50
simBox_G = 1.45
nerve_G = 0.01
fasc_G = 0.0517
mesh_size = 3
e_type = "hexapolar"
e2e_dist = None
# call the function to automate pipeline
tk.pipeline(path2server, path2mph, simBox_3D, simBox_size, nerve_3D, nerve_R, nerve_L, fasc_3D, fasc_R, fasc_L, \
substrate_3D, substrate_W, substrate_L, substrate_D, e_R, \
e_type=e_type, e2e_dist=e2e_dist, rotate_deg=rotate_deg, simBox_G=simBox_G, \
nerve_G=nerve_G, fasc_G=fasc_G, mesh_size=mesh_size)
"""
# set transfer resistances between the fibres and the electrode
h.load_file("setrx.hoc")
# attach electrode
h.load_file("attachStim.hoc")
# sort the fibres based on distance, store into array
dist = dict()
detected = []
Nsec = 0
for sec in h.allsec():
Nsec = Nsec + 1
# ignore section for electrode
if str(sec) == 'sElec':
continue
# get fibre's ID
secName = ''.join(e for e in str(sec) if e.isalnum()) # e.g. "AFibreBuilder11MYSA0"
fibreType = secName[:6] # e.g. "AFibre"
num = ""
for c in secName[:17]:
if c.isdigit():
num = num + c
fibreIndex = int(num) # e.g. 11
fibreName = fibreType + str(fibreIndex)
# only process undetected fibres
if fibreName in detected:
continue
detected.append(fibreName)
# get fibre's distance to electrode
dist[fibreName] = math.sqrt((sec.y3d(0)-e_3D[1])**2+(sec.z3d(0)-e_3D[2])**2)
fibres_od = sorted(dist, key=lambda fibre: (dist[fibre])) # sort the dictionary by fibre's distance to electrode
print(fibres_od)
# assign stimulation waveform to electrode
delay = 1 # ms
amp = 0.1 # mA
freq = 4e3 # Hz, sinusoidal frequency in Hz
dur = 50 # ms
last = 1
dt = 0.025
observeRange = 1170
(h.stim_time, h.stim_amp) = KFS(delay, amp, freq, dur, last, dt)
h.attach_stim() # apply waveform to the electrode
# detect fibres' responses
""" fibre responses stored in the following dictionary (example):
resp {
"AFibre11": {
"dist": 412.3,
"sr": 4
}
...
}
"""
resp = dict()
attDv = h.Vector()
detected = []
counter = 0
for sec in h.allsec():
# ignore section for electrode
if str(sec) == 'sElec':
continue
# get fibre's ID
secName = ''.join(e for e in str(sec) if e.isalnum()) # e.g. "AFibreBuilder11MYSA0"
fibreType = secName[:6] # e.g. "AFibre"
node = secName[-5:]
num = ""
for c in secName[:17]:
if c.isdigit():
num = num + c
fibreIndex = int(num) # e.g. 11
fibreName = fibreType + str(fibreIndex)
# ignore the first node for C fibre
if fibreType == 'CFibre' and node == 'node0':
continue
# only process fibres within observation range
if fibreName not in fibres_od[:observeRange]:
continue
# only process undetected fibres
if fibreName in detected:
continue
detected.append(fibreName)
# display progress
counter = counter + 1
sys.stdout.write("\r processing %f percent" % (counter/Nsec*100))
sys.stdout.flush()
resp[fibreName] = dict()
# store 3D position info
resp[fibreName]['y'] = sec.y3d(0)
resp[fibreName]['z'] = sec.z3d(0)
# detect fibre's spiking rate
attDv.record(sec(0.5)._ref_v)
if fibreType == "AFibre":
h.v_init = -80
h.dt = dt
h.tstop = delay+dur+last
h.run()
else:
h.v_init = -60
h.dt = dt
h.tstop = delay+dur+last
h.run()
spikeTrain = np.array(attDv)
# remove stimulus artefacts by a filter
b, a = butter(10, 3000, 'low', analog=False, fs=40000)
spikeTrainFiltered = lfilter(b, a, spikeTrain+80)-80
# store and plot response
if fibreType == "AFibre":
spikeRate = firingRate(spikeTrainFiltered)
resp[fibreName]['sr'] = spikeRate
else:
spikeRate = firingRate(spikeTrainFiltered)
resp[fibreName]['sr'] = spikeRate
with open('data/PE/hex.txt', 'w') as f:
for fibre in fibres_od[:observeRange]:
print('%s %g %g %g' % (fibre, resp[fibre]['y'], resp[fibre]['z'], resp[fibre]['sr']), file=f)