import numpy as np
import pylab
import random
import matplotlib.pyplot as plt
import os
from mpi4py import MPI
from scipy import interpolate
class NeuralNetwork:
from neuron import h
def __init__(self):
self.EES = 0
self.EES_freqeqncy = 0
self.weightStimMn = 20
self.weightStimAff = 10
self.comm = MPI.COMM_WORLD
self.sizeComm = self.comm.Get_size()
self.rank = self.comm.Get_rank()
self.percFibersIa_GM = 0
self.percFibersII_GM = 0
self.percMn_GM = 0
self.percFibersIa_TA = 0
self.percFibersII_TA = 0
self.percMn_TA = 0
#Initializing
self.h.xopen("NeuralNetwork.hoc")
self.h('{t=pc.set_maxstep(0.5)}')
self.h.stdinit()
self.h.t=0
def __del__(self):
self.h('{pc.runworker()}')
self.h('{pc.done()}')
self.h('t2 = startsw() // timer')
self.h('print "model setup time ", t1-t0, " run time ", t2-t1, " total ", t2-t0')
self.h.quit()
# Compute indexes in the different hosts
def computeInd(self,nCells):
StimInd_0 = nCells//self.h.Nhost
if(self.h.PcID< nCells%self.h.Nhost):
StimInd_0+=1
return StimInd_0
#EES stimulation Ia fibers
def set_IA_stim(self,percFlex,percExt, w):
i=0
"""
Activating Flexors
"""
nFlex = int(percFlex*self.h.stimIafEES_Flex.count())
nFlex_dec = percFlex*self.h.stimIafEES_Flex.count()-nFlex
nFlex_dec = self.comm.gather(nFlex_dec,root=0)
nFlex_extra = []
if self.rank == 0:
nFlex_extra = round(sum(nFlex_dec))
nFlex_extra = self.comm.bcast(nFlex_extra,root=0)
nFlex_extra = self.computeInd(nFlex_extra)
nFlex += int(nFlex_extra)
# Randomize excited Iaf
nIAfxhost = np.zeros(self.h.stimIafEES_Flex.count())
for i in range(nFlex):
nIAfxhost[i]=1
random.shuffle(nIAfxhost)
indFlex = np.nonzero(nIAfxhost)
for i in indFlex[0]:
self.h.stimIafEES_Flex.object(i).weight[0]=w
"""
Activating Extensors
"""
nExt = int(percExt*self.h.stimIafEES_Ext.count())
nExt_dec = percExt*self.h.stimIafEES_Ext.count()-nExt
nExt_dec = self.comm.gather(nExt_dec,root=0)
nExt_extra = []
if self.rank == 0:
nExt_extra = round(sum(nExt_dec))
nExt_extra = self.comm.bcast(nExt_extra,root=0)
nExt_extra = self.computeInd(nExt_extra)
nExt += int(nExt_extra)
# Randomize excited Iaf
nIAfxhost = np.zeros(self.h.stimIafEES_Ext.count())
for i in range(nExt):
nIAfxhost[i]=1
random.shuffle(nIAfxhost)
indExt = np.nonzero(nIAfxhost)
for i in indExt[0]:
self.h.stimIafEES_Ext.object(i).weight[0]=w
#EES stimulation II fibers
def set_II_stim(self,percFlex,percExt, w):
i=0
"""
Activating Flexors
"""
nFlex = int(percFlex*self.h.stimIIfEES_Flex.count())
nFlex_dec = percFlex*self.h.stimIIfEES_Flex.count()-nFlex
nFlex_dec = self.comm.gather(nFlex_dec,root=0)
nFlex_extra = []
if self.rank == 0:
nFlex_extra = round(sum(nFlex_dec))
nFlex_extra = self.comm.bcast(nFlex_extra,root=0)
nFlex_extra = self.computeInd(nFlex_extra)
nFlex += int(nFlex_extra)
for i in range(nFlex):
self.h.stimIIfEES_Flex.object(i).weight[0]=w
"""
Activating Extensors
"""
nExt = int(percExt*self.h.stimIIfEES_Ext.count())
nExt_dec = percExt*self.h.stimIIfEES_Ext.count()-nExt
nExt_dec = self.comm.gather(nExt_dec,root=0)
nExt_extra = []
if self.rank == 0:
nExt_extra = round(sum(nExt_dec))
nExt_extra = self.comm.bcast(nExt_extra,root=0)
nExt_extra = self.computeInd(nExt_extra)
nExt += int(nExt_extra)
for i in range(nExt):
self.h.stimIIfEES_Ext.object(i).weight[0]=w
#EES stimulation Mns
def set_Mn_stim(self,percFlex,percExt, w):
i=0
"""
Activating Flexors
"""
nFlex = int(percFlex*self.h.stimMn_Flex.count())
nFlex_dec = percFlex*self.h.stimMn_Flex.count()-nFlex
nFlex_dec = self.comm.gather(nFlex_dec,root=0)
nFlex_extra = []
if self.rank == 0:
nFlex_extra = round(sum(nFlex_dec))
nFlex_extra = self.comm.bcast(nFlex_extra,root=0)
nFlex_extra = self.computeInd(nFlex_extra)
nFlex += int(nFlex_extra)
for i in range(nFlex):
self.h.stimMn_Flex.object(i).weight[0]=w
"""
Activating Extensors
"""
nExt = int(percExt*self.h.stimMn_Ext.count())
nExt_dec = percExt*self.h.stimMn_Ext.count()-nExt
nExt_dec = self.comm.gather(nExt_dec,root=0)
nExt_extra = []
if self.rank == 0:
nExt_extra = round(sum(nExt_dec))
nExt_extra = self.comm.bcast(nExt_extra,root=0)
nExt_extra = self.computeInd(nExt_extra)
nExt += int(nExt_extra)
for i in range(nExt):
self.h.stimMn_Ext.object(i).weight[0]=w
#Set the EES frequency (freq in Hz)
def set_EES_freq(self,freq):
if (self.h.pc.gid_exists(self.h.nCell*2)):
EES = self.h.pc.gid2cell(self.h.nCell*2)
EES.interval=1000/freq
# Set Iaf natural firing rate
def set_IA_natural_firing(self,FiringRateFl,FiringRateExt,w):
i=0
"""
Flexors
"""
#Setting to 0 the weight of all populations
for i in range(int(self.h.stimIafNat_Flex_50.count())):
self.h.stimIafNat_Flex_50.object(i).weight[0]=0
for i in range(int(self.h.stimIafNat_Flex_40.count())):
self.h.stimIafNat_Flex_40.object(i).weight[0]=0
for i in range(int(self.h.stimIafNat_Flex_30.count())):
self.h.stimIafNat_Flex_30.object(i).weight[0]=0
for i in range(int(self.h.stimIafNat_Flex_20.count())):
self.h.stimIafNat_Flex_20.object(i).weight[0]=0
if np.mean(FiringRateFl)>=50:
#set the weight to the selected stim population
for i in range(int(self.h.stimIafNat_Flex_50.count())):
self.h.stimIafNat_Flex_50.object(i).weight[0]=w
#set the firing rate
for i in range(int(self.h.nIAf)):
if (self.h.pc.gid_exists(self.h.Ind_IaFibStimFlex+i)):
Iaf = self.h.pc.gid2cell(self.h.Ind_IaFibStimFlex+i)
Iaf.interval=1000/FiringRateFl[i]
elif (np.mean(FiringRateFl)<50 and np.mean(FiringRateFl)>35):
#set the weight to the selected stim population
for i in range(int(self.h.stimIafNat_Flex_40.count())):
self.h.stimIafNat_Flex_40.object(i).weight[0]=w
elif (np.mean(FiringRateFl)<=35 and np.mean(FiringRateFl)>25):
#set the weight to the selected stim population
for i in range(int(self.h.stimIafNat_Flex_30.count())):
self.h.stimIafNat_Flex_30.object(i).weight[0]=w
elif (np.mean(FiringRateFl)<=25 and np.mean(FiringRateFl)>15):
#set the weight to the selected stim population
for i in range(int(self.h.stimIafNat_Flex_20.count())):
self.h.stimIafNat_Flex_20.object(i).weight[0]=w
"""
Extensor
"""
#Setting to 0 the weight of all populations
for i in range(int(self.h.stimIafNat_Ext_50.count())):
self.h.stimIafNat_Ext_50.object(i).weight[0]=0
for i in range(int(self.h.stimIafNat_Ext_40.count())):
self.h.stimIafNat_Ext_40.object(i).weight[0]=0
for i in range(int(self.h.stimIafNat_Ext_30.count())):
self.h.stimIafNat_Ext_30.object(i).weight[0]=0
for i in range(int(self.h.stimIafNat_Ext_20.count())):
self.h.stimIafNat_Ext_20.object(i).weight[0]=0
if np.mean(FiringRateExt)>=50:
#set the weight to the selected stim population
for i in range(int(self.h.stimIafNat_Ext_50.count())):
self.h.stimIafNat_Ext_50.object(i).weight[0]=w
#set the firing rate
for i in range(int(self.h.nIAf)):
if (self.h.pc.gid_exists(self.h.Ind_IaFibStimExt+i)):
Iaf = self.h.pc.gid2cell(self.h.Ind_IaFibStimExt+i)
Iaf.interval=1000/FiringRateExt[i]
elif (np.mean(FiringRateExt)<50 and np.mean(FiringRateExt)>35):
#set the weight to the selected stim population
for i in range(int(self.h.stimIafNat_Ext_40.count())):
self.h.stimIafNat_Ext_40.object(i).weight[0]=w
elif (np.mean(FiringRateExt)<=35 and np.mean(FiringRateExt)>25):
#set the weight to the selected stim population
for i in range(int(self.h.stimIafNat_Ext_30.count())):
self.h.stimIafNat_Ext_30.object(i).weight[0]=w
elif (np.mean(FiringRateExt)<=25 and np.mean(FiringRateExt)>15):
#set the weight to the selected stim population
for i in range(int(self.h.stimIafNat_Ext_20.count())):
self.h.stimIafNat_Ext_20.object(i).weight[0]=w
# Set IIf natural firing rate
def set_II_natural_firing(self,FiringRateFl,FiringRateExt,w):
i=0
"""
Flexors
"""
#Setting to 0 the weight of all populations
for i in range(int(self.h.stimIIfNat_Flex_50.count())):
self.h.stimIIfNat_Flex_50.object(i).weight[0]=0
for i in range(int(self.h.stimIIfNat_Flex_40.count())):
self.h.stimIIfNat_Flex_40.object(i).weight[0]=0
for i in range(int(self.h.stimIIfNat_Flex_30.count())):
self.h.stimIIfNat_Flex_30.object(i).weight[0]=0
for i in range(int(self.h.stimIIfNat_Flex_20.count())):
self.h.stimIIfNat_Flex_20.object(i).weight[0]=0
if np.mean(FiringRateFl)>=50:
#set the weight to the selected stim population
for i in range(int(self.h.stimIIfNat_Flex_50.count())):
self.h.stimIIfNat_Flex_50.object(i).weight[0]=w
#set the firing rate
for i in range(int(self.h.nIIf)):
if (self.h.pc.gid_exists(self.h.Ind_IIFibStimFlex+i)):
IIf = self.h.pc.gid2cell(self.h.Ind_IIFibStimFlex+i)
IIf.interval=1000/FiringRateFl[i]
elif (np.mean(FiringRateFl)<50 and np.mean(FiringRateFl)>35):
#set the weight to the selected stim population
for i in range(int(self.h.stimIIfNat_Flex_40.count())):
self.h.stimIIfNat_Flex_40.object(i).weight[0]=w
elif (np.mean(FiringRateFl)<=35 and np.mean(FiringRateFl)>25):
#set the weight to the selected stim population
for i in range(int(self.h.stimIIfNat_Flex_30.count())):
self.h.stimIIfNat_Flex_30.object(i).weight[0]=w
elif (np.mean(FiringRateFl)<=25 and np.mean(FiringRateFl)>15):
#set the weight to the selected stim population
for i in range(int(self.h.stimIIfNat_Flex_20.count())):
self.h.stimIIfNat_Flex_20.object(i).weight[0]=w
"""
Extensor
"""
#Setting to 0 the weight of all populations
for i in range(int(self.h.stimIIfNat_Ext_50.count())):
self.h.stimIIfNat_Ext_50.object(i).weight[0]=0
for i in range(int(self.h.stimIIfNat_Ext_40.count())):
self.h.stimIIfNat_Ext_40.object(i).weight[0]=0
for i in range(int(self.h.stimIIfNat_Ext_30.count())):
self.h.stimIIfNat_Ext_30.object(i).weight[0]=0
for i in range(int(self.h.stimIIfNat_Ext_20.count())):
self.h.stimIIfNat_Ext_20.object(i).weight[0]=0
if np.mean(FiringRateExt)>=50:
#set the weight to the selected stim population
for i in range(int(self.h.stimIIfNat_Ext_50.count())):
self.h.stimIIfNat_Ext_50.object(i).weight[0]=w
#set the firing rate
for i in range(int(self.h.nIIf)):
if (self.h.pc.gid_exists(self.h.Ind_IIFibStimExt+i)):
IIf = self.h.pc.gid2cell(self.h.Ind_IIFibStimExt+i)
IIf.interval=1000/FiringRateExt[i]
elif (np.mean(FiringRateExt)<50 and np.mean(FiringRateExt)>35):
#set the weight to the selected stim population
for i in range(int(self.h.stimIIfNat_Ext_40.count())):
self.h.stimIIfNat_Ext_40.object(i).weight[0]=w
elif (np.mean(FiringRateExt)<=35 and np.mean(FiringRateExt)>25):
#set the weight to the selected stim population
for i in range(int(self.h.stimIIfNat_Ext_30.count())):
self.h.stimIIfNat_Ext_30.object(i).weight[0]=w
elif (np.mean(FiringRateExt)<=25 and np.mean(FiringRateExt)>15):
#set the weight to the selected stim population
for i in range(int(self.h.stimIIfNat_Ext_20.count())):
self.h.stimIIfNat_Ext_20.object(i).weight[0]=w
# Run simulation
def runSimulation(self,frequency=40,name="",amplitude="optimal"):
#Initializing
self.h('{t=pc.set_maxstep(0.5)}')
self.h.finitialize(0)
self.h.stdinit()
self.h.t=0
dt=5 # simulation step
changeParam = 20 # change parameters every 20 ms
self.h.tstop= 7500 # 7.5 s, simulation max time
#loading the FIRING RATES of IA and II fibers, updates at 50 Hz
#matrix[nIAf][nTIME]
Ia_MG=np.loadtxt('../afferentsData/fr_Ia_GM'+str(name)+'.txt')
Ia_TA=np.loadtxt('../afferentsData/fr_Ia_TA'+str(name)+'.txt')
II_MG=np.loadtxt('../afferentsData/fr_II_GM'+str(name)+'.txt')
II_TA=np.loadtxt('../afferentsData/fr_II_TA'+str(name)+'.txt')
nTIME=Ia_TA[1].size
if amplitude == "optimal":
optimalRecGm=np.loadtxt('../Recruitment_data/OptimalAmpRecrGM_IaIIMnCur.txt')
suboptimalRecTa=np.loadtxt('../Recruitment_data/SuboptimalAmpRecrTa_IaIIMnCur.txt')
self.percFibersIa_GM= optimalRecGm[0]
self.percFibersII_GM= optimalRecGm[1]
self.percMn_GM = optimalRecGm[2]
self.percFibersIa_TA= suboptimalRecTa[0]
self.percFibersII_TA= suboptimalRecTa[1]
self.percMn_TA = suboptimalRecTa[2]
elif amplitude > 0 and amplitude <600:
availableCurrents = np.linspace(0,600,20)
temp = abs(availableCurrents-amplitude)
indx = temp.argmin()
recIa_MG=np.loadtxt('../Recruitment_data/GM_full_S1_wire1')
recII_MG=np.loadtxt('../Recruitment_data/GM_full_ii_S1_wire1')
recMn_MG=np.loadtxt('../Recruitment_data/MGM_full_S1_wire1')
recIa_TA=np.loadtxt('../Recruitment_data/TA_full_S1_wire1')
recII_TA=np.loadtxt('../Recruitment_data/TA_full_ii_S1_wire1')
recMn_TA=np.loadtxt('../Recruitment_data/MTA_full_S1_wire1')
self.percFibersIa_GM= recIa_MG[indx]/max(recIa_MG)
self.percFibersII_GM= recII_MG[indx]/max(recII_MG)
self.percMn_GM = recMn_MG[indx]/max(recMn_MG)
self.percFibersIa_TA= recIa_TA[indx]/max(recIa_TA)
self.percFibersII_TA= recII_TA[indx]/max(recII_TA)
self.percMn_TA = recMn_TA[indx]/max(recMn_TA)
self.set_IA_stim(1,1,0)
self.set_II_stim(1,1,0)
self.set_Mn_stim(1,1,0)
if frequency>0:
self.EES=1
self.EES_freqeqncy = frequency
self.set_IA_stim(self.percFibersIa_TA,self.percFibersIa_GM, self.weightStimAff)
self.set_II_stim(self.percFibersII_TA,self.percFibersII_GM, self.weightStimAff)
self.set_Mn_stim(self.percMn_TA, self.percMn_GM, self.weightStimMn)
# Setting EES frequency
if self.h.PcID ==0:
print "EES set at "+str(self.EES_freqeqncy)+" Hz"
self.set_EES_freq(self.EES_freqeqncy)
print "\t{:.2f}".format(self.percFibersIa_GM*100)+"% of GM Ia fibers, "+"{:.2f}".format(self.percFibersII_GM*100)+"% of GM II fibers and "+"{:.2f}".format(self.percMn_GM*100)+"% of GM Mn receive EES "
print "\t{:.2f}".format(self.percFibersIa_TA*100)+"% of TA Ia fibers, "+"{:.2f}".format(self.percFibersII_TA*100)+"% of TA II fibers and "+"{:.2f}".format(self.percMn_TA*100)+"% of TA Mn receive EES "
"""
MAIN RUNNING LOOP
"""
t_old = self.h.t
j=0
while (self.h.t<self.h.tstop and j< nTIME):
if (self.h.t>=t_old+changeParam-0.01):
self.set_IA_natural_firing(Ia_TA[:,j],Ia_MG[:,j],self.weightStimAff)
self.set_II_natural_firing(II_TA[:,j],II_MG[:,j],self.weightStimAff)
j+=1
t_old = self.h.t
if(self.h.PcID==0):
print "\t{:.1f}".format(self.h.t*100/self.h.tstop)+"%"
self.h.pc.psolve(self.h.t+dt)
"""
EXTRACT EMG
"""
if self.rank==0:
print "\nExctracting cells firings..."
AP_Flex = self.apListToMatrix(self.h.nMN,self.h.AP_MN_Flex)
AP_Ext = self.apListToMatrix(self.h.nMN,self.h.AP_MN_Ext)
if self.rank==0:
firings_Flex = self.extract_firings(AP_Flex, self.h.tstop/1000)
firings_Ext = self.extract_firings(AP_Ext, self.h.tstop/1000)
print "\nComputing EMG signals..."
EMG_Flex = self.synth_EMG(firings_Flex, self.h.tstop/1000)
EMG_Ext = self.synth_EMG(firings_Ext, self.h.tstop/1000)
# plotting
f1, ax_1 = plt.subplots(2, figsize=(20, 9), dpi=80, facecolor='w', edgecolor='k', sharex=True, sharey=True)
ax_1[0].plot(EMG_Flex,'-',label='Flexor EMG')
ax_1[1].plot(EMG_Ext,'-',label='Extensor EMG')
ax_1[0].legend(loc='upper left')
ax_1[1].legend(loc='upper left')
ax_1[0].set_title("Estimated EMG - EES frequency: "+str(frequency)+ "Hz - EES amplitude: "+str(amplitude))
if not os.path.exists('../Results/'):
os.makedirs('../Results/')
f1.savefig("../Results/DynamicSimulation_EMG_EES_fr_"+str(int(frequency))+"Hz_amp_"+str(amplitude)+".pdf")
np.savetxt("../Results/DynamicSimulation_EMG_Flex_EES_fr_"+str(int(frequency))+"Hz_amp_"+str(amplitude)+'.txt',EMG_Flex,delimiter='')
np.savetxt("../Results/DynamicSimulation_EMG_Ext_EES_fr_"+str(int(frequency))+"Hz_amp_"+str(amplitude)+'.txt',EMG_Ext,delimiter='')
plt.show()
# ghater and transform the vector of vector of AP into a matrix in process 0
def apListToMatrix(self,cellType,apList):
nCellxHost=self.computeInd(cellType)
tot_nCellxHost=[]
self.comm.gather(nCellxHost,tot_nCellxHost,root=0)
tot_nCellxHost = self.comm.bcast(tot_nCellxHost,root=0)
hostsWithMoreCells = [i for i, x in enumerate(tot_nCellxHost) if x==max(tot_nCellxHost)]
nApXhost = [apList[z].size() for z in range(int(nCellxHost))]
maxNapXhost = max(nApXhost)
maxNapXhost = self.comm.gather(maxNapXhost,root=0)
maxNap = None
if self.rank == 0:
maxNap = max(maxNapXhost)
maxNap = self.comm.bcast(maxNap,root=0)
apXhost = -1*np.ones([nCellxHost,maxNap])
for z in range(int(nCellxHost)):
for k in range(int(apList[z].size())):
apXhost[z,k]=apList[z].x[k]
if self.sizeComm<=1:
return apXhost
ap = self.comm.gather(apXhost, root=0)
AP = None
if self.rank==0:
if self.sizeComm>1:
AP = np.concatenate([ap[0],ap[1]])
for i in range(2,self.sizeComm):
AP = np.concatenate([AP,ap[i]])
return AP
# extract firings from a matrix/vector of AP event
def extract_firings(self, AP, nSec): # AP in ms
sampling_rate = 5000.
dt = 1000./sampling_rate
firings = np.zeros([AP.shape[0],int(sampling_rate*nSec)])
# check wheter we have more AP for each cell or not
if len(AP.shape)==2:
for i in range(AP.shape[0]):
for ii in range(AP.shape[1]):
for j in range(int(sampling_rate*nSec)):
if AP[i,ii]>=j*dt and AP[i,ii]<(j+1)*dt:
firings[i,j]=1
elif len(AP.shape)==1:
for i in range(AP.shape[0]):
for j in range(int(sampling_rate*nSec)):
if AP[i]>=j*dt and AP[i]<(j+1)*dt:
firings[i,j]=1
return firings
# sythetise the EMG from a firings matrix
def synth_EMG(self, firings, nSec): # AP in ms
sampling_rate = 5000.
dt = 1000./sampling_rate
delay = int(2/dt)
# MUAP duration between 5-10ms (Day et al 2001) -> 7.5 +-1
menaLenMUAP = int(7.5/dt)
stdLenMUAP = int(1/dt)
nS = [int(menaLenMUAP+random.gauss(0,stdLenMUAP)) for i in range(firings.shape[0])]
Amp = [abs(1+random.gauss(0,0.2)) for i in range(firings.shape[0])]
EMG = np.zeros(sampling_rate*nSec+ max(nS)+delay);
# create MUAP shape
for i in range(firings.shape[0]):
n40perc = int(nS[i]*0.4)
n60perc = nS[i]-n40perc
amplitudeMod = (1-(np.linspace(0,1,nS[i])**2)) * np.concatenate((np.ones(n40perc),1/np.linspace(1,3,n60perc)))
logBase = 1.05
freqMod = np.log(np.linspace(1,logBase**(4*np.pi),nS[i]))/np.log(logBase)
EMG_unit = Amp[i]*amplitudeMod*np.sin(freqMod);
for j in range(int(sampling_rate*nSec)):
if firings[i,j]==1:
EMG[j+delay:j+delay+nS[i]]=EMG[j+delay:j+delay+nS[i]]+EMG_unit
return EMG[:sampling_rate*nSec]
# srecruitment curve
def computeRecruitCurve(self,network="extensor"):
if network=="extensor":
#loading the number of IA and II fibers activated at a given current from the FEM model results
Ia_nAct=np.loadtxt('../Recruitment_data/GM_full_S1_wire1')
II_nAct=np.loadtxt('../Recruitment_data/GM_full_ii_S1_wire1')
Mn_nAct=np.loadtxt('../Recruitment_data/MGM_full_S1_wire1')
startCurr = 5
if(self.h.PcID==0):
print "Extensor h-reflex computation"
else:
#loading the number of IA and II fibers activated at a given current from the FEM model results
Ia_nAct=np.loadtxt('../Recruitment_data/TA_full_S1_wire1')
II_nAct=np.loadtxt('../Recruitment_data/TA_full_ii_S1_wire1')
Mn_nAct=np.loadtxt('../Recruitment_data/MTA_full_S1_wire1')
startCurr = 7
network = "flexor"
if(self.h.PcID==0):
print "Flexor h-reflex computation"
MnInd = int(self.computeInd(self.h.nMN))
IafInd = int(self.computeInd(self.h.nIAf))
IIfInd = int(self.computeInd(self.h.nIIf))
IaiInd = int(self.computeInd(self.h.nIAint))
self.set_EES_freq(8)
ampResponse_Early = np.zeros(Ia_nAct.size)
ampResponse_MediumLate = np.zeros(Ia_nAct.size)
currAmp = range(startCurr,Ia_nAct.size)
low_curr = 9
high_curr = 13
for jj in currAmp:
# set EES to 0 on both Ia e II f
self.set_IA_stim(1,1, 0)
self.set_II_stim(1,1, 0)
self.set_Mn_stim(1,1, 0)
percFibersIa= Ia_nAct[jj]/max(Ia_nAct) # / n cells in FEM model
percFibersII= II_nAct[jj]/max(II_nAct)
percMn = Mn_nAct[jj]/max(Mn_nAct)
if network=="extensor":
self.set_IA_stim(0,percFibersIa, self.weightStimAff)
self.set_II_stim(0,percFibersII, self.weightStimAff)
self.set_Mn_stim(0,percMn, self.weightStimMn)
else:
self.set_IA_stim(percFibersIa, 0, self.weightStimAff)
self.set_II_stim(percFibersII, 0, self.weightStimAff)
self.set_Mn_stim(percMn, 0, self.weightStimMn)
if(self.h.PcID==0):
print "\nComputing the response on the "+network+"s MNs due to:"
print "{:.2f}".format(percFibersIa*100) + "% of the population of Ia fibers"
print "{:.2f}".format(percFibersII*100) + "% of the population of II fibers"
print "{:.2f}".format(percMn*100) + "% of the population of Mn cells\n"
self.h.finitialize(0)
self.h.stdinit()
self.h.t = 0
dt=0.025 # simulation step
self.h.tstop = 160 # Length of simulation
self.h.pc.psolve(self.h.t+120) # remove initialization effects
AP_MN_init = np.zeros(MnInd)
AP_Iaf_init = np.zeros(IafInd)
AP_IIf_init = np.zeros(IIfInd)
if network=="extensor":
for i in range(int(MnInd)):
AP_MN_init[i] = self.h.AP_MN_Ext[i].size()
for i in range(int(IafInd)):
AP_Iaf_init[i] = self.h.AP_IA_Ext[i].size()
for i in range(int(IIfInd)):
AP_IIf_init[i] = self.h.AP_II_Ext[i].size()
else :
for i in range(int(MnInd)):
AP_MN_init[i] = self.h.AP_MN_Flex[i].size()
for i in range(int(IafInd)):
AP_Iaf_init[i] = self.h.AP_IA_Flex[i].size()
for i in range(int(IIfInd)):
AP_IIf_init[i] = self.h.AP_II_Flex[i].size()
AP_MN = np.zeros((int(40/dt),MnInd))
AP_Iaf = np.zeros((int(40/dt),IafInd))
AP_IIf = np.zeros((int(40/dt),IIfInd))
if(self.h.PcID==0):
print "Initialization completed...\n "
time = np.zeros(int(40/dt))
count =0
while (self.h.t<self.h.tstop and count < int(40/dt)):
self.h.pc.psolve(self.h.t+dt)
if network=="extensor":
for i in range(int(MnInd)):
AP_MN[count,i] = self.h.AP_MN_Ext[i].size() - AP_MN_init[i]
AP_MN_init[i] = self.h.AP_MN_Ext[i].size()
for i in range(int(IafInd)):
AP_Iaf[count,i] = self.h.AP_IA_Ext[i].size() - AP_Iaf_init[i]
AP_Iaf_init[i] = self.h.AP_IA_Ext[i].size()
for i in range(int(IIfInd)):
AP_IIf[count,i] = self.h.AP_II_Ext[i].size() - AP_IIf_init[i]
AP_IIf_init[i] = self.h.AP_II_Ext[i].size()
else :
for i in range(int(MnInd)):
AP_MN[count,i] = self.h.AP_MN_Flex[i].size() - AP_MN_init[i]
AP_MN_init[i] = self.h.AP_MN_Flex[i].size()
for i in range(int(IafInd)):
AP_Iaf[count,i] = self.h.AP_IA_Flex[i].size() - AP_Iaf_init[i]
AP_Iaf_init[i] = self.h.AP_IA_Flex[i].size()
for i in range(int(IIfInd)):
AP_IIf[count,i] = self.h.AP_II_Flex[i].size() - AP_IIf_init[i]
AP_IIf_init[i] = self.h.AP_II_Flex[i].size()
if(self.h.PcID==0):
time[count] = self.h.t
count+=1
AP_MN = self.comm.gather(AP_MN, root=0)
AP_Iaf = self.comm.gather(AP_Iaf, root=0)
AP_IIf = self.comm.gather(AP_IIf, root=0)
if(self.h.PcID==0):
if self.sizeComm>1:
# gathering the pot in one array
Global_AP_MN = np.concatenate((AP_MN[0],AP_MN[1]),axis=1)
for i in range(2,self.sizeComm):
Global_AP_MN = np.concatenate((Global_AP_MN ,AP_MN[i]),axis=1)
Global_AP_Iaf = np.concatenate((AP_Iaf[0],AP_Iaf[1]),axis=1)
for i in range(2,self.sizeComm):
Global_AP_Iaf = np.concatenate((Global_AP_Iaf ,AP_Iaf[i]),axis=1)
Global_AP_IIf = np.concatenate((AP_IIf[0],AP_IIf[1]),axis=1)
for i in range(2,self.sizeComm):
Global_AP_IIf = np.concatenate((Global_AP_IIf ,AP_IIf[i]),axis=1)
else:
Global_AP_MN = AP_MN[0]
Global_AP_Iaf= AP_Iaf[0]
Global_AP_IIf= AP_IIf[0]
IndexEES_Mn = np.concatenate((np.nonzero(Global_AP_MN[:,0]),np.nonzero(Global_AP_MN[:,1])), axis=1)
for i in range(2,int(self.h.nMN)):
IndexEES_Mn = np.concatenate((IndexEES_Mn,np.nonzero(Global_AP_MN[:,i])), axis=1)
IndexEES_Mn=np.extract(IndexEES_Mn>1/dt,IndexEES_Mn)
IndexEES_Iaf = np.concatenate((np.nonzero(Global_AP_Iaf[:,0]),np.nonzero(Global_AP_Iaf[:,1])), axis=1)
for i in range(2,int(self.h.nIAf)):
IndexEES_Iaf = np.concatenate((IndexEES_Iaf,np.nonzero(Global_AP_Iaf[:,i])), axis=1)
IndexEES_Iaf=np.extract(IndexEES_Iaf>1/dt,IndexEES_Iaf)
if IndexEES_Mn.size >0:
EESpulseTime = np.mean(IndexEES_Iaf)*dt
IndexMn_Early=np.extract(IndexEES_Mn<(EESpulseTime+3.5)/dt,IndexEES_Mn) # first EES pulse at 10 ms and second at 135 (15th ms recorded)
IndexMn_MediumLate=np.extract(IndexEES_Mn>=(EESpulseTime+3.5)/dt,IndexEES_Mn)
ampResponse_Early[jj] = IndexMn_Early.shape[0]/self.h.nMN
ampResponse_MediumLate[jj] = IndexMn_MediumLate.shape[0]/self.h.nMN
# to create the EMG responses
if jj==low_curr:
EESpulseTime_low = EESpulseTime
IndexMn_Early_low = IndexMn_Early
IndexMn_MediumLate_low = IndexMn_MediumLate
elif jj==high_curr:
EESpulseTime_high = EESpulseTime
IndexMn_Early_high = IndexMn_Early
IndexMn_MediumLate_high = IndexMn_MediumLate
print "\nThe amplitude of the early response is: " + "{:.2f}".format(ampResponse_Early[jj]*100) + "% "
print "The amplitude of the medium-late response is: " + "{:.2f}".format(ampResponse_MediumLate[jj]*100) + "% \n"
if(self.h.PcID==0) :
# find current threshold
current = np.linspace(0,600,Ia_nAct.size)
tck = interpolate.splrep(current, ampResponse_MediumLate, s=0)
current_precise = np.linspace(0,600,current.size*5)
ampResp_ML_precise = interpolate.splev(current_precise,tck,der=0)
for i in range(ampResp_ML_precise.size):
if ampResp_ML_precise[i]<0:ampResp_ML_precise[i]=0
temp = abs(ampResp_ML_precise[:ampResp_ML_precise.argmax()]-0.1)
threshold = current_precise[temp.argmin()]
curr_thresholds = current/threshold
low_curr_thr = curr_thresholds[low_curr]
high_curr_thr = curr_thresholds[high_curr]
firings = self.extract_firings(IndexMn_MediumLate_low*dt, 0.04)
EMG_ML_low = self.synth_EMG(firings, 0.04)
firings = self.extract_firings(IndexMn_MediumLate_high*dt, 0.04)
EMG_ML_high = self.synth_EMG(firings, 0.04)
firings = self.extract_firings(IndexMn_Early_low*dt, 0.04)
EMG_E_low = self.synth_EMG(firings, 0.04)
firings = self.extract_firings(IndexMn_Early_high*dt, 0.04)
EMG_E_high = self.synth_EMG(firings, 0.04)
f1, ax_1 = plt.subplots(2, figsize=(20, 9), dpi=80, facecolor='w', edgecolor='k', sharex=True, sharey=True)
ax_1[0].plot(EMG_E_low+EMG_ML_low,'-',label='EMG response - '+"{:.1f}".format(low_curr_thr)+'x motor thr')
ax_1[0].plot([EESpulseTime_low*5,EESpulseTime_low*5],[0,50],'k',label = 'Stimulation pulse',linewidth=3.0)
ax_1[0].axvspan(EESpulseTime_low*5, (EESpulseTime_low+3.5)*5, color='r', alpha=0.25, lw=0,label='Early response')
ax_1[0].axvspan((EESpulseTime_low+3.5)*5,200, color='g', alpha=0.25, lw=0,label='Medium-late response')
ax_1[1].plot(EMG_E_high+EMG_ML_high,'-',label='EMG response - '+"{:.1f}".format(high_curr_thr)+'x motor thr')
ax_1[1].plot([EESpulseTime_high*5,EESpulseTime_high*5],[0,50],'k',label = 'Stimulation pulse',linewidth=3.0)
ax_1[1].axvspan(EESpulseTime_high*5, (EESpulseTime_high+3.5)*5, color='r', alpha=0.25, lw=0,label='Early response')
ax_1[1].axvspan((EESpulseTime_high+3.5)*5,200, color='g', alpha=0.25, lw=0,label='Medium-late response')
ax_1[0].legend(loc='upper left')
ax_1[1].legend(loc='upper left')
ax_1[0].set_title("EMG responses - " + network + " network")
f2 = plt.figure()
ax_2 = f2.add_subplot(1,1,1)
ax_2.plot(curr_thresholds,ampResponse_MediumLate,label='Medium-late response')
ax_2.plot(curr_thresholds,ampResponse_Early,label='Early response')
ax_2.legend(loc='upper left')
ax_2.set_title("Recruitment curve - " + network + " network")
ax_2.set_xlabel("Stimulation amp (x motor thr)")
ax_2.set_ylabel("Response amplitude")
ax_2.grid(True)
if not os.path.exists('../Results/'):
os.makedirs('../Results/')
f1.savefig("../Results/EMG_responses"+network+".pdf")
f2.savefig("../Results/Recruitment_curve"+network+".pdf")
np.savetxt('../Results/current.txt',current,delimiter='')
np.savetxt('../Results/ampResponse_ER_'+network+'.txt',ampResponse_Early,delimiter='')
np.savetxt('../Results/ampResponse_MLR_'+network+'.txt',ampResponse_MediumLate,delimiter='')
plt.show()