#!/usr/bin/env python
'''
BGTCS MFM
Usage:
mfm [options] [<key>=<value>]...
Options:
-l --list List all options
-h --help Show this screen
'''
import numpy as np
import pickle
import sys
import os
from docopt import docopt
from tabulate import tabulate
from utils import progbar
from swift import aswift
from dbs import cDBS, pDBS
class MFM(object):
def __init__(self,**kwargs):
self.t = 0 #internal time counter
self.i = 0 #internal index
self._load_params(kwargs)
self._set_MFM_params()
self._set_DBS()
self.S = np.zeros((self.params['N'],20))
self.S[0,:] = [ 43.74102506, -1.15197439, 6.96276347, -22.25852135, 7.19671392,
-28.57548512, 17.26916297, 132.89911127, 9.71319243, 67.69191101,
9.57769785, -21.11198645, 5.33943222, -22.62375016, 0.54172422,
-22.23467637, 6.76173506, 143.5386694, 7.49756915, 23.59983148]
self.swift = aswift(tau_s = 1./self.params['swift_f'] * self.params['swift_c'],
tau_f = 1./self.params['swift_f'] * self.params['swift_c'] / self.params['swift_s2f'],
f = self.params['swift_f'],
fs = self.params['fs'])
self.memory = {
'amp' : np.zeros(self.params['N']),
'phase' : np.zeros(self.params['N']),
'stim' : np.zeros(self.params['N']),
}
def __str__(self):
general = ('Run Info\n'+
'--------\n'+
'length : {0} s\n'
'DD : {1}\n'
'cDBS : {2}\n'
'pDBS : {3}\n')\
.format(self.params['tstop'],self.params['DD'],self.params['cDBS'],self.params['pDBS'])
SWIFT = ('\nSWIFT Parameters\n'
'----------------\n'
'f : {:0.1f} Hz\n'
'tau_s : {:0.4f} s\n'
'tau_f : {:0.4f} s\n')\
.format(*[self.params[key] for key in ['swift_f','swift_tau_s','swift_tau_f']])
if self.params['cDBS']:
cDBS =('\ncDBS Parameters\n'
'---------------\n'
'frequency : {} Hz\n'
'amplitude : {} mA\n'
'pulse width : {} us')\
.format(*[self.params[key] for key in ['cDBS_f','cDBS_amp','cDBS_width']])
else: cDBS=''
if self.params['pDBS']:
pDBS=('\npDBS Parameters\n'
'---------------\n'
'phase thr : {} rad\n'
'stim amp : {} mA\n'
'power thr : {} dB\n'
'ref period : {}\n')\
.format(*[self.params[key] for key in ['pDBS_phase','pDBS_amp','pDBS_power_thr','pDBS_ref_period']])
else: pDBS=''
return general+SWIFT+cDBS+pDBS
def _load_params(self,kwargs):
def process_kwargs(kwargs):
for key,value in kwargs.items():
if key not in self.params:
print('Invalid keyword argument {}'.format(key))
else:
value = type(self.params[key])(value) #cast value to that of params[key]
self.params[key] = value
self.params = {}
self.params['verbose'] = True
#General parameters
self.params['dt'] = 1e-3 # s
self.params['stim_start'] = 0.0 # s
self.params['tstop'] = 50.0 # s
self.params['RunID'] = -1
#DD parameters
self.params['DD'] = True
#Stimulation parameters
self.params['stim_target'] = 'STN'
self.params['Cm'] = 1e-4 # F
#cDBS parameters
self.params['cDBS'] = False
self.params['cDBS_f'] = 130. # (Hz)
self.params['cDBS_amp'] = 3.0 # (mV)
self.params['cDBS_width'] = 60. # (us)
#pDBS parameters
self.params['pDBS'] = False
self.params['pDBS_phase'] = 2.24 # rad
self.params['pDBS_amp'] = 2.38 # rad
self.params['pDBS_width'] = 60 # us
self.params['pDBS_ref_period'] = 0.3 # s
self.params['pDBS_power_thr'] = -28.57 # dB
#SWIFT params
self.params['state_target'] = 'p1'
self.params['swift_f'] = 29 # Hz
self.params['swift_tau_s'] = 0.2397 # s - Overrides swift_c when set
self.params['swift_c'] = 10 # unitless - number of cycles per tau
self.params['swift_s2f'] = 5 # unitless - ratio of tau_s to tau_f
self._options = self.params.copy()
process_kwargs(kwargs)
#Set parameters dependent on other parameters
self.params['N'] = int(np.ceil(self.params['tstop']/self.params['dt']))
self.params['fs'] = 1./self.params['dt']
if self.params['swift_tau_s'] is None:
self.params['swift_tau_s'] = 1./self.params['swift_f'] * self.params['swift_c']
self.params['swift_tau_f'] = self.params['swift_tau_s'] / self.params['swift_s2f']
def _set_MFM_params(self):
self.phin = 15
self.noiseAmp = 0.03
self.re = 80 #mm
self.gammae = 125 #s^-1
self.alpha = 160 #s^-1
self.beta = 640 #s^-1
self.gammasq=self.gammae^2
self.alphabeta=self.alpha*self.beta
self.aPb = (self.alpha+self.beta)#/alphabeta
self.alphagamma=self.alpha*self.gammae
#Axonal Delays: from second to first (1st population type is postsynaptic)
#(ms)
self.taues = int(0.035/self.params['dt'])
self.tauis = int(0.035/self.params['dt'])
self.taud1e = int(0.002/self.params['dt'])
self.taud2e = int(0.002/self.params['dt'])
self.taud1s = int(0.002/self.params['dt'])
self.taud2s = int(0.002/self.params['dt'])
self.taup1d1 = int(0.001/self.params['dt'])
self.taup1p2 = int(0.001/self.params['dt'])
self.taup1ST = int(0.001/self.params['dt'])
self.taup2d2 = int(0.001/self.params['dt'])
self.taup2ST = int(0.001/self.params['dt'])
self.tauSTe = int(0.001/self.params['dt'])
self.tauSTp2 = int(0.001/self.params['dt'])
self.tause = int(0.050/self.params['dt'])
self.taure = int(0.050/self.params['dt'])
self.tausp1 = int(0.003/self.params['dt'])
self.tausr = int(0.002/self.params['dt'])
self.taurs = int(0.002/self.params['dt'])
self.taud2d1 = int(0.001/self.params['dt'])
#Threshold spread (mV)
self.sigmaprime = 3.8
#Connection strength (mVs)
self.vee = 1.6 #1.6 1.4 is DD
self.vie = 1.6 #1.6 1.4 is DD
self.vei = -1.9 #-1.9-1.6 is DD
self.vii = -1.9 #-1.9 -1.6 is DD
self.ves = 0.4
self.vis = 0.4
self.vd1e = 1.0 #1 #1 is normal, .5 is DD
self.vd1d1 = -0.3
self.vd1s = 0.1
self.vd2e = 0.7 #0.7 1.4 is DD
self.vd2d2 = -0.3
self.vd2s = 0.05
self.vp1d1 = -0.1
self.vp1p2 = -0.03
self.vp1ST = 0.3
self.vp2d2 = -0.3 #-0.3 -0.5 is DD
self.vp2p2 = -0.1 #-0.1 -0.07 is DD
self.vp2ST = 0.3
self.vSTe = 0.1
self.vSTp2 = -0.04
self.vse = 0.8
self.vsp1 = -0.03
self.vsr = -0.4
self.vsn = 0.5
self.vre = 0.15
self.vrs = 0.03
self.vd2d1 = 0
#Sigmoids
#Maximum Firing Rates (s^-1)
self.Qe = 300
self.Qi = 300
self.Qd1 = 65
self.Qd2 = 65
self.Qp1 = 250
self.Qp2 = 300
self.QST = 500
self.Qs = 300
self.Qr = 500
#Firing Thresholds (mV)
self.thetae = 14
self.thetai = 14
self.thetad1 = 19
self.thetad2 = 19
self.thetap1 = 10
self.thetap2 = 9 #9 8 is DD
self.thetaST = 10 #10 9 is DD
self.thetas = 13
self.thetar = 13
self.phie = 0
self.phie_dot = 1
self.Ve = 2
self.Ve_dot = 3
self.Vi = 4
self.Vi_dot = 5
self.Vd1 = 6
self.Vd1_dot = 7
self.Vd2 = 8
self.Vd2_dot = 9
self.Vp1 = 10
self.Vp1_dot = 11
self.Vp2 = 12
self.Vp2_dot = 13
self.VST = 14
self.VST_dot = 15
self.Vs = 16
self.Vs_dot = 17
self.Vr = 18
self.Vr_dot = 19
self.struct={'e' :2,
'i' :4,
'd1' :6,
'd2' :8,
'p1' :10,
'p2' :12,
'STN' :14,
's' :16,
'r' :18}
#Set parkinsonian parameters
if self.params['DD']:
self.vee = 1.4
self.vie = 1.4
self.vei = -1.6
self.vii = -1.6
self.vd1e = 0.5
self.vd2e = 1.4
self.vp2d2 = -0.5
self.vp2p2 = -0.07
self.thetap2 = 8
self.thetaST = 9
def _set_DBS(self):
if self.params['cDBS']:
self.cDBS = cDBS(dt = self.params['dt'],
f = self.params['cDBS_f'],
stim_amp = self.params['cDBS_amp'],
width = self.params['cDBS_width'],
tstart = self.params['stim_start'])
else:
self.cDBS = None
self.pDBS = pDBS(dt = self.params['dt'],
f = self.params['swift_f'],
tau_s = self.params['swift_tau_s'],
tau_f = self.params['swift_tau_f'],
phase_thr = self.params['pDBS_phase'],
ref_period = self.params['pDBS_ref_period'],
stim_amp = self.params['pDBS_amp'],
width = self.params['pDBS_width'],
power_thr = self.params['pDBS_power_thr'])
def advance(self):
def sigmoid(V,Q,theta):
return Q/(1+np.exp(-(V-theta)/3.8))
i = self.i
dSdt = np.zeros(20)
dSdt[self.phie]= self.S[i,self.phie_dot]
dSdt[self.phie_dot] = self.gammasq*(sigmoid(self.S[i,self.Ve], self.Qe, self.thetae)-\
self.S[i,self.phie])-2*self.gammae*self.S[i,self.phie_dot]
dSdt[self.Ve] = self.S[i,self.Ve_dot]
dSdt[self.Ve_dot] = self.alphagamma*(self.vee*self.S[i,self.phie]+\
self.vei*sigmoid(self.S[i,self.Vi],self.Qi, self.thetai)+\
self.ves*sigmoid(self.S[i-self.taues,self.Vs], self.Qs, self.thetas)-\
self.S[i,self.Ve])-\
self.aPb*self.S[i,self.Ve_dot]
dSdt[self.Vi] = self.S[i,self.Vi_dot]
dSdt[self.Vi_dot] = self.alphagamma*(self.vii*sigmoid(self.S[i,self.Vi], self.Qi, self.thetai)+\
self.vie*self.S[i,self.phie]+\
self.vis*sigmoid(self.S[i-self.tauis,self.Vs], self.Qs, self.thetas)-\
self.S[i,self.Vi])-\
self.aPb*self.S[i,self.Vi_dot]
dSdt[self.Vd1] = self.S[i,self.Vd1_dot]
dSdt[self.Vd1_dot] = self.alphabeta*(self.vd1e*self.S[i-self.taud1e,self.phie]+\
self.vd1s*sigmoid(self.S[i-self.taud1s,self.Vs], self.Qs, self.thetas)+\
self.vd1d1*sigmoid(self.S[i,self.Vd1], self.Qd1, self.thetad1)-\
self.S[i,self.Vd1])-\
self.aPb*self.S[i,self.Vd1_dot] #Add in SNc
dSdt[self.Vd2] = self.S[i,self.Vd2_dot]
dSdt[self.Vd2_dot] = self.alphabeta*(self.vd2e*self.S[i-self.taud2e, self.Ve]+\
self.vd2d1*sigmoid(self.S[i-self.taud2d1,self.Vd1], self.Qd1, self.thetad1)+\
self.vd2s*sigmoid(self.S[i-self.taud2s,self.Vs], self.Qs, self.thetas)+\
self.vd2d2*sigmoid(self.S[i,self.Vd2], self.Qd2, self.thetad2)-\
self.S[i,self.Vd2])-\
self.aPb*self.S[i,self.Vd2_dot] #Add in the SNc
dSdt[self.Vp1] = self.S[i,self.Vp1_dot]
dSdt[self.Vp1_dot] = self.alphabeta*(self.vp1d1*sigmoid(self.S[i-self.taup1d1,self.Vd1], self.Qd1, self.thetad1)+\
self.vp1p2*sigmoid(self.S[i-self.taup1p2,self.Vp2], self.Qp2, self.thetap2)+\
self.vp1ST*sigmoid(self.S[i-self.taup1ST,self.VST], self.QST, self.thetaST)-\
self.S[i,self.Vp1])-\
self.aPb*self.S[i,self.Vp1_dot]
dSdt[self.Vp2] = self.S[i,self.Vp2_dot]
dSdt[self.Vp2_dot] = self.alphabeta*(self.vp2d2*sigmoid(self.S[i-self.taup2d2,self.Vd2], self.Qd2, self.thetad2)+\
self.vp2p2*sigmoid(self.S[i,self.Vp2], self.Qp2, self.thetap2)+\
self.vp2ST*sigmoid(self.S[i-self.taup2ST,self.VST], self.QST, self.thetaST)-\
self.S[i,self.Vp2])-\
self.aPb*self.S[i,self.Vp2_dot]
dSdt[self.VST] = self.S[i,self.VST_dot]
dSdt[self.VST_dot] = self.alphabeta*(self.vSTp2*sigmoid(self.S[i-self.tauSTp2,self.Vp2], self.Qp2, self.thetap2)+\
self.vSTe*self.S[i-self.tauSTe,self.phie]-\
self.S[i,self.VST])-\
self.aPb*self.S[i,self.VST_dot]
dSdt[self.Vs] = self.S[i,self.Vs_dot]
dSdt[self.Vs_dot] = self.alphabeta*(self.vsp1*sigmoid(self.S[i-self.tausp1,self.Vp1], self.Qp1, self.thetap1)+\
self.vse*self.S[i-self.tause,self.phie]+\
self.vsr*sigmoid(self.S[i-self.tausr,self.Vr], self.Qr, self.thetar)+\
self.phin-\
self.S[i,self.Vs])-\
self.aPb*self.S[i,self.Vs_dot]
dSdt[self.Vr] = self.S[i,self.Vr_dot]
dSdt[self.Vr_dot] = self.alphabeta*(self.vre*self.S[i-self.taure,self.phie]+\
self.vrs*sigmoid(self.S[i-self.taurs,self.Vs], self.Qs, self.thetas)-\
self.S[i,self.Vr])-\
self.aPb*self.S[i,self.Vr_dot]
#DBS
#====================================================================================
if self.params['cDBS']:
cDBS_C = self.cDBS.advance()
self.S[i,self.struct[self.params['stim_target']]] += cDBS_C/self.params['Cm']
if cDBS_C != 0: self.memory['stim'][i+1] = cDBS_C
else:
pDBS_C = self.pDBS.advance(self.S[i,self.struct[self.params['state_target']]])
if self.params['pDBS']:
self.S[i,self.struct[self.params['stim_target']]] += pDBS_C/self.params['Cm']
if pDBS_C != 0: self.memory['stim'][i+1] = pDBS_C
self.memory['amp'][i+1] = self.pDBS.amp
self.memory['phase'][i+1] = self.pDBS.phase
#Advance
#====================================================================================
self.S[i+1,:] = self.S[i,:]+self.params['dt']*dSdt
#Noise
#====================================================================================
self.S[i+1,self.Ve] += self.noiseAmp*np.random.normal(0,1)*np.sqrt(self.params['dt'])*self.Qe*(1-sigmoid(self.S[i+1,self.Ve], 1, self.thetae))*sigmoid(self.S[i,self.Ve], 1, self.thetae)
self.S[i+1,self.Vi] += self.noiseAmp*np.random.normal(0,1)*np.sqrt(self.params['dt'])*self.Qi*(1-sigmoid(self.S[i+1,self.Vi], 1, self.thetai))*sigmoid(self.S[i,self.Vi], 1, self.thetai)
self.S[i+1,self.Vd1] += self.noiseAmp*np.random.normal(0,1)*np.sqrt(self.params['dt'])*self.Qd1*(1-sigmoid(self.S[i+1,self.Vd1], 1, self.thetad1))*sigmoid(self.S[i,self.Vd1], 1, self.thetad1)
self.S[i+1,self.Vd2] += self.noiseAmp*np.random.normal(0,1)*np.sqrt(self.params['dt'])*self.Qd2*(1-sigmoid(self.S[i+1,self.Vd2], 1, self.thetad2))*sigmoid(self.S[i,self.Vd2], 1, self.thetad2)
self.S[i+1,self.Vp1] += self.noiseAmp*np.random.normal(0,1)*np.sqrt(self.params['dt'])*self.Qp1*(1-sigmoid(self.S[i+1,self.Vp1], 1, self.thetap1))*sigmoid(self.S[i,self.Vp1], 1, self.thetap1)
self.S[i+1,self.Vp2] += self.noiseAmp*np.random.normal(0,1)*np.sqrt(self.params['dt'])*self.Qp2*(1-sigmoid(self.S[i+1,self.Vp2], 1, self.thetap2))*sigmoid(self.S[i,self.Vp2], 1, self.thetap2)
self.S[i+1,self.VST] += self.noiseAmp*np.random.normal(0,1)*np.sqrt(self.params['dt'])*self.QST*(1-sigmoid(self.S[i+1,self.VST], 1, self.thetaST))*sigmoid(self.S[i,self.VST], 1, self.thetaST)
self.S[i+1,self.Vs] += self.noiseAmp*np.random.normal(0,1)*np.sqrt(self.params['dt'])*self.Qs*(1-sigmoid(self.S[i+1,self.Vs], 1, self.thetas))*sigmoid(self.S[i,self.Vs], 1, self.thetas)
self.S[i+1,self.Vr] += self.noiseAmp*np.random.normal(0,1)*np.sqrt(self.params['dt'])*self.Qr*(1-sigmoid(self.S[i+1,self.Vr], 1, self.thetar))*sigmoid(self.S[i,self.Vr], 1, self.thetar)
self.i += 1
def run(self):
#if self.params['verbose']: self.progbar = ProgressBar()
if self.params['verbose']: self.progbar = progbar()
while self.i < self.params['N'] - 1:
self.advance()
#if self.params['verbose']: self.progbar.display(float(self.i)/(self.params['N']-2))
if self.params['verbose']: self.progbar.update(float(self.i)/(self.params['N']-2))
if self.params['verbose']: print()
def save(self,fname=None):
if fname == None:
if not os.path.isdir('data'):
os.makedirs('data')
if self.params['RunID'] == -1:
try:
ls = os.listdir('data')
lowest_empty = 0
found_empty = False
while not found_empty:
found_empty = True
for fname in ls:
try: fnum = int(fname[:3])
except: fnum = -1
if fnum == lowest_empty:
lowest_empty += 1
found_empty = False
self.params['RunID'] = lowest_empty
except:
self.params['RunID'] = 0
print('\nSaving data...\n RunID: {0:03d}'.format(self.params['RunID']))
fname = 'data/{0:03d}.mfm'.format(self.params['RunID'])
else:
print('\nSaving data...\n {}'.format(fname))
pickle.dump(self.__dict__,open(fname,'wb'))
def load(self,fname):
self.__dict__.update(pickle.load(open(fname,'rb')))
def plot(self,PSD_seg=0.5):
from scipy import signal
import matplotlib as mpl
import matplotlib.pyplot as plt
t = np.arange(self.params['N']) * self.params['dt']
x = self.S[:,self.struct[self.params['state_target']]][int(self.i * PSD_seg):]
f,Pxx = signal.welch(self.S[:,self.struct[self.params['state_target']]],1/self.params['dt'],nperseg=2048)
Pxx = 10*np.log10(Pxx)
fig,ax = plt.subplots()
ax.plot(f[f<100],Pxx[f<100])
ax.set_ylabel('PSD (dB/Hz)')
ax.set_xlabel('Frequency (Hz)')
plt.tight_layout()
#if self.params['pDBS']:
fig,ax = plt.subplots(4,1,sharex=True)
ax[0].plot(t,self.S[:,self.struct[self.params['state_target']]],label='state')
ax[0].plot(t,self.S[:,self.struct[self.params['stim_target']]], label='stim')
ax[1].plot(t,self.memory['amp']) #self.pDBS.mem['amp'])
ax[2].plot(t,self.memory['phase']) #self.pDBS.mem['phase'])
ax[3].vlines(t[self.memory['stim'] > 0], 0, 1)
#ax[3].plot(t,self.memory['stim']) #self.pDBS.mem['stim'])
ax[0].legend()
ax[0].set_ylabel('V')
ax[1].set_ylabel('Power (dB)')
ax[2].set_ylabel('Phase')
ax[3].set_ylabel('Stim')
ax[3].set_xlabel('Time (s)')
plt.tight_layout()
plt.show()
@property
def options(self):
return self._options
def main():
def parse_kwargs(kwargs):
args = {}
for arg in kwargs:
key,value = arg.split('=')
if value.lower() == 'false': value = False
elif value.lower() == 'true' : value = True
else:
try: value = int(value)
except:
try: value = float(value)
except:
pass
args[key] = value
return args
args = docopt(__doc__)
if args['--list']:
print('Available options:')
headers = ['Option', 'Default']
data = sorted([(k,v) for k,v in MFM().options.items()])
print(tabulate(data, headers=headers))
print('\nFor more details, such as units, etc, look at the source of MFM._load_params().')
sys.exit()
kwargs = parse_kwargs(args['<key>=<value>'])
mfm = MFM(**kwargs)
print(mfm)
mfm.run()
mfm.save()
if __name__ == '__main__':
main()