import numpy as np
from ajustador import fitnesses
import measurements1
wnames, waves = zip(*measurements1.waves.items())
fitness_list = [
fitnesses.response_fitness,
fitnesses.baseline_fitness,
fitnesses.rectification_fitness,
fitnesses.charging_curve_fitness,
fitnesses.falling_curve_time_fitness,
fitnesses.mean_isi_fitness,
fitnesses.spike_time_fitness,
fitnesses.spike_count_fitness,
fitnesses.spike_latency_fitness,
fitnesses.spike_width_fitness,
fitnesses.spike_height_fitness,
fitnesses.spike_ahp_fitness,
fitnesses.hyperpol_fitness,
fitnesses.spike_fitness,
fitnesses.simple_combined_fitness,
]
import pytest
@pytest.mark.parametrize("w2", waves, ids=wnames)
@pytest.mark.parametrize("w1", waves, ids=wnames)
@pytest.mark.parametrize("fitness", fitness_list, ids=[f.__name__ for f in fitness_list])
def test_basics(w1, w2, fitness):
y = fitness(w1, w2)
if np.isnan(float(y)):
return
same = w1 is w2
disjoint = not (w1.injection[:,None] == w2.injection).any()
repeats = (np.diff(w1.injection) < 1e-14).any()
if same or disjoint or repeats:
assert y >= 0
else:
assert y > 0