from __future__ import print_function, division
import collections
import enum
import numpy as np
import pandas as pd
from . import vartype
from ajustador.helpers.loggingsystem import getlogger
import logging
logger = getlogger(__name__)
logger.setLevel(logging.INFO)
class ErrorCalc(enum.IntEnum):
normal = 1
relative = 2
"If 'b' (measurement) is 0, limit to this value"
RELATIVE_MAX_RATIO = 10
NAN_REPLACEMENT = 1.5
def sub_mes_dev(reca, recb):
''' Calculates difference and root over sum of squares of deviation of raca and racb.
'''
logger.debug("{} {}".format(type(reca), type(recb)))
if isinstance(reca, vartype.vartype):
assert reca == vartype.vartype.nan
return vartype.vartype.nan
if isinstance(recb, vartype.vartype):
assert recb == vartype.vartype.nan
return vartype.vartype.nan
if len(reca) == 0 or len(recb) == 0:
return vartype.vartype.nan
if hasattr(reca, 'x'):
xy = (reca.x - recb.x, (reca.dev**2 + recb.dev**2)**0.5)
if isinstance(reca, vartype.vartype):
return vartype.vartype(*xy)
else:
return np.rec.fromarrays(xy, names='x,dev')
else:
return reca - recb
def _select(a, b, which=None):
''' a -> sim, b -> measurments and which -> filter condition
Note:- If filter condtion is not satisfied by any of the value, when indexed
will return a nan.'''
if which is not None:
bsel = b[which]
else:
bsel = b
fitting = np.abs(a.injection[:,None] - bsel.injection) < 1e-12
logger.debug("{}".format(fitting))
ind1, ind2 = np.where(fitting)
logger.debug("{} {}".format(ind1, ind2))
return a[ind1], bsel[ind2]
def relative_diff_single(a, b, extra=0):
x = getattr(a, 'x', a)
y = getattr(b, 'x', b)
base = abs(x) + abs(y) / RELATIVE_MAX_RATIO
## `np.where( (base>0), (x-y)/base, base)` returns element wise array of
## the second argument, `(x-y)/base`, wherever first argument is True, and
## the third argument wherever the first argument is False, so if x and y
## are both zero, a difference of zero is returned rather than a NaN.
return (np.where( (base>0), abs(x-y)/base, base)
+ RELATIVE_MAX_RATIO * extra)
def relative_diff(a, b):
"""A difference between a and b using b as the yardstick
.. math::
W = |a - b| / (|b| + |a| * RELATIVE_MAX_RATIO)
w = rms(W)
"""
n1, n2 = len(a), len(b)
if n1 == n2 == 0:
return np.array([])
if n1 < n2:
a = a[:n2]
elif n1 < n2:
b = b[:n1]
return relative_diff_single(a, b, extra=abs(n1 - n2))
def _evaluate(a, b, error=ErrorCalc.relative):
''' Calcuate RMS using anyone of the two types of difference selected by error flag.
Difference are calculated in between sim and measurements.
'''
if error == ErrorCalc.normal:
diff = sub_mes_dev(a, b)
ans = vartype.array_rms(diff, nan_replacement=NAN_REPLACEMENT)
elif error == ErrorCalc.relative:
diff = relative_diff(a, b)
ans = vartype.array_rms(diff, nan_replacement=NAN_REPLACEMENT)
else:
assert False, error
if np.isnan(ans): # Is this check really needed?
return NAN_REPLACEMENT
else:
return ans
def _evaluate_single(a, b, error=ErrorCalc.relative):
if error == ErrorCalc.normal:
ans = float(abs(a - b))
elif error == ErrorCalc.relative:
ans = relative_diff_single(a, b)
else:
raise AssertionError
if np.isnan(ans):
return NAN_REPLACEMENT
else:
return ans
def response_fitness(sim, measurement, full=False, error=ErrorCalc.relative):
"Similarity of response to hyperpolarizing injection"
m1, m2 = _select(sim, measurement, measurement.spike_count < 1)
return _evaluate(m1.response, m2.response, error=error)
def response_variance_fitness(sim, measurement, full=False, error=ErrorCalc.relative):
'''Variance of steady state response for non-spiking responses'''
m1, m2 = _select(sim, measurement, measurement.spike_count < 1)
return _evaluate(m1.steady.dev, m2.steady.dev, error=error)
def baseline_fitness(sim, measurement, full=False, error=ErrorCalc.relative):
"Similarity of baselines"
m1, m2 = _select(sim, measurement)
return _evaluate(m1.baseline, m2.baseline, error=error)
def baseline_pre_fitness(sim, measurement, full=False, error=ErrorCalc.relative):
"Similarity of baselines"
m1, m2 = _select(sim, measurement)
return _evaluate(m1.baseline_pre, m2.baseline_pre, error=error)
def baseline_post_fitness(sim, measurement, full=False, error=ErrorCalc.relative):
"Similarity of baselines"
m1, m2 = _select(sim, measurement)
return _evaluate(m1.baseline_post, m2.baseline_post, error=error)
def rectification_fitness(sim, measurement, full=False, error=ErrorCalc.relative):
m1, m2 = _select(sim, measurement, measurement.injection <= -10e-12)
return _evaluate(m1.rectification, m2.rectification, error=error)
#This should be calculated for positive current injection, even if no spike. Maybe only if no spike
def charging_curve_fitness(sim, measurement, full=False, error=ErrorCalc.relative):
m1, m2 = _select(sim, measurement, measurement.injection > 0)
if len(m2) == 0:
return vartype.vartype.nan
return _evaluate(m1.charging_curve_halfheight, m2.charging_curve_halfheight,
error=error)
def post_injection_curve_tau_fitness(sim, measurement, full=False, error=ErrorCalc.relative):
"Similarity of time constants fit to post injection curve"
m1, m2 = _select(sim, measurement)
return _evaluate(m1.post_injection_curve_tau, m2.post_injection_curve_tau, error=error)
def charging_curve_time_fitness(sim, measurement, full=False, error=ErrorCalc.relative):
m1, m2 = _select(sim, measurement, measurement.injection >0)
if len(m2) == 0:
return vartype.vartype.nan
return _evaluate(m1.charging_curve_tau, m2.charging_curve_tau, error=error)
def charging_curve_full_fitness(sim, measurement, full=False, error=ErrorCalc.relative):
''''''
m1, m2 = _select(sim, measurement, measurement.injection > 0)
diffs = [ahp_curve_compare(wave1.charging_curve, wave2.charging_curve)
for wave1, wave2 in zip(m1, m2)]
if not diffs:
return 0
#assert 0 <= min(diffs) <= 1, diffs
#assert 0 <= max(diffs) <= 1, diffs
diffs = np.array(diffs)
if full:
return diffs
else:
return vartype.array_rms(diffs, nan_replacement=NAN_REPLACEMENT)
#alternatively, could do falling curve for positive current injection if no spike
def falling_curve_time_fitness(sim, measurement, full=False, error=ErrorCalc.relative):
m1, m2 = _select(sim, measurement, measurement.injection <= -10e-12)
if len(m2) == 0:
return vartype.vartype.nan
return _evaluate(m1.falling_curve_tau, m2.falling_curve_tau, error=error)
def mean_isi_fitness(sim, measurement, full=False, error=ErrorCalc.relative):
m1, m2 = _select(sim, measurement, measurement.spike_count >= 2)
if len(m2) == 0:
return vartype.vartype.nan
return _evaluate(m1.mean_isi, m2.mean_isi, error=error)
def isi_spread_fitness(sim, measurement, full=False, error=ErrorCalc.relative):
m1, m2 = _select(sim, measurement, measurement.spike_count >= 2)
if len(m2) == 0:
return vartype.vartype.nan
return _evaluate(m1.isi_spread, m2.isi_spread, error=error)
def _measurement_to_spikes(meas):
frames = [pd.DataFrame(wave.spikes) for wave in meas]
for frame, wave in zip(frames, meas):
frame['injection'] = wave.injection
frame.reset_index(inplace=True)
frame.set_index(['index', 'injection'], inplace=True)
return pd.concat(frames)
def spike_time_fitness(sim, measurement, full=False, error=ErrorCalc.relative):
m1, m2 = _select(sim, measurement, measurement.spike_count >= 2)
if len(m1) == 0:
m1, m2 = _select(measurement, sim, sim.spike_count >= 2)
if len(m1) == 0:
# neither is spiking, cannot determine spike timing (Good thing)
print('************')
return 0 # If both are not spiking (rare but possible), cannot imporve spike_time_fitness
spikes1 = _measurement_to_spikes(m1)
spikes2 = _measurement_to_spikes(m2)
# spikes1, spikes2 are pandas DataFrames indexed by injection level and spike number
# The align method below will insert nans for missing spikes (only in the data frame missing a spike; the other dataframe will preserve its times)
spikes1, spikes2 = spikes1.align(spikes2,axis=0)
# Missing spikes contribute error scaled by injection_interval (could be left simply as NaNs and handled by NAN_REPLACEMENT)
spikes1.fillna(sim[0].injection_interval, inplace=True)
spikes2.fillna(sim[0].injection_interval, inplace=True)
return _evaluate(spikes1['x'], spikes2['x'], error=error)
def spike_count_fitness(sim, measurement, full=False, error=ErrorCalc.relative):
m1, m2 = _select(sim, measurement)
return _evaluate(m1.spike_count, m2.spike_count, error=error)
def spike_latency_fitness(sim, measurement, full=False, error=ErrorCalc.relative):
m1, m2 = _select(sim, measurement, measurement.spike_count >= 1)
return _evaluate(m1.spike_latency, m2.spike_latency, error=error)
def spike_width_fitness(sim, measurement, full=False, error=ErrorCalc.relative):
return _evaluate_single(sim.mean_spike_width, measurement.mean_spike_width,
error=error)
def spike_height_fitness(sim, measurement, full=False, error=ErrorCalc.relative):
return _evaluate_single(sim.mean_spike_height, measurement.mean_spike_height,
error=error)
def spike_threshold_fitness(sim, measurement, full=False, error=ErrorCalc.relative):
return _evaluate_single(sim.mean_spike_threshold, measurement.mean_spike_threshold, error=error)
def spike_ahp_fitness(sim, measurement, full=False, error=ErrorCalc.relative):
m1, m2 = _select(sim, measurement, measurement.spike_count >= 1)
# Just ignore any extra spikes. Let's assume that most spikes
# and AHPs are similar, and that we're using a different fitness
# function to compare spike counts.
left = m1.spike_ahp
right = m2.spike_ahp
if len(left) < len(right):
right = right[:len(left)]
elif len(left) > len(right):
left = left[:len(right)]
return _evaluate(left, right, error=ErrorCalc.relative)
def interpolate(wave1, wave2):
"Interpolate wave1 to wave2.x"
y = np.interp(wave2.x, wave1.x, wave1.y, left=np.nan, right=np.nan)
return np.rec.fromarrays((wave2.x, y), names='x,y')
def ahp_curve_centered(wave, i):
windows = wave.spike_ahp_window
if i >= len(windows):
return None
cut = windows[i]
ahp_y = wave.spike_ahp[i]
ahp_x = wave.spike_ahp_position[i]
return cut.relative_to(ahp_x.x, ahp_y.x)
def ahp_curve_compare(cut1, cut2):
"""Returns a number from [0, 1] which compares how close they are.
0 means the same, 1 means very different.
"""
assert not cut1 is cut2 is None
if cut1 is None or cut2 is None:
return 1
if len(cut1.x) is 0 or len(cut2.x) is 0:
return 1
cut1i = interpolate(cut1, cut2)
diff = np.tanh((cut1i.y - cut2.y) / cut2.y)
return vartype.array_rms(diff, nan_replacement=np.nanmax(diff))
def _pick_spikes(wave1, wave2):
n = max(wave1.spike_count, wave2.spike_count)
# let's compare max 10 spikes
if n <= 10:
return range(n)
else:
return np.linspace(0, n-1, 10, dtype=int)
def ahp_curve_fitness(sim, measurement, full=False, error=ErrorCalc.relative):
''' Calculates
'''
m1, m2 = _select(sim, measurement,
sim.spike_count + measurement.spike_count > 0)
diffs = [ahp_curve_compare(ahp_curve_centered(wave1, i),
ahp_curve_centered(wave2, i))
for wave1, wave2 in zip(m1, m2)
for i in _pick_spikes(wave1, wave2)]
if not diffs:
return 0
assert 0 <= min(diffs) <= 1, diffs
assert 0 <= max(diffs) <= 1, diffs
diffs = np.array(diffs)
if full:
return diffs
else:
return vartype.array_rms(diffs, nan_replacement=NAN_REPLACEMENT)
class WaveHistogram:
"""Compute the difference between cumulative histograms of two waves
Since the x step might be different, we need to scale to the same
range. This is done by doing a frequency histogram, which
abstracts away the number of points in either plot.
"""
def __init__(self, wave1, wave2, left=-np.inf, right=+np.inf):
self.wave1 = wave1
self.wave2 = wave2
self.left = left
self.right = right
def x1(self):
return self.wave1.x[(self.wave1.x >= self.left) & (self.wave1.x <= self.right)]
def x2(self):
return self.wave2.x[(self.wave2.x >= self.left) & (self.wave2.x <= self.right)]
def y1(self):
return self.wave1.y[(self.wave1.x >= self.left) & (self.wave1.x <= self.right)]
def y2(self):
return self.wave2.y[(self.wave2.x >= self.left) & (self.wave2.x <= self.right)]
def hist(self, bins, y, cumulative=True):
hist = np.histogram(y, bins=bins, density=True)[0]
hist /= hist.sum()
if cumulative:
return np.cumsum(hist)
else:
return hist
def bins(self, n=50):
y1, y2 = self.y1(), self.y2()
low = min(y1.min(), y2.min())
high = max(y1.max(), y2.max())
return np.linspace(low, high, n)
def diff(self, full=False):
bins = self.bins()
hist1 = self.hist(bins, self.y1())
hist2 = self.hist(bins, self.y2())
diff = (hist2 - hist1) * bins.ptp()
if full:
return diff
else:
# we return something that is approximately the area betwen the CDFs
return np.abs(diff).sum()
def plot(self, figure):
from matplotlib import pyplot
ax1 = figure.add_subplot(121)
ax2 = figure.add_subplot(122)
ax1.plot(self.x1(), self.y1(), label='recording 1', color='blue')
ax1.plot(self.x2(), self.y2(), label='recording 2', color='red')
bins = self.bins()
hist1 = self.hist(bins, self.y1())
hist2 = self.hist(bins, self.y2())
diff = self.diff(full=True)
height = bins.ptp() / bins.size
ax2.barh(bins[:-1], hist1, height=height, alpha=0.2, color='blue')
ax2.barh(bins[:-1], hist2, height=height, alpha=0.2, color='red')
bars = ax2.barh(bins[:-1], left=hist1, width=hist2-hist1,
height=height, color='none', edgecolor='black')
for bar in bars:
bar.set_hatch('x')
ax2.set_title('cumulative histograms\ndiff={}'.format(np.abs(diff).sum()))
ax2.yaxis.set_major_formatter(pyplot.NullFormatter())
figure.tight_layout()
return ax1, ax2
def spike_range_y_histogram_fitness(sim, measurement, full=False, error=ErrorCalc.relative):
"""Match histograms of y-values in spiking regions
This returns an rms of `WaveHistogram.diff` over the injection
region. Waves are filtered to have at at least one spike between
the pair. This is done to make this fitness function sensitive to
depolarization block. Otherwise, the result would be dominated by
baseline mismatches and response mismatches.
`baseline_post_fitness` and `response_fitness` are better fitted
to detect mismatches in other regions.
"""
m1, m2 = _select(sim, measurement)
diffs = np.array([WaveHistogram(wave1.wave, wave2.wave,
wave1.injection_start, wave1.injection_end).diff()
for wave1, wave2 in zip(m1, m2)
if max(wave1.spike_count, wave2.spike_count) > 0])
if full:
return diffs
else:
return vartype.array_rms(diffs, nan_replacement=NAN_REPLACEMENT)
# Used in work-aju.py somebody might use this.
def hyperpol_fitness(sim, measurement, full=False, error=ErrorCalc.relative):
a = response_fitness(sim, measurement, error=error)
b1 = baseline_pre_fitness(sim, measurement, error=error)
b2 = baseline_post_fitness(sim, measurement, error=error)
c = rectification_fitness(sim, measurement, error=error)
d = falling_curve_time_fitness(sim, measurement, error=error)
e = spike_count_fitness(sim, measurement, error=error)
if error == ErrorCalc.normal:
arr = np.array([a, b1/5, b2/5, c*4, d/20, e])
else:
arr = np.array([a, b1, b2, c, d, e])
if full:
return arr
else:
return vartype.array_rms(arr, nan_replacement=NAN_REPLACEMENT)
def spike_fitness(sim, measurement, full=False, error=ErrorCalc.relative):
a = mean_isi_fitness(sim, measurement, error=error)
b = spike_latency_fitness(sim, measurement, error=error)
c = spike_width_fitness(sim, measurement, error=error)
d = spike_height_fitness(sim, measurement, error=error)
e = spike_ahp_fitness(sim, measurement, error=error)
f = spike_time_fitness(sim, measurement, error=error)
arr = np.array([a, b, c, d, e, f])
if full:
return arr
else:
return vartype.array_rms(arr, nan_replacement=NAN_REPLACEMENT)
class combined_fitness:
"""Basic weighted combinations of fitness functions
"""
presets = {
'empty' : collections.OrderedDict(),
'new_combined_fitness' : collections.OrderedDict(
response=1,
baseline_pre=1,
baseline_post=1,
rectification=1,
falling_curve_time=1,
spike_time=1,
spike_width=1,
spike_height=1,
spike_latency=1,
spike_ahp=1,
ahp_curve=1,
spike_range_y_histogram=1),
'simple_combined_fitness' : collections.OrderedDict(
response=1,
baseline=1,
rectification=1,
falling_curve_time=1,
mean_isi=1,
spike_latency=1,
spike_height=1,
spike_width=1,
spike_ahp=1,
spike_count=1,
isi_spread=1),
}
@staticmethod
def fitness_by_name(name):
''' Gets "_fitness" postfix functions using name. Returns a function object.'''
return globals()[name + '_fitness']
def __init__(self,
preset='new_combined_fitness',
*,
error=ErrorCalc.relative,
extra=None,
**kwargs):
"""Creates a weighted combination of features.
preset can be used to pick one of the starting sets of fitness
functions and their weights. To modify the weight for one of
the "known" functions from this module, the new weight can be
passed as a keyword argument:
>>> new_combined_fitness('new_combined_fitness', spike_latency=2.5)
Arbitrary fitness functions can be given as (weight, function) pairs
in extra:
>>> def fitness1(sim, measurement, full=False, error=ErrorCalc.relative):
... return 5
>>> f = combined_fitness('empty',
... extra={fitness1 : 0.5})
>>> print(f.report('a', 'b'))
fitness1=0.5*5=2.5
total: 2.5
"""
self.error = error
weights = self.presets[preset].copy()
weights.update(kwargs)
pairs1 = [(w, self.fitness_by_name(k)) # w -> (weight, function_object) gathers function names with postfix "_fitness"
for k, w in weights.items()]
pairs2 = [(w, k) for k, w in extra.items()] if extra else []
if set(f for w,f in pairs1).intersection(set(f for w,f in pairs2)):
raise ValueError('"known" function specified in extra')
self.pairs = pairs1 + pairs2
def _parts(self, sim, measurement, *, full=False):
for w, func in self.pairs:
if w or full:
yield (w, func(sim, measurement, error=self.error), func.__name__)
def __call__(self, sim, measurement, full=False):
# Computes feature fitnesses using _parts for one trace.
parts = [(feature_name, w*NAN_REPLACEMENT if r == vartype.vartype.nan else w*r) for w, r, feature_name in self._parts(sim, measurement)]
for feature_name, value in parts:
logger.debug("{} {}".format(feature_name, value))
if str(value) == str(np.nan):
logger.warning("Feature: {} fitness: {} Check Feature declaration in 'combined_fitness'!!!".format(feature_name, value))
arr = np.array([p[1] for p in parts])
if full:
return arr
else:
# Calculates RMS across feature. (fitness metrics.)
return vartype.array_rms(arr, nan_replacement=NAN_REPLACEMENT)
@property
def __name__(self):
return self.__class__.__name__
def report(self, sim, measurement, *, full=False):
parts = [(w, NAN_REPLACEMENT if r is vartype.vartype.nan else r, name) for w, r, name in self._parts(sim, measurement, full=full)]
desc = '\n'.join('{}={}*{:.2g}={:.2g}'.format(name, w, r, w*r)
for w, r, name in parts)
total = desc + '\n' + 'total: {:.02g}'.format(self.__call__(sim, measurement))
return total
def fit_sort(group, measurement, fitness):
w = np.array([fitness(sim, measurement) for sim in group])
w[np.isnan(w)] = np.inf
return np.array(group)[w.argsort()]
def fit_finished(fitness, cutoff=0.01, window=10):
isdf = isinstance(fitness, pd.DataFrame)
if not isdf:
fitness = pd.DataFrame(fitness)
dev = pd.rolling_var(fitness, window) ** 0.5
quit = dev / dev.max() < cutoff
if isdf:
return quit
else:
return quit.values.flatten()
def find_best(group, measurement, fitness):
w = np.array([fitness(sim, measurement) for sim in group])
w[np.isnan(w)] = np.inf
return group[w.argmin()]
def find_multi_best(group, measurement, fitness,
similarity=.10,
debug=False, full=False):
best = np.empty(0, dtype=object)
scores = np.empty(0)
for sim in group:
score = fitness(sim, measurement, full=1)
# ignore misfits
if np.isnan(score).any():
if debug:
print('dropping for nans:', sim)
continue
for i, key in enumerate(best):
if (scores[i] < score).all():
if debug:
print('dropping worse:', sim)
break
if scores.size:
dominates = (score < scores).all(axis=1)
if debug:
print('dropping', dominates.sum(), 'dominated')
best = np.hstack((best[-dominates], [sim]))
scores = np.vstack((scores[-dominates], score))
else:
best = np.hstack([sim])
scores = score[None, :]
if similarity:
# sort by rms
total = (scores ** 2).sum(axis=1)
order = total.argsort()
best = best[order]
scores = scores[order]
worse = np.empty_like(best, dtype=bool)
for i in reversed(range(best.size)):
similar = scores[i] - scores[:i]
worse[i] = ((similar ** 2).sum(axis=1) < total[i] * similarity).any()
scores = scores[-worse]
best = best[-worse]
if full:
return best, scores
else:
return best
def normalize_dimensions(vect):
mean = np.mean(vect, axis=0)
radius = np.ptp(vect, axis=0) / 2
trivial = radius == 0 # ignore non-variable parameters
return ((vect - mean) / radius).T[-trivial].T
find_nonsimilar_result = collections.namedtuple('find_nonsimilar_result', 'group scores params')
def find_nonsimilar(group, measurement, fitness,
similarity=.10):
from . import analysis
what = group[0].params.keys()
params, scores = analysis.convert_to_values(group, measurement, fitness, *what, full=1)
group = np.array(group, dtype=object)
scores = np.array(scores)
scores[np.isnan(scores)] = np.inf
total = (scores ** 2).sum(axis=1)
order = total.argsort()
group = group[order]
scores = scores[order]
params = params[order]
normalized = normalize_dimensions(params)
# sort by rms
duplicate = np.zeros_like(group, dtype=bool)
for i in range(group.size - 1):
if not duplicate[i]: # ignore the ones already ignored
diff = ((normalized[i + 1:] - normalized[i])**2).sum(axis=1)**0.5
duplicate[i + 1:] |= diff < similarity
return find_nonsimilar_result(group[-duplicate], scores[-duplicate], params[-duplicate])