import math
from collections import namedtuple
import pprint
import numpy as np
from scipy import optimize
from . import utilities, detect, vartype
from .signal_smooth import smooth
from ajustador.helpers.loggingsystem import getlogger
import logging
logger = getlogger(__name__)
logger.setLevel(logging.INFO)
def _plot_line(ax, ranges, value, label, color, zorder=3):
for (a,b) in ranges:
ax.plot([a, b], [float(value)]*2,
label=label, color=color, linestyle='-', zorder=zorder)
label = None # we only need one line labeled
if isinstance(value, vartype.vartype):
ax.plot([a, b], [value.x - 3*value.dev]*2,
color, linestyle='--', zorder=zorder)
ax.plot([a, b], [value.x + 3*value.dev]*2,
color, linestyle='--', zorder=zorder)
def _plot_spike(ax, wave, spikes, i, bottom=None, spike_bounds=None, lmargin=0.0010, rmargin=0.0015):
ax.set_xlim(spikes.x[i] - lmargin, spikes.x[i] + rmargin)
ax.plot(wave.x, wave.y, label='recording')
if bottom is not None:
ax.vlines(spikes.x[i:i+1], bottom, spikes.y, 'r', zorder=3)
if spike_bounds is not None:
ax.axvspan(spike_bounds[i].left, spike_bounds[i].right,
alpha=0.3, color='cyan')
def plural(n, word):
return '{} {}{}'.format(n, word, '' if n == 1 else 's')
class Feature:
requires = ()
provides = ()
array_attributes = ()
mean_attributes = ()
def __init__(self, obj):
self._obj = obj
def plot(self, figure=None):
if figure is None:
from matplotlib import pyplot
figure = pyplot.figure()
wave = self._obj.wave
ax = figure.add_subplot(111)
ax.plot(wave.x, wave.y, label='recording')
ax.set_xlabel('time / s')
ax.set_ylabel('membrane potential / V')
return ax
def spike_plot(self, figure=None, max_spikes=20,
bottom=None, spike_bounds=None,
lmargin=0.0010, rmargin=0.0015, rowsize=None):
if figure is None:
from matplotlib import pyplot
figure = pyplot.figure()
wave = self._obj.wave
spikes = self._obj.spikes
spike_count = min(len(spikes), max_spikes)
if rowsize is None:
rowsize = 3 if spike_count < 19 else 5
rows = math.ceil(spike_count / rowsize)
columns = min(spike_count, rowsize)
axes = []
sharey = None
for i in range(spike_count):
ax = figure.add_subplot(rows, columns, i+1, sharey=sharey)
if i == 0:
sharey = ax
else:
ax.tick_params(labelleft='off')
axes.append(ax)
_plot_spike(ax, wave, spikes, i, bottom=bottom, spike_bounds=spike_bounds,
lmargin=lmargin, rmargin=rmargin)
figure.autofmt_xdate()
return axes
def report_attr(self, name):
val = getattr(self, name)
prefix = '{} = '.format(name)
if hasattr(val, 'report'):
ans = val.report(prefix=prefix)
elif isinstance(val, np.ndarray) and hasattr(val, 'dev'):
ans = vartype.vartype.format_array(val, prefix=prefix)
elif hasattr(val, '__len__'):
joiner = '\n' + len(prefix)*' '
ans = prefix + joiner.join(str(x) for x in val)
else:
ans = prefix + str(val)
if name in self.mean_attributes and hasattr(val, '__len__'):
mean = vartype.vartype.average(val)
ans += '\n{:{}} = {}'.format('', len(name), mean)
return ans
def report(self):
return '\n'.join(self.report_attr(name) for name in self.provides)
class SteadyState(Feature):
"""Find the baseline and injection steady states
The range *before* `baseline_before` and *after* `baseline_after`
is used for `baseline`.
The range *between* `steady_after` and `steady_before` is used
for `steady`.
"""
requires = ('wave',
'baseline_before', 'baseline_after',
'steady_after', 'steady_before', 'steady_cutoff')
provides = ('baseline', 'steady', 'response',
'baseline_pre', 'baseline_post')
mean_attributes = ('baseline', 'steady', 'response',
'baseline_pre', 'baseline_post')
array_attributes = ('baseline', 'steady', 'response',
'baseline_pre', 'baseline_post')
@property
@utilities.once
def baseline(self):
"""The mean voltage of the area outside of injection interval
Returns mean value of wave after excluding "outliers", values
> 95th or < 5th percentile.
"""
wave = self._obj.wave
before = self._obj.baseline_before
after = self._obj.baseline_after
if before is None and after is None:
raise ValueError('cannot determine baseline')
region = ((wave.x < before if before is not None else False) |
(wave.x > after if after is not None else False))
what = wave.y[region]
cutoffa, cutoffb = np.percentile(what, (40, 60))
cut = what[(what >= cutoffa) & (what <= cutoffb)]
return vartype.array_mean(cut)
@property
@utilities.once
def baseline_pre(self):
"""The mean voltage of the area before the injection interval
Returns mean value of wave after excluding "outliers", values
> 95th or < 5th percentile.
"""
wave = self._obj.wave
before = self._obj.baseline_before
if before is None:
return vartype.vartype.nan
what = wave.y[(wave.x < before)]
cutoffa, cutoffb = np.percentile(what, (40, 60))
cut = what[(what >= cutoffa) & (what <= cutoffb)]
return vartype.array_mean(cut)
@property
@utilities.once
def baseline_post(self):
"""The mean voltage of the area after the injection interval
Returns mean value of wave after excluding "outliers", values
> 95th or < 5th percentile.
"""
wave = self._obj.wave
after = self._obj.baseline_after
if after is None:
return vartype.vartype.nan
what = wave.y[(wave.x > after)]
cutoffa, cutoffb = np.percentile(what, (40, 60))
cut = what[(what >= cutoffa) & (what <= cutoffb)]
return vartype.array_mean(cut)
@property
@utilities.once
def steady(self):
"""Returns mean value of wave between `steady_after` and `steady_before`.
"Outliers", values > 80th percentile (which is a parameter), are excluded.
80th percentile excludes the spikes.
"""
wave = self._obj.wave
after = self._obj.steady_after
before = self._obj.steady_before
cutoff = self._obj.steady_cutoff
data = wave.y[(wave.x > after) & (wave.x < before)]
cutoff = np.percentile(data, cutoff)
cut = data[data <= cutoff]
return vartype.array_mean(cut)
@property
@utilities.once
def response(self):
return self.steady - self.baseline
def plot(self, figure=None, pre_post=False):
wave = self._obj.wave
before = self._obj.baseline_before
after = self._obj.baseline_after
steady_after = self._obj.steady_after
steady_before = self._obj.steady_before
time = wave.x[-1]
ax = super().plot(figure)
if not pre_post:
_plot_line(ax,
[(0, before), (after, time)],
self.baseline,
'baseline', 'k')
else:
if before is not None:
_plot_line(ax,
[(0, before)],
self.baseline_pre,
'baseline_pre', 'k')
if after is not None:
_plot_line(ax,
[(after, time)],
self.baseline_post,
'baseline_post', 'k')
_plot_line(ax,
[(steady_after, steady_before)],
self.steady,
'steady', 'r')
ax.annotate('response',
xy=(time/2, self.steady.x),
xytext=(time/2, self.baseline.x),
arrowprops=dict(facecolor='black',
shrink=0),
horizontalalignment='center', verticalalignment='bottom')
ax.legend(loc='upper right')
ax.figure.tight_layout()
peak_and_threshold = namedtuple('peak_and_threshold', 'peaks thresholds')
def _find_spikes(wave, min_height=0.0, max_charge_time=0.004, charge_threshold=0.02):
peaks = detect.detect_peaks(wave.y, P_low=0.75, P_high=0.50)
peaks = peaks[wave.y[peaks] > min_height]
thresholds = np.empty(peaks.size)
for i in range(len(peaks)):
start = (wave.x >= wave.x[peaks[i]] - max_charge_time).argmax()
x = wave.x[start:peaks[i] + 1]
y = wave.y[start:peaks[i] + 1]
yderiv = np.diff(y)
#spike threshold is point where derivative is 2% of steepest
try:
ythresh = charge_threshold * yderiv.max()
thresh = y[1:][yderiv > ythresh].min()
thresholds[i] = thresh
except Exception:
thresholds[i] = np.nan
return peak_and_threshold(peaks, thresholds)
class WaveRegion:
def __init__(self, wave, left_i, right_i):
self._wave = wave
self.left_i = left_i
self.right_i = right_i
@property
def left(self):
"x coordinate of the left edge of FWHM"
if self.left_i == 0: # arr[-1:1] is an empty slice
return self._wave.x[0]
else:
return self._wave.x[self.left_i-1:self.left_i+1].mean()
@property
def right(self):
"x coordinate of the right edge of FWHM"
return self._wave.x[self.right_i:self.right_i+2].mean()
@property
def width(self):
return self.right - self.left
@property
def wave(self):
return self._wave[self.left_i:self.right_i+1]
@property
def x(self):
return self._wave.x[self.left_i:self.right_i+1]
@property
def y(self):
return self._wave.y[self.left_i:self.right_i+1]
def min(self):
return self.wave.min()
def relative_to(self, x, y):
new = np.rec.fromarrays((self.x - x, self.y - y), names='x,y')
return WaveRegion(new, 0, new.size-1)
def __str__(self):
y = self.y
return 'WaveRegion[{} points, x={:.04f}-{:.04f}, y={:.03f}-{:.03f}]'.format(
self.right_i - self.left_i + 1,
self.left, self.right,
self.y.min(), self.y.max())
def report(self, prefix='WaveRegion = '):
return '{}{}'.format(prefix, self)
class Spikes(Feature):
"""Find the position and height of spikes
"""
requires = ('wave', 'injection_interval', 'injection_start')
provides = ('spike_i', 'spikes', 'spike_count',
'spike_threshold','mean_spike_threshold',
'mean_isi', 'isi_spread',
'spike_latency',
'spike_bounds',
'spike_height', 'spike_width',
'mean_spike_height', # TODO: is it OK to have mean_spike_height as
# here and as an aggregated attribute?
)
array_attributes = ('spike_count',
'spike_height', 'spike_width',
'mean_isi', 'isi_spread',
'spike_latency','spike_threshold')
mean_attributes = ('spike_height', 'spike_width', 'spike_threshold')
@property
@utilities.once
def spike_i_and_threshold(self):
"Indices of spike maximums in the wave.x, wave.y arrays"
return _find_spikes(self._obj.wave)
@property
def spike_i(self):
"Indices of spike maximums in the wave.x, wave.y arrays"
return self.spike_i_and_threshold.peaks
@property
def spike_threshold(self):
"Indices of spike maximums in the wave.x, wave.y arrays"
return self.spike_i_and_threshold.thresholds
@property
@utilities.once
def spikes(self):
"An array with .x and .y components marking the spike maximums"
return self._obj.wave[self.spike_i]
@property
def spike_count(self):
"The number of spikes"
return len(self.spike_i)
mean_isi_fallback_variance = 0.001
@property
@utilities.once
def mean_isi(self):
"""The mean interval between spikes
Defined as:
* :math:`<x_{i+1} - x_i>`, if there are at least two spikes,
* the length of the depolarization interval otherwise (`injection_interval`)
If there less than three spikes, the variance is fixed as
`mean_isi_fallback_variance`.
"""
if self.spike_count > 2:
return vartype.array_mean(np.diff(self.spikes.x))
elif self.spike_count == 2:
d = self.spikes.x[1]-self.spikes.x[0]
return vartype.vartype(d, 0.001)
else:
return vartype.vartype(self._obj.injection_interval, 0.001)
@property
@utilities.once
def isi_spread(self):
"""The difference between the largest and smallest inter-spike intervals
Only defined when `spike_count` is at least 3.
"""
if len(self.spikes) > 2:
diff = np.diff(self.spikes.x)
return diff.ptp()
else:
return np.nan
@property
@utilities.once
def spike_latency(self):
"Latency until the first spike or nan if no spikes"
# TODO: add spike_latency to plot
if len(self.spikes) > 0:
return self.spikes[0].x - self._obj.injection_start
else:
return self._obj.injection_end - self._obj.injection_start
@property
@utilities.once
def spike_bounds(self):
"The FWHM box and other measurements for each spike"
spikes, thresholds = self.spike_i_and_threshold
ans = []
y = self._obj.wave.y
halfheight = (self.spikes.y - thresholds) / 2 + thresholds
for i, k in enumerate(self.spike_i):
beg = end = k
while beg > 1 and y[beg - 1] > halfheight[i]:
beg -= 1
while end + 2 < y.size and y[end + 1] > halfheight[i]:
end += 1
ans.append(WaveRegion(self._obj.wave, beg, end))
return ans
@property
@utilities.once
def spike_height(self):
"The difference between spike peaks and spike threshold"
spikes, thresholds = self.spike_i_and_threshold
height = self.spikes.y - thresholds
return height
@property
@utilities.once
def spike_width(self):
return np.array([bounds.width for bounds in self.spike_bounds])
@property
@utilities.once
def mean_spike_height(self):
"The mean absolute position of spike vertices"
# TODO: is the variance too big?
return vartype.array_mean(self.spikes.y)
@property
@utilities.once
def mean_spike_threshold(self):
"The mean absolute position of spike vertices"
# TODO: is the variance too big?
return vartype.array_mean(self.spikes.spike_threshold)
def plot(self, figure=None):
from . import drawing_util
wave = self._obj.wave
ax = super().plot(figure)
bottom = -0.06 # self._obj.steady.x
# Doing the "proper" thing makes the plot hard to read
vline1 = ax.vlines(self.spikes.x, bottom, self.spikes.y, 'r',
label='timing of spike maximum')
ax.text(0.05, 0.5, plural(self.spike_count, 'spike'),
horizontalalignment='left',
transform=ax.transAxes)
_plot_line(ax,
[(self._obj.steady_after, self._obj.steady_before)],
self.mean_spike_height,
'spike_height', 'y', zorder=0)
ax.legend(loc='upper left',
handler_map={vline1: drawing_util.HandlerVLineCollection()})
ax.figure.tight_layout()
if self.spike_count > 0:
ax2 = ax.figure.add_axes([.7, .45, .25, .4])
ax2.tick_params(labelbottom='off', labelleft='off')
ax2.set_title('first spike', fontsize='smaller')
_plot_spike(ax2, wave, self.spikes, i=0,
bottom=-0.06, spike_bounds=self.spike_bounds)
def spike_plot(self, figure=None, **kwargs):
x = self._obj.spikes.x.mean()
spike_bounds = self.spike_bounds
thresholds = self.spike_threshold
height = self.spike_height
lmargin = self.spike_width.max()
rmargin = self.spike_width.max() * 2
axes = super().spike_plot(figure=None,
spike_bounds=spike_bounds,
lmargin=lmargin, rmargin=rmargin,
**kwargs)
for i in range(len(axes)):
y = thresholds[i] + height[i] / 2
axes[i].annotate('FWHM',
xy=(spike_bounds[i].left, y),
xytext=(spike_bounds[i].right, y),
arrowprops=dict(facecolor='black',
shrink=0),
verticalalignment='bottom')
axes[i].axhline(thresholds[i], color='green', linestyle='--', linewidth=0.3)
class AHP(Feature):
"""Find the depth of "after hyperpolarization"
"""
requires = ('wave',
'injection_start', 'injection_end', 'injection_interval',
'spikes', 'spike_count', 'spike_bounds', 'spike_threshold')
provides = ('spike_ahp_window', 'spike_ahp', 'spike_ahp_position')
array_attributes = ('spike_ahp_window', 'spike_ahp', 'spike_ahp_position')
mean_attributes = ('spike_ahp',)
@property
@utilities.once
def spike_ahp_window(self):
spikes = self._obj.spikes
spike_bounds = self._obj.spike_bounds
thresholds = self._obj.spike_threshold
injection_start = self._obj.injection_start
injection_end = self._obj.injection_end
x = self._obj.wave.x
y = self._obj.wave.y
ans = []
for i in range(len(spikes)):
beg = spike_bounds[i].right_i
# Don't allow the ahp to straddle an injection start/stop edge.
# The ahp will be invalid anyway.
rlimit = min(spike_bounds[i+1].left if i < len(spikes)-1 else x[-1],
injection_start if injection_start > x[beg] else np.infty,
injection_end if injection_end > x[beg] else np.infty)
w = spike_bounds[i].width
if not np.isnan(w):
n_rolling_window = int(w // (x[1] - x[0])) + 1
else:
# FIXME: consider rejecting those outright
n_rolling_window = 5
# if we are before the AHP, or mostly going down, advance
while (beg < y.size - n_rolling_window and
y[beg] >= thresholds[i] and x[beg + 1] < rlimit and
y[beg] > y[beg + n_rolling_window]):
beg += 1
end = beg + n_rolling_window
while (end < x.size and
(y[end] < thresholds[i] or end - beg < 5) and
x[end] < rlimit):
end += 1
ans.append(WaveRegion(self._obj.wave, beg, end))
return ans
@property
@utilities.once
def spike_ahp(self):
"""Returns the (averaged) minimum in y of each AHP window
`spike_ahp_window` is used to determine the extent of the AHP.
An average of the bottom area of the window of the width of the
spike is used.
Probably this should be changed to return the difference between threshold and minimum y
thresh=spikes.spike_threshold
mean=vartype.array_mean(cut.y)-thresh[i], or ans[i]=mean.x-spikes.spike_threshold[i],mean.dev
"""
windows = self.spike_ahp_window
spikes = self._obj.spikes
spike_bounds = self._obj.spike_bounds
ans = np.empty((len(windows), 2))
for i in range(len(windows)):
w = spike_bounds[i].width
left = windows[i].x[windows[i].y.argmin()] - w/2
right = windows[i].x[windows[i].y.argmin()] + w/2
cut = windows[i].wave[(windows[i].x >= left) & (windows[i].x <= right)]
mean = vartype.array_mean(cut.y)
ans[i] = mean.x, mean.dev
return np.rec.fromarrays(ans.T, names='x,dev')
@property
@utilities.once
def spike_ahp_position(self):
"""Returns the (averaged) x of the minimum in y of each AHP window
`spike_ahp_window` is used to determine the extent of the AHP.
An average of the bottom area of the window of the width of the
spike is used.
TODO: add to plot
"""
windows = self.spike_ahp_window
spikes = self._obj.spikes
spike_bounds = self._obj.spike_bounds
ans = np.empty((len(windows), 2))
for i in range(len(windows)):
step = windows[i].x[1] - windows[i].x[0]
# Make sure that we have at least a few points in the window,
# even if the spike is very narrow.
w = max(spike_bounds[i].width, 8 * step)
left = windows[i].x[windows[i].y.argmin()] - w/2
right = windows[i].x[windows[i].y.argmin()] + w/2
cut = windows[i].wave[(windows[i].x >= left) & (windows[i].x <= right)]
bottom = vartype.array_mean(cut.y)
relative = cut.y - bottom.x
weights = (relative / relative.ptp())**-2
weights = np.fmin(weights, 100)
avg = (cut.x * weights).sum() / weights.sum()
assert not np.isnan(avg)
dev = ((cut.x-avg)**2 * weights).sum()**0.5 / weights.sum()**0.5
assert not np.isnan(dev)
# TODO: check the formula for dev
ans[i] = (avg, dev)
return np.rec.fromarrays(ans.T, names='x,dev')
def _do_plots(self, axes):
spikes = self._obj.spikes
spike_bounds = self._obj.spike_bounds
thresholds = self._obj.spike_threshold
windows = self.spike_ahp_window
ahps = self.spike_ahp
low, high = np.inf, -np.inf
spike_count = len(axes)
for i in range(spike_count):
window = windows[i]
x = spikes.x[i]
width = window.right - x
axes[i].plot(window.x, window.y, 'r', label='AHP')
_plot_line(axes[i],
[(spikes[i].x - 3*spike_bounds[i].width,
spikes[i].x + 3*spike_bounds[i].width)],
thresholds[i],
'spike threshold', 'green')
_plot_line(axes[i],
[(x, window.right)],
vartype.vartype(*ahps[i]),
'AHP bottom', 'magenta')
axes[i].annotate('AHP',
xytext=(x + width/2, ahps[i].x),
xy=(x + width/2, thresholds[i]),
arrowprops=dict(facecolor='black',
shrink=0),
horizontalalignment='center', verticalalignment='top')
diff = abs(thresholds[i] - ahps[i].x)
low = min(ahps[i].x - diff*0.5, thresholds[i] - diff*0.5, low)
high = max(thresholds[i] + diff*0.5, high)
axes[i].set_ylim(low, high)
def plot(self, figure=None):
ax = super().plot(figure)
if self._obj.spike_count == 0:
ax.text(0.5, 0.5, 'no spikes',
horizontalalignment='center',
transform=ax.transAxes)
else:
ax.set_xlim(self._obj.spikes[0].x - self._obj.injection_interval*0.05,
self._obj.spikes[-1].x + self._obj.injection_interval*0.05)
self._do_plots([ax] * self._obj.spike_count)
ax.figure.tight_layout()
def spike_plot(self, figure=None, **kwargs):
spike_bounds = self._obj.spike_bounds
axes = super().spike_plot(figure, **kwargs)
self._do_plots(axes)
for i in range(self._obj.spike_count):
l = spike_bounds[i].left
r = self.spike_ahp_window[i].right
diff = r - l
axes[i].set_xlim(l - diff*0.15, r + diff*0.15)
def _find_falling_curve(wave, window=20, after=0.2, before=0.6):
d = vartype.array_diff(wave)
dd = smooth(d.y, window='hanning', window_len=window)[(d.x > after) & (d.x < before)]
start = end = dd.argmin() + (d.x <= after).sum()
while start > 0 and wave[start - 1].y > wave[start].y and wave[start].x > after:
start -= 1
sm = smooth(wave.y, window='hanning', window_len=window)
smallest = sm[end]
# find minimum
while (end+window < wave.size and wave[end+window].x < before
and sm[end:end + window].min() < smallest):
smallest = sm[end]
end += window // 2
start_override = (d.x > after).argmax()
ccut = wave[start_override + 1 : end]
return ccut
def simple_exp(x, amp, tau):
return float(amp) * np.exp(-(x-x[0]) / float(tau))
def negative_exp(x, amp, tau):
return float(amp) * (1-np.exp(-(x-x[0]) / float(tau)))
falling_param = namedtuple('falling_param', 'amp tau')
function_fit = namedtuple('function_fit', 'function params good')
def _fit_falling_curve(ccut, baseline, steady):
if ccut.size < 5 or not (steady-baseline).negative:
func = None
params = falling_param(vartype.vartype.nan,
vartype.vartype.nan)
good = False
else:
init = (ccut.y.min()-baseline.x, ccut.x.ptp())
func = negative_exp
try:
popt, pcov = optimize.curve_fit(func, ccut.x, ccut.y-baseline.x, (-1,1))
pcov = np.zeros((2,2)) + pcov
params = falling_param(vartype.vartype(popt[0], pcov[0,0]**0.5),
vartype.vartype(popt[1], pcov[1,1]**0.5))
good = params.amp.negative and params.tau.positive
except RuntimeError:
params = None
good = False
return function_fit(func, params, good)
class FallingCurve(Feature):
requires = ('wave',
'injection_start', 'steady_before',
'falling_curve_window',
'baseline', 'steady')
provides = ('falling_curve', 'falling_curve_fit',
'falling_curve_amp', 'falling_curve_tau',
'falling_curve_function')
array_attributes = ('falling_curve_amp', 'falling_curve_tau',
'falling_curve_function')
@property
@utilities.once
def falling_curve(self):
return _find_falling_curve(self._obj.wave,
window=self._obj.falling_curve_window,
after=self._obj.injection_start,
before=self._obj.steady_before)
@property
@utilities.once
def falling_curve_fit(self):
return _fit_falling_curve(self.falling_curve, self._obj.baseline, self._obj.steady)
@property
def falling_curve_amp(self):
fit = self.falling_curve_fit
return fit.params.amp if fit.good else vartype.vartype.nan
@property
def falling_curve_tau(self):
fit = self.falling_curve_fit
return fit.params.tau if fit.good else vartype.vartype.nan
@property
def falling_curve_function(self):
fit = self.falling_curve_fit
return fit.function if fit.good else None
def plot(self, figure=None):
ax = super().plot(figure)
ccut = self.falling_curve
baseline = self._obj.baseline
steady = self._obj.steady
ax.plot(ccut.x, ccut.y, 'r', label='falling curve')
ax.set_xlim(self._obj.injection_start - 0.005, ccut.x.max() + .01)
func, popt, good = self.falling_curve_fit
if good:
label = 'fitted {}'.format(func.__name__)
ax.plot(ccut.x, baseline.x + func(ccut.x, *popt), 'g--', label=label)
else:
ax.text(0.2, 0.5, 'bad fit',
horizontalalignment='center',
transform=ax.transAxes,
color='red')
ax.legend(loc='upper right')
ax.figure.tight_layout()
class PostInjectionCurve(Feature):
requires = ('wave',
'injection_start', 'injection_end', 'steady_before',
'falling_curve_window',
'baseline_after', 'steady')
provides = ('post_injection_curve', 'post_injection_curve_fit',
'post_injection_curve_amp', 'post_injection_curve_tau',
'post_injection_curve_function')
array_attributes = ('post_injection_curve_amp', 'post_injection_curve_tau',
'post_injection_curve_function')
@property
@utilities.once
def post_injection_curve(self):
window = self._obj.wave[(self._obj.wave.x > self._obj.injection_end)]
return window
@property
@utilities.once
def post_injection_curve_fit(self):
if self._obj.steady > self._obj.baseline_post:
return _fit_falling_curve(self.post_injection_curve, self._obj.steady, self._obj.baseline_post)
if self._obj.steady < self._obj.baseline_post:
return _fit_charging_curve(self.post_injection_curve, self._obj.steady, self._obj.baseline_post)
@property
def post_injection_curve_amp(self):
fit = self.post_injection_curve_fit
return fit.params.amp if fit.good else vartype.vartype.nan
@property
def post_injection_curve_tau(self):
fit = self.post_injection_curve_fit
return fit.params.tau if fit.good else vartype.vartype.nan
@property
def post_injection_curve_function(self):
fit = self.post_injection_curve_fit
return fit.function if fit.good else None
def plot(self, figure=None):
ax = super().plot(figure)
ccut = self.post_injection_curve
baseline = self._obj.baseline
steady = self._obj.steady
ax.plot(ccut.x, ccut.y, 'r', label='falling curve')
ax.set_xlim(self._obj.injection_start - 0.005, ccut.x.max() + .01)
func, popt, good = self.post_injection_curve_fit
if good:
label = 'fitted {}'.format(func.__name__)
ax.plot(ccut.x, baseline.x + func(ccut.x, *popt), 'g--', label=label)
else:
ax.text(0.2, 0.5, 'bad fit',
horizontalalignment='center',
transform=ax.transAxes,
color='red')
ax.legend(loc='upper right')
ax.figure.tight_layout()
class Rectification(Feature):
requires = ('injection_start',
'steady_after', 'steady_before',
'falling_curve', 'steady')
provides = 'rectification',
array_attributes = 'rectification',
mean_attributes = 'rectification',
window_len = 11
@property
@utilities.once
def rectification(self):
ccut = self._obj.falling_curve
steady = self._obj.steady
if ccut.size < self.window_len + 1:
return vartype.vartype.nan
pos = ccut.y.argmin()
end = max(pos + self.window_len//2, ccut.size-1)
bottom = vartype.array_mean(ccut[end-self.window_len : end+self.window_len+1].y)
return steady - bottom
def plot(self, figure=None):
ax = super().plot(figure)
ccut = self._obj.falling_curve
after = self._obj.steady_after
before = self._obj.steady_before
steady = self._obj.steady
ax.set_xlim(self._obj.injection_start - 0.005, before)
_plot_line(ax,
[(after, before)],
steady,
'steady', 'r')
right = (after + before) / 2
bottom = steady.x - self.rectification.x
if np.isnan(bottom):
ax.text(0.5, 0.5, 'rectification not detected',
horizontalalignment='center',
transform=ax.transAxes,
color='red')
else:
_plot_line(ax,
[(after, right)],
bottom,
'rectification bottom', 'g')
ax.annotate('rectification',
xytext=(right, bottom),
xy=(right, self._obj.steady.x),
arrowprops=dict(facecolor='black',
shrink=0),
horizontalalignment='center', verticalalignment='top')
ax.legend(loc='upper right')
ax.figure.tight_layout()
charging_param = namedtuple('charging_param', 'amp tau')
charging_function_fit = namedtuple('charging_function_fit', 'function params good')
def _fit_charging_curve(ccut, baseline, steady):
if ccut.size < 5 or (steady-baseline).negative:
func = None
params = charging_param(vartype.vartype.nan,
vartype.vartype.nan)
good = False
else:
init = (ccut.y.min()-baseline.x, ccut.x.ptp())
func = negative_exp
try:
popt, pcov = optimize.curve_fit(func, ccut.x-ccut.x[0], ccut.y-ccut.y[0], p0 = (.02,.02), maxfev = 100000)
pcov = np.zeros((2,2)) + pcov
params = charging_param(vartype.vartype(popt[0], pcov[0,0]**0.5),
vartype.vartype(popt[1], pcov[1,1]**0.5))
good = params.amp.positive and params.tau.positive
except RuntimeError:
params = None
good = False
return charging_function_fit(func, params, good)
class ChargingCurve(Feature):
requires = ('wave', 'injection_start', 'steady_before',
'baseline', 'baseline_before',
'spikes', 'spike_count', 'spike_threshold', 'injection_end')
provides = ('charging_curve_halfheight', 'charging_curve','charging_curve_fit', 'charging_curve_amp', 'charging_curve_tau','charging_curve_function')
array_attributes = ('charging_curve', 'charging_curve_halfheight','charging_curve_amp', 'charging_curve_tau','charging_curve_function')
@property
@utilities.once
def charging_curve_halfheight(self):
"The height in the middle between depolarization start and first spike"
ccut = self.charging_curve
if ccut is None:
return vartype.vartype.nan
threshold = self._obj.spike_threshold[0]
baseline = self._obj.baseline
return (threshold - baseline) / 2
def negative_exp(x, amp, tau):
return float(amp) * (1-np.exp(-(x-x[0]) / float(tau)))
@property
@utilities.once
def charging_curve(self):
#if self._obj.spike_count < 1:
# return None
wave = self._obj.wave
injection_start = self._obj.injection_start
injection_end = self._obj.injection_end
baseline = self._obj.baseline.x
if self._obj.spike_count < 1:
#threshold_y = (np.max(wave.y) - baseline) * 0.9
what = wave[(wave.x > injection_start) & (wave.x < injection_end)]
return what
else:
cut = wave[(wave.x<self._obj.spikes[0].x)]
threshold_y = 0.95*(self._obj.spike_threshold[0] - baseline)
threshold_x = cut[(cut.y-baseline < threshold_y)][-1].x #x value of last y value below threshold before first spike
what = wave[(wave.x > injection_start) & (wave.x < threshold_x)]
#what = what[what.y < threshold]
return what
@property
@utilities.once
def charging_curve_fit(self):
return _fit_charging_curve(self.charging_curve, self._obj.baseline, self._obj.steady)
@property
def charging_curve_amp(self):
fit = self.charging_curve_fit
return fit.params.amp if fit.good else vartype.vartype.nan
@property
def charging_curve_tau(self):
fit = self.charging_curve_fit
return fit.params.tau if fit.good else vartype.vartype.nan
@property
def charging_curve_function(self):
fit = self.charging_curve_fit
return fit.function if fit.good else None
def plot(self, figure=None):
ax = super().plot(figure)
baseline = self._obj.baseline
ccut = self.charging_curve
if ccut is None:
ax.text(0.05, 0.5, 'cannot determine charging curve',
horizontalalignment='left',
transform=ax.transAxes,
color='red')
else:
ax.plot(ccut.x, ccut.y, 'r', label='charging curve')
ax.set_xlim(ccut.x[0] - 0.005, self._obj.spikes[0].x)
_plot_line(ax,
[(ccut.x[0], ccut.x[-1])],
baseline + self.charging_curve_halfheight,
'charging curve halfheight', 'g')
_plot_line(ax,
[(0, self._obj.baseline_before)],
baseline,
'baseline', 'k')
ax.legend(loc='upper left')
ax.figure.tight_layout()
standard_features = (
SteadyState,
Spikes,
AHP,
FallingCurve,
Rectification,
ChargingCurve,
PostInjectionCurve,
)