import os
os.environ['MKL_NUM_THREADS'] = "1"
import logging
from typing import List, Tuple, Dict, Union
import numpy as np
from collections import OrderedDict
from scipy.integrate import solve_ivp
import time
class AlonsoMarderModel(object):
"""
AlonsoMarderModel provides a class for the AlonsoMarder neuronal modal provided in:
https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6395073/
"""
def __init__(self, injected_current=None, initial_conditions=None, reversal_potentials=None, conductances=None,
inf_alphas=None, inf_betas=None, tau_cs=None, tau_ds=None, tau_as=None, tau_bs=None,
tau_a2s=None, tau_b2s=None, tau_ca=None, spike_threshold=None):
self.voltage_trace = None
self.spike_times = None
self.time_steps = None
self.constants = {
'reftemp_celsius': 10.0, # degC, reference temperature
'temp_celsius': 10.0, # degC, temperature
'gas_constant': 8.314 * pow(10, 3), # Ideal Gas Constant (*10^3 to put into mV)
'base_temp_kelvin': 273.15, # Temperature in Kelvin
'z': 2.0, # Valence of Caclium Ions
'faraday': 96485.33, # Faraday's Constant
'ca_conc_extracellular': 3000.0, # Outer Ca Concentration (uM)
'ca_conv_factor': 0.94, # Outer Ca Concentration (uM)
'ca_conc_background': 0.05, # Outer Ca Concentration (uM)
'tau_ca_conc_intracellular': tau_ca or 6.535e2, # Outer Ca Concentration (uM)
'C': 10.0, # nF / cm^2, membrane capacitance
'I_app': injected_current or 0.0, # nA, externally-applied current
}
self.initial_conditions = initial_conditions or OrderedDict({
'V': -51., # mV, membrane voltage
'm_Na': 0., # sodium activation variable
'h_Na': 0., # sodium inactivation variable
'm_CaT': 0., # low-threshold calcium activation variable
'h_CaT': 0., # low-threshold calcium inactivation variable
'm_CaS': 0., # slow calcium activation variable
'h_CaS': 0., # slow calcium inactivation variable
'm_H': 0., # hyperpolarization-activated cation activation variable
'h_H': 1., # hyperpolarization-activated cation inactivation variable
'm_Kd': 0., # potassium activation variable
'h_Kd': 1., # potassium inactivation variable
'm_KCa': 0., # mV, calcium-dependent potassium activation variable
'h_KCa': 1., # mV, calcium-dependent potassium inactivation variable
'm_A': 0., # mV, transient potassium activation variable
'h_A': 0., # mV, transient potassium inactivation variable
'm_L': 1., # leak channel activation variable
'h_L': 1., # leak channel inactivation variable
'ca_conc_intracellular': 5., # uM, intracellular Ca concentration initial condition
})
self.spike_threshold = spike_threshold or {
'spike_threshold': self.initial_conditions["V"] + 15.0,
'threshold_spike': -15.,
'threshold_mid_upper': -35.,
'threshold_mid_lower': -45.,
'threshold_slow_wave': -49.999,
}
self.reversal_potentials = reversal_potentials or {
'E_L': -50., # mV, leak reversal potential
'E_Na': 30., # mV, sodium reversal potential
'E_CaT': self._calculate_calcium_rev_potential(
self.initial_conditions['ca_conc_intracellular'],
self.constants['temp_celsius']), # mV, low-threshold calcium reversal potential
'E_CaS': self._calculate_calcium_rev_potential(
self.initial_conditions['ca_conc_intracellular'],
self.constants['temp_celsius']), # mV, slow calcium reversal potential
'E_Kd': -80., # mV, potassium reversal potential
'E_KCa': -80., # mV, calcium-dependent potassium reversal potential
'E_A': -80., # mV, transient potassium reversal potential
'E_H': -20., # mV, hyperpolarization-activated cation reversal potential
}
self.conductances = conductances or {
'g_Na': 1.0764e3, # uS, transient sodium conductance
'g_CaT': 6.4056e0, # uS, low-threshold calcium conductance
'g_CaS': 1.0048e1, # uS, slow calcium conductance
'g_A': 8.0384e0, # uS, transient potassium conductance
'g_KCa': 1.7584e1, # uS, calcium-dependent potassium conductance
'g_Kd': 1.240928e2, # uS, potassium conductance
'g_H': 1.1304e-1, # uS, hyperpolarization-activated cation conductance
'g_L': 1.7584e-1, # uS, leak conductance
}
self.inf_alphas = inf_alphas or {
'm_Na': 25.5, # sodium activation variable
'h_Na': 48.9, # sodium inactivation variable
'm_CaT': 27.1, # low-threshold calcium activation variable
'h_CaT': 32.1, # low-threshold calcium inactivation variable
'm_CaS': 33.0, # slow calcium activation variable
'h_CaS': 60.0, # slow calcium inactivation variable
'm_H': 70.0, # hyperpolarization-activated cation activation variable
'm_Kd': 12.3, # potassium activation variable
'm_KCa': 28.3, # mV, calcium-dependent potassium activation variable
'm_A': 27.2, # mV, transient potassium activation variable
'h_A': 56.9, # mV, transient potassium inactivation variable
}
self.inf_betas = inf_betas or {
'm_Na': -5.29, # sodium activation variable
'h_Na': 5.18, # sodium inactivation variable
'm_CaT': -7.20, # low-threshold calcium activation variable
'h_CaT': 5.50, # low-threshold calcium inactivation variable
'm_CaS': -8.1, # slow calcium activation variable
'h_CaS': 6.20, # slow calcium inactivation variable
'm_H': 6.0, # hyperpolarization-activated cation activation variable
'm_Kd': -11.8, # potassium activation variable
'm_KCa': -12.6, # mV, calcium-dependent potassium activation variable
'm_A': -8.70, # mV, transient potassium activation variable
'h_A': 4.90, # mV, transient potassium inactivation variable
}
self.tau_cs = tau_cs or {
'm_Na': 1.32, # sodium activation variable
'h_Na_0': 0.0, # sodium inactivation variable
'h_Na_1': 1.50, # sodium inactivation variable
'm_CaT': 21.7, # low-threshold calcium activation variable
'h_CaT': 105.0, # low-threshold calcium inactivation variable
'm_CaS': 1.40, # slow calcium activation variable
'h_CaS': 60.0, # slow calcium inactivation variable
'm_H': 272.0, # hyperpolarization-activated cation activation variable
'm_Kd': 7.20, # potassium activation variable
'm_KCa': 90.3, # mV, calcium-dependent potassium activation variable
'm_A': 11.6, # mV, transient potassium activation variable
'h_A': 38.6, # mV, transient potassium inactivation variable
}
self.tau_ds = tau_ds or {
'm_Na': 1.26, # sodium activation variable
'h_Na_0': -0.67, # sodium inactivation variable
'h_Na_1': -1.00, # sodium inactivation variable
'm_CaT': 21.3, # low-threshold calcium activation variable
'h_CaT': 89.8, # low-threshold calcium inactivation variable
'm_CaS': 7.00, # slow calcium activation variable
'h_CaS': 150.0, # slow calcium inactivation variable
'm_H': -1499.0, # hyperpolarization-activated cation activation variable
'm_Kd': 6.40, # potassium activation variable
'm_KCa': 75.1, # mV, calcium-dependent potassium activation variable
'm_A': 10.4, # mV, transient potassium activation variable
'h_A': 29.2, # mV, transient potassium inactivation variable
}
self.tau_as = tau_as or {
'm_Na': 120.0, # sodium activation variable
'h_Na_0': 62.9, # sodium inactivation variable
'h_Na_1': 34.9, # sodium inactivation variable
'm_CaT': 68.1, # low-threshold calcium activation variable
'h_CaT': 55.0, # low-threshold calcium inactivation variable
'm_CaS': 27.0, # slow calcium activation variable
'h_CaS': 55.0, # slow calcium inactivation variable
'm_H': 42.2, # hyperpolarization-activated cation activation variable
'm_Kd': 28.3, # potassium activation variable
'm_KCa': 46.0, # mV, calcium-dependent potassium activation variable
'm_A': 32.9, # mV, transient potassium activation variable
'h_A': 38.9, # mV, transient potassium inactivation variable
}
self.tau_bs = tau_bs or {
'm_Na': -25.0, # sodium activation variable
'h_Na_0': -10.0, # sodium inactivation variable
'h_Na_1': 3.60, # sodium inactivation variable
'm_CaT': -20.5, # low-threshold calcium activation variable
'h_CaT': -16.9, # low-threshold calcium inactivation variable
'm_CaS': 10.0, # slow calcium activation variable
'h_CaS': 9.00, # slow calcium inactivation variable
'm_H': -8.73, # hyperpolarization-activated cation activation variable
'm_Kd': -19.2, # potassium activation variable
'm_KCa': -22.7, # mV, calcium-dependent potassium activation variable
'm_A': -15.2, # mV, transient potassium activation variable
'h_A': -26.5, # mV, transient potassium inactivation variable
}
self.tau_a2s = tau_a2s or {
'm_CaS': 70.0,
'h_CaS': 65.0,
}
self.tau_b2s = tau_b2s or {
'm_CaS': -13.0,
'h_CaS': -16.0,
}
self.q10s = {
'i_Na': 3.,
'i_CaT': 3.,
'i_CaS': 3.,
'i_H': 1.,
'i_Kd': 4.,
'i_KCa': 4.,
'i_A': 3.,
'i_L': 1.,
'g_Na': 1.,
'm_Na': 1.,
'h_Na': 1.,
'g_CaT': 1.,
'm_CaT': 1.,
'h_CaT': 1.,
'g_CaS': 1.,
'm_CaS': 1.,
'h_CaS': 1.,
'g_A': 1.,
'm_A': 1.,
'h_A': 1.,
'g_KCa': 1.,
'm_KCa': 1.,
'h_KCa': 1.,
'g_Kd': 1.,
'm_Kd': 1.,
'h_Kd': 1.,
'g_H': 1.,
'm_H': 1.,
'h_H': 1.,
'g_L': 1.,
'tau_Ca': 1.,
}
self.channel_types = ('Na', 'CaT', 'CaS', 'H', 'Kd', 'KCa', 'A', 'L')
self.channel_currents = {}
self.state_vars_constant = ('m_L', 'h_H', 'h_Kd', 'h_KCa', 'h_L')
self.state_vars_labels = self.get_state_vars_labels()
self.state_vars = [self.initial_conditions[key] for key in self.state_vars_labels]
def get_state_var(self, key: str) -> float:
"""
get a state variable
:param key: name of channel state var
:return: the current value of the given state variable
"""
if key in self.state_vars_constant:
return 1.0
index_state_var = self.state_vars_labels.index(key)
return self.state_vars[index_state_var]
def get_state_vars_labels(self) -> list:
"""
get state variable labels
:return: state variable labels
"""
state_vars_m = [f'm_{ch}' for ch in self.channel_types]
[state_vars_m.remove(x) for x in self.state_vars_constant if x in state_vars_m]
state_vars_h = [f'h_{ch}' for ch in self.channel_types]
[state_vars_h.remove(x) for x in self.state_vars_constant if x in state_vars_h]
return ['V', 'ca_conc_intracellular'] + state_vars_m + state_vars_h
def _calculate_calcium_rev_potential(self, ca_conc_intracellular, temp_celsius):
"""
computed dynamically using the Nernst equation assuming an extra-cellular calcium concentration of 3e3 uMolars.
:param ca_conc_intracellular: calcium intracellular concentration
:param temp_celsius: temperature in c
:return: calcium reverse potential
"""
rtzf_term = self.constants['gas_constant'] * (self.constants['base_temp_kelvin'] + temp_celsius)
rtzf_term /= (self.constants['z'] * self.constants['faraday'])
return rtzf_term * np.log10(self.constants['ca_conc_extracellular'] / ca_conc_intracellular)
def _calculate_normal_inf_response(self, key: str, voltage: float) -> float:
"""
calculates infinite response
:param key: name of the channel's state variables
:param voltage: current voltage
:return: normal infinite response
"""
return 1. / (1. + np.exp((voltage + self.inf_alphas[key]) / self.inf_betas[key]))
def _calculate_kca_inf_response(self, key: str, voltage: float) -> float:
"""
calculates infinite response with respect to KCa
:param key: name of channel's state variables
:param voltage: current voltage
:return: spec infinite response
"""
index_of_conc_kca = self.state_vars_labels.index('ca_conc_intracellular')
conc_kca = self.state_vars[index_of_conc_kca]
left_term = conc_kca / (conc_kca + 3.0)
return left_term / (1. + np.exp((voltage + self.inf_alphas[key]) / self.inf_betas[key]))
def _calculate_inf_response(self, key: str, voltage: float) -> float:
"""
compute the infinite response for a channel's state variable
:param key: the name of a channel's state variable
:param voltage: current voltage
:return: the value of the channel's state variable's current infinite response
"""
if 'KCa' in key:
return self._calculate_kca_inf_response(key, voltage)
else:
return self._calculate_normal_inf_response(key, voltage)
def get_current_voltage(self) -> float:
"""
get value from state variable and return float
:return: current voltage
"""
index_voltage = self.state_vars_labels.index('V')
return self.state_vars[index_voltage]
def _calculate_normal_tau(self, key: str, voltage: float) -> float:
"""
calcualtes tau normally by : CT - DT/(1. + exp((Volt + AT)/BT))
:param key: name of channels state variables current voltage
:param voltage: current voltage
:return: normal tau
"""
timeconst = self.tau_cs[key]
timeconst -= self.tau_ds[key] / (1. + np.exp((voltage + self.tau_as[key]) / self.tau_bs[key]))
return timeconst
def _calculate_cas_tau(self, key: str, voltage: float) -> float:
"""
calculates CaS tau different from normal tau by : CT + DT/(exp((Volt + AT)/BT) + exp((Volt + AT2)/BT2))
:param key: name of channels state variables
:param voltage: current voltage
:return: spec tau
"""
div_term = np.exp((voltage + self.tau_as[key]) / self.tau_bs[key])
div_term += np.exp((voltage + self.tau_a2s[key]) / self.tau_b2s[key])
return self.tau_cs[key] + self.tau_ds[key] / div_term
def _calculate_double_tau(self, key: str, voltage: float) -> float:
"""
calculate normal tau multiplied twice with different channel states
:param key: name of channels state variables
:param voltage: current voltage
:return: double tau
"""
total = self._calculate_normal_tau(f'{key}_0', voltage) * self._calculate_normal_tau(f'{key}_1', voltage)
return total
def _calculate_tau(self, key: str) -> float:
"""
calculate the time constant for a channel's state variable
:param key: the name of a channel's state variable
:return: the value of the channel's state variable's current time constant
"""
voltage = self.get_current_voltage()
if 'h_Na' in key:
return self._calculate_double_tau(key, voltage)
elif 'CaS' in key:
return self._calculate_cas_tau(key, voltage)
else:
return self._calculate_normal_tau(key, voltage)
def scale_time(self, key: str, value_to_scale: float) -> float:
"""
scare the input by using q10 corresponding to the key in the dict
:param key: name of channels state variables
:param value_to_scale: q10 scalar
:return: scaled value with corresponding q10
"""
temp = self.constants['temp_celsius']
reftemp = self.constants['reftemp_celsius']
return value_to_scale * pow(self.q10s[key], -(temp - reftemp) / 10.0)
def _calculate_channel_current(self, channel: str) -> float:
"""
calculates the channel current
:param channel: ionic channel names
:return: channel current
"""
g = self.conductances[f'g_{channel}']
e_rev = self.reversal_potentials[f'E_{channel}']
h = self.get_state_var(f'h_{channel}')
m = self.get_state_var(f'm_{channel}')
voltage = self.get_current_voltage()
q = self.q10s[f'i_{channel}']
pow_term = pow(m, q)
return g * pow_term * h * (voltage - e_rev)
def _calculate_dvdt(self, channel_currents: dict) -> float:
"""
calculates the derivative of the voltage wrt time
:param channel_currents: channel currents
:return: deriv of current
"""
return (-sum(channel_currents.values()) + self.constants['I_app']) / self.constants['C']
def _calculate_dstate_dt(self, key: str) -> float:
"""
calculates the derivative of the states channels
:param key: name of channel state variables
:return: deriv of state variable
"""
state_var = self.get_state_var(key)
voltage = self.get_current_voltage()
inf_state_var = self._calculate_inf_response(key, voltage)
# tau_state_var = self.scale_time(f'tau_{key}', self._calculate_tau(key))
tau_state_var = self._calculate_tau(key)
return (inf_state_var - state_var) / tau_state_var
def _calculate_dca_conc_intracellular_dt(self, channel_currents: dict) -> float:
"""
calculates the derivative of calcium's intracellular concentration
:param channel_currents: channel current
:return: deriv of calcium's intracellular concentration
"""
ca_conc_intra = self.get_state_var('ca_conc_intracellular')
outcalc = -self.constants['ca_conv_factor']
outcalc *= (channel_currents['CaT'] + channel_currents['CaS'])
outcalc += self.constants['ca_conc_background'] - ca_conc_intra
outcalc /= self.scale_time('tau_Ca', self.constants['tau_ca_conc_intracellular'])
return outcalc
def _calculate_dstate_variable(self, key, channel_currents: dict) -> float:
"""
calculates state variables
:param key: name of channel state variables
:param channel_currents: channel current
:return: state variables
"""
if key == 'V':
return self._calculate_dvdt(channel_currents)
elif key == 'ca_conc_intracellular':
return self._calculate_dca_conc_intracellular_dt(channel_currents)
elif key in self.state_vars_constant:
return np.float64(0.0)
else:
return self._calculate_dstate_dt(key)
# noinspection PyUnusedLocal
def update_state_variables(self, t: float, y: np.array) -> np.array:
"""
updates state variables (dy/dt) function for solve_ivp or other ODE solver
:param t: current time, ignored
:param y: array of current values of state variables
:return: updated state variables
"""
self.state_vars = [y[self.state_vars_labels.index(key)]
for key in self.state_vars_labels]
self._update_ca_rev_potential()
self.channel_currents = {ch: self._calculate_channel_current(ch)
for ch in self.channel_types}
return np.array([self._calculate_dstate_variable(key, self.channel_currents)
for key in self.state_vars_labels])
# return np.array(new_dstate_vars)
# new_array = np.array(new_dstate_vars)
# if np.all(np.isfinite(new_array)):
# return new_array
# else:
# return np.ones_like(new_array) + np.nan
def get_state_vars_and_labels(self) -> Tuple[List[float], List[str]]:
"""
Get the current values of the state variables and their labels
:return: a Tuple of: a list of the values of the state variables, and a list of the labels
"""
return self.state_vars, self.state_vars_labels
def get_initial_conditions(self) -> Tuple[List[float], List[str]]:
"""
Get the initial conditions
:return: a Tuple of: a list of the values of the state variables, and a list of the labels
"""
return [self.initial_conditions[key] for key in self.state_vars_labels], self.state_vars_labels
def _update_ca_rev_potential(self) -> None:
"""
updates calcium reverse potential
:return: None
"""
new_ca_rev_potential = self._calculate_calcium_rev_potential(
self.get_state_var('ca_conc_intracellular'),
self.constants['temp_celsius'])
self.reversal_potentials['E_CaT'] = new_ca_rev_potential # mV, low-threshold calcium reversal potential
self.reversal_potentials['E_CaS'] = new_ca_rev_potential # mV, slow calcium reversal potential
def run_simulation(self, time_steps: np.ndarray) -> Dict[str, Union[np.ndarray, bool, float]]:
"""
run simulation of model
:param time_steps: time steps
:return: 't': time steps, 'y': voltage trace, and 'spike_times': spike times
"""
self.time_steps = time_steps
init_cond = self.get_initial_conditions()[0]
# time_start = time.time()
sol = solve_ivp(self.update_state_variables, [self.time_steps[0],
self.time_steps[-1]],
init_cond, "BDF", self.time_steps)
# out_time = time.time() - time_start
# logging.debug(f'AMM: compute time: {out_time} s')
self.voltage_trace = sol.y[0]
logging.debug("AMM: computing trace characteristics")
dict_characteristics = self.convert_trace_to_spike_characteristics_tonic(self.voltage_trace,
self.time_steps,
self.spike_threshold)
logging.debug("AMM: returning")
return {"t": self.time_steps,
"y": self.voltage_trace,
**dict_characteristics}
@staticmethod
def convert_trace_to_spike_times_upward(voltage: np.array, times: np.array, threshold: float) -> np.ndarray:
"""
Converts voltage traces to spike times
:param voltage: array of voltages (mV)
:param times: array of time values (.1ms)
:param threshold: threshold for spike activation
:return: an array of spike times
"""
# grab all indices above the activation threshold0
voltage_indices_ge_th = np.where(voltage >= threshold)[0]
voltage_indices_ge_th = voltage_indices_ge_th[np.where(voltage_indices_ge_th < (len(times) - 1))[0]]
# only grab the index that is directly below the threshold and remove all other indexes
voltage_indeces_upward_th = voltage_indices_ge_th[np.where(voltage[voltage_indices_ge_th - 1] < threshold)]
# use index from voltage and use for time
return times[voltage_indeces_upward_th]
@staticmethod
def convert_trace_to_spike_times_downward(voltage: np.array, times: np.array, threshold: float) -> np.ndarray:
"""
Converts voltage traces to spike times
:param voltage: array of voltages (mV)
:param times: array of time values (.1ms)
:param threshold: threshold for spike activation
:return: an array of spike times
"""
# grab all indices above the activation threshold0
voltage_indices_ge_th = np.where(voltage >= threshold)[0]
voltage_indices_ge_th = voltage_indices_ge_th[np.where(voltage_indices_ge_th < (len(times) - 1))[0]]
# only grab the index that is directly below the threshold and remove all other indexes
voltage_indeces_downward_th = voltage_indices_ge_th[np.where(voltage[voltage_indices_ge_th + 1] < threshold)]
# use index from voltage and use for time
return times[voltage_indeces_downward_th]
@staticmethod
def calculate_num_spikes_per_burst(spike_times, temporal_interval: Union[int, float]):
spike_diff = np.diff(spike_times)
spike_gt_temp = np.where(np.array(spike_diff) > temporal_interval)[0]
# burst_end_loc = spike_gt_temp.astype(np.int) + 1
burst_end_loc_append = np.append(spike_gt_temp, len(spike_times) - 1)
return np.diff(np.insert(burst_end_loc_append, 0, 0))
@staticmethod
def clean_up_num_spikes_per_burst(num_spikes: np.ndarray) -> np.ndarray:
# indices_of_trailing_downward = np.where(num_spikes == 1)[0] - 1
# num_spikes[indices_of_trailing_downward] += 1
return num_spikes[num_spikes > 1]
@staticmethod
def calculate_num_interbursts(spike_times, temporal_interval):
# spike_diff = [(spike_times[x + 1] - spike_times[x]) for x in range(len(spike_times) - 1)]
spike_diff = np.diff(spike_times)
spike_gt_temp = np.where(np.array(spike_diff) > temporal_interval)[0]
burst_end_loc = spike_gt_temp.astype(np.int) + 1
burst_end_loc_append = np.append(spike_gt_temp, len(spike_times)-1)
burst_start_loc_append = np.insert(burst_end_loc, 0, 0)
return {'bs_loc': burst_start_loc_append,
'be_loc': burst_end_loc_append
}
def convert_trace_to_spike_characteristics_tonic(self, voltage: np.ndarray, times: np.ndarray,
threshold: dict) -> Dict[str, np.ndarray]:
times_before_th_upward = self.convert_trace_to_spike_times_upward(voltage, times, threshold['threshold_spike'])
times_before_th_downward = self.convert_trace_to_spike_times_downward(voltage, times,
threshold['threshold_mid_upper'])
times_before_th_downward2 = self.convert_trace_to_spike_times_downward(voltage, times,
threshold['threshold_mid_lower'])
e_lag = np.mean(np.abs(times_before_th_upward[:times_before_th_downward.size] - times_before_th_downward))
t_sp = -20
# check if a spike is within 100ms AFTER collecting all the spike times that crosses upward on a th.
temporal_interval = 100.
spike_times = times_before_th_upward
num_bursts = self.calculate_num_interbursts(spike_times, temporal_interval)
# TODO: use log-barrier function instead of if
if spike_times.size < 1 or times_before_th_downward.size < 1 or times_before_th_downward2.size < 1 \
or num_bursts['bs_loc'].size < 2 or num_bursts['be_loc'].size < 2:
return {'burst_frequency_mean': 1e5,
'duty_cycle_mean': 1e5,
'times_before_th_downward': times_before_th_downward,
'times_before_th_downward2': times_before_th_downward2,
'num_sw': 0.0,
'e_lag': 0.0,
'num_spikes': np.array([]),
'num_mid': np.array([]),
'spike_per_burst': np.array([]),
'spike_per_burst_mid': np.array([]),
}
bs = spike_times[num_bursts['bs_loc']]
be = spike_times[num_bursts['be_loc']]
if np.all(bs[-1] == be[-1]):
bs = bs[:-1]
be = be[:-1]
burst_duration = be - bs # [:be.size]
# period_half = be
# period = burst_duration + period_half
period = np.diff(bs)
burst_frequency = 1 / np.diff(be)
duty_cycle = burst_duration[:period.size] / np.diff(bs)
burst_frequency_mean = np.mean(burst_frequency) # <fb>
duty_cycle_mean = np.mean(duty_cycle) # <dc>
num_sw = np.size(self.convert_trace_to_spike_times_downward(voltage, times, threshold['threshold_slow_wave']))
# num_spikes = np.size(self.convert_trace_to_spike_times_upward(voltage, times, threshold['threshold_spike']))
num_spike_per_burst = self.clean_up_num_spikes_per_burst(
np.diff(np.insert(num_bursts['be_loc'], 0, 0)))
# num_mid = np.array([np.size(times_before_th_downward), np.size(times_before_th_downward2)])
nspike_mid1 = self.clean_up_num_spikes_per_burst(
self.calculate_num_spikes_per_burst(times_before_th_downward, temporal_interval))
nspike_mid2 = self.clean_up_num_spikes_per_burst(
self.calculate_num_spikes_per_burst(times_before_th_downward2, temporal_interval))
spike_per_burst_mid = (nspike_mid1, nspike_mid2)
return {'burst_frequency_mean': burst_frequency_mean,
'duty_cycle_mean': duty_cycle_mean,
'times_before_th_downward': times_before_th_downward,
'times_before_th_downward2': times_before_th_downward2,
'num_sw': num_sw,
'e_lag': e_lag,
# 'num_spikes': num_spikes,
# 'num_mid': num_mid,
'burst_frequency': burst_frequency,
'duty_cycle': duty_cycle,
'spike_per_burst': num_spike_per_burst,
'spike_per_burst_mid': spike_per_burst_mid,
}