import pickle
from pathlib import Path
import numpy as np
from brian2 import Hz, second, ms
from scipy.signal import find_peaks
from sklearn import metrics
from tqdm import tqdm
def gaussian(x, mu, sig):
"""Create gaussian waveform on domain x
:param x: Values to compute over
:param mu: Gaussian center
:param sig: Gaussian variance
:return: gaussian(x)
"""
return 1. / (np.sqrt(2. * np.pi) * sig) * np.exp(-np.power((x - mu) / sig, 2.) / 2)
def infer_frequency(spike_times, t1=0.6, t2=0.3, res=0.1, prom=0.5, time_step=0.01):
"""Try to infer the frequency from the superposition method.
:param spike_times: Spike times
:param t1: Variance of the isi gaussian
:param t2: Variance of the diff gaussian
:param res: Width of the ISI bins
:param prom: Minimum prominence
:param time_step: Width of the frequency bins
:return: Inferred frequency
"""
try:
isi_bin = np.arange(0, 350, res)
isi = np.diff(spike_times)
# replace ISI with sum of gaussians at each ISI value
isi_hist_g = np.sum(np.vstack([gaussian(isi_bin, i, t1) for i in isi]), axis=0)
isi_hist_g /= isi_hist_g.sum() # normalize gaussian
pks, _ = find_peaks(isi_hist_g, prominence=prom)
pk_location_diffs = np.diff(isi_bin[pks]) # find difference between peaks (ix+n skipping value)
peak_loc_bins = np.arange(0, 15, time_step)
peak_hist_g = np.sum(np.vstack([gaussian(peak_loc_bins, pk, t2) for pk in pk_location_diffs]), axis=0)
return 1000 / (peak_loc_bins[np.argmax(peak_hist_g)])
except:
return np.nan # lots can go wrong but in all cases assume nan
def infer_amplitude(spike_times):
"""Infer a proxy for the amplitude
:param spike_times: Times of the spikes
:return: Number of spikes (proxy for amplitude)
"""
return len(spike_times)
def discriminate(d1, d2, a, b):
""" Discriminate between 2 sets of measures
:param d1: Measure 1
:param d2: Measure 2
:param a: True value of d1
:param b: True value of d2
:return: Roc of discrimination
"""
if a > b:
d1, d2 = d2, d1
# clean values from d1 and d2
# if there's any funny business both are set to 0 i.e. 50/50 disc.
d1 = np.array(d1)
d1[~np.isfinite(d1)] = 0
d1 = d1[np.isfinite(d1)].tolist()
d2 = np.array(d2)
d2[~np.isfinite(d2)] = 0
d2 = d2[np.isfinite(d2)].tolist()
score = d1 + d2 # concatenate
true_value = [0] * len(d1) + [1] * len(d2)
try:
false_positive_rate, true_positive_rate, threshold = metrics.roc_curve(list(map(int, true_value)), score,
pos_label=int(1))
auc = metrics.auc(false_positive_rate, true_positive_rate)
return auc
except:
Warning('Something went wrong')
return 0
def create_param_range(x0, x1, res):
""" Create a set of parameters.
:param x0: Lower bound
:param x1: Upper bound
:param res: Step size
:return: Parameter set
"""
return np.arange(x0, x1 + res, res)
def evaluate_amplitude_discrimination(model, *, a0, a1, res, unit, f0=100, f1=600, t_max=10 * second):
""" Test a model for amplitude discrimination
:param model: Model to run
:param a0: Min amplitude
:param a1: Max amplitude
:param res: Amplitude resolution
:param unit: Amplitude unit (amp or siemens)
:param f0: Min frequency
:param f1: Max frequency
:param t_max: Sim length
:return: Spike time, inferred amp. measure
"""
trial_amplitude = create_param_range(a0, a1, res)
model.store('initial') # set initial state
# create placeholders
spike_times = {}
infered_amplitude = {}
for amp in tqdm(trial_amplitude):
model.restore('initial') # reset initial conditions
model.set_stimulus_current(amp * unit)
model.f = np.random.uniform(f0, f1, model.neurons.N) * Hz # set a random frequency
model.run(t_max, report=None)
spike_times[amp] = model.spike_train
infered_amplitude[amp] = list(map(infer_amplitude, model.spike_train.values()))
return spike_times, infered_amplitude
def evaluate_frequency_discrimination(model, *, f0, f1, res, unit, a0=50, a1=100, t_max=10 * second,
t1=None, t2=None, prom=None, infer_res=0.25, time_step=0.05):
""" Test a model for frequency discrimination
:param model: Model to run
:param f0: Min frequency
:param f1: Max frequency
:param res: Frequency resolution
:param unit: Amplitude unit (amp or siemens)
:param a0: Min amplitude
:param a1: Max amplitude
:param t_max: Sim length
:param t1: Disc parameter t1
:param t2: Disc parameter t2
:param prom: Min prominence
:param infer_res: Disc parameter infer_res
:param time_step: Disc parameter time_Step
:return: Spike time, inferred freq. measure
"""
trial_frequency = create_param_range(f0, f1, res)
model.store('initial') # set initial state
# create placeholders
spike_times = {}
infered_frequency = {}
for freq in tqdm(trial_frequency):
model.restore('initial') # reset initial conditions
model.f = freq * Hz
model.set_stimulus_current(np.random.uniform(a0, a1, model.neurons.N) * unit)
model.run(t_max, report=None)
spike_times[freq] = model.spike_train
if t1 is not None:
infered_frequency[freq] = list(map(lambda x:
infer_frequency(x / ms, t1=t1, t2=t2, prom=prom,
res=infer_res, time_step=time_step),
model.spike_train.values()))
return spike_times, infered_frequency
def save(sim_name, sim_type, spike_times, infered_values):
""" Save a set of results
:param sim_name: Name of simulation
:param sim_type: amplitude or frequency
:param spike_times: Dict of spike times
:param infered_values: Measure for each experiment
"""
with open(Path('save_data') / f'{sim_name}_{sim_type}.pkl', 'wb') as f:
pickle.dump((spike_times, infered_values), f, protocol=pickle.HIGHEST_PROTOCOL)
def load(sim_name, sim_type):
""" Load a given sim_name and class
:param sim_name: Name of a simulation
:param sim_type: amplitude or frequency
:return: Data
"""
with open(Path('save_data') / f'{sim_name}_{sim_type}.pkl', 'rb') as f:
return pickle.load(f)
def load_amplitude(sim_name):
""" Helper function to load an amplitude sim
:param sim_name: Name of the simulatiion
:return: Data
"""
return load(sim_name, 'A')
def load_frequency(sim_name):
""" Helper function to load a frequency sim
:param sim_name: Name of the simulatiion
:return: Data
"""
return load(sim_name, 'f')
def save_amplitude(sim_name, spike_times, infered_amplitude):
""" Helper function to save an amplitude sim
:param sim_name: Name of the simulatiion
:param spike_times: Dict of spike times
:param infered_amplitude: Infered amplitude
"""
save(sim_name, 'A', spike_times, infered_amplitude)
def save_frequency(sim_name, spike_times, infered_frequency):
""" Helper function to save a frequency sim
:param sim_name: Name of the simulatiion
:param spike_times: Dict of spike times
:param infered_frequency: Infered freq
"""
save(sim_name, 'f', spike_times, infered_frequency)
def discrimination_combinations(infered_values):
""" Compute discrimination for all combinations.
:param infered_values: param:measure dict of inferred values
:return: Meshgrid for x, y (param_x, param_x) and z (discriminability)
"""
levels = list(infered_values.keys())
x = y = levels
xx, yy = np.meshgrid(x, y)
z = np.array(
[discriminate(infered_values[x], infered_values[y], x, y) for (x, y) in zip(np.ravel(xx), np.ravel(yy))])
zz = z.reshape(xx.shape)
return xx, yy, zz