import pandas as pd
from ipfx import feature_extractor
from scipy.stats import mode, pearsonr
from scipy import interpolate
from scipy.spatial import distance
from scipy.signal import find_peaks
import ot
from brian2 import pF, pA, nS, mV, NeuronGroup, pamp, run, second, StateMonitor, ms, TimedArray, size, nan, array, reshape, \
shape, volt, siemens, amp
try:
### these are some libraries for spike train assesment not needed if you are not calling spike dist
from elephant.spike_train_dissimilarity import victor_purpura_dist, van_rossum_dist
from neo.core import SpikeTrain
import quantities as pq
except:
print('Spike distance lib import failed')
from scipy.optimize import curve_fit
import numpy as np
from brian2 import plot
def detect_spike_times(dataX, dataY, dataC, sweeps=None, dvdt=7, swidth=10, speak=-10):
# requires IPFX (allen institute).
# works with abf and nwb
swidth /= 1000
if sweeps is None:
sweepList = np.arange(dataX.shape[0])
else:
sweepList = np.asarray(sweeps)
spikedect = feature_extractor.SpikeFeatureExtractor(filter=0, dv_cutoff=dvdt, max_interval=swidth, min_peak=speak)
spike_list = []
for sweep in sweepList:
sweepX = dataX[sweep, :]
sweepY = dataY[sweep, :]
sweepC = dataC[sweep, :]
try:
spikes_in_sweep = spikedect.process(sweepX, sweepY, sweepC) ##returns a dataframe
except:
spikes_in_sweep = pd.DataFrame()
if spikes_in_sweep.empty == True:
spike_list.append([])
else:
spike_ind = spikes_in_sweep['peak_t'].to_numpy()
spike_list.append(spike_ind)
return spike_list
def compute_threshold(dataX, dataY, dataC, sweeps, dvdt=20):
# requires IPFX (allen institute). Modified version by smestern
# install using pip install git+https://github.com/smestern/ipfx.git Not on git yet
# works with abf and nwb
if sweeps is None:
sweepList = np.arange(dataX.shape[0])
else:
sweepList = np.asarray(sweeps)
spikedect = feature_extractor.SpikeFeatureExtractor(filter=0, dv_cutoff=dvdt)
threshold_list = []
for sweep in sweepList:
sweepX = dataX[sweep, :]
sweepY = dataY[sweep, :]
sweepC = dataC[sweep, :]
spikes_in_sweep = spikedect.process(sweepX, sweepY, sweepC) ##returns a dataframe
if spikes_in_sweep.empty == False:
thres_V = spikes_in_sweep['threshold_v'].to_numpy()
threshold_list.append(thres_V)
return np.nanmean(threshold_list[0])
def compute_dt(dataX):
dt = dataX[0, 1] - dataX[0, 0]
dt = dt * 1000 # ms
return dt
def compute_steady_hyp(dataY, dataC, ind=[0,1]):
stim_index = find_stim_changes(dataC[0,:])
mean_steady= np.nanmean(dataY[:, stim_index[ind[0]]:stim_index[ind[1]]])
return mean_steady
def compute_rmp(dataY, dataC):
deflection = np.nonzero(dataC[0, :])[0][0] - 1
rmp1 = np.nanmean(dataY[:, :deflection])
rmp2 = rmp1 #mode(dataY[:, :deflection], axis=None)[0][0]
return rmp2
def find_stim_changes(dataI):
diff_I = np.diff(dataI)
infl = np.nonzero(diff_I)[0]
'''
dI = np.diff(np.hstack((0, dataI, 0))
'''
return infl
def find_downward(dataI):
diff_I = np.diff(dataI)
downwardinfl = np.nonzero(np.where(diff_I<0, diff_I, 0))[0][0]
return downwardinfl
def exp_decay_1p(t, a, b1, alphaFast):
return (a + b1*(1-np.exp(-t/alphaFast)))
def exp_decay_factor(dataT,dataV,dataI, time_aft=50, plot=False, sag=True):
try:
time_aft = time_aft / 100
if time_aft > 1:
time_aft = 1
if sag:
diff_I = np.diff(dataI)
downwardinfl = np.nonzero(np.where(diff_I<0, diff_I, 0))[0][0]
end_index = downwardinfl + int((np.argmax(diff_I)- downwardinfl) * time_aft)
upperC = np.amax(dataV[downwardinfl:end_index])
lowerC = np.amin(dataV[downwardinfl:end_index])
minpoint = np.argmin(dataV[downwardinfl:end_index])
end_index = downwardinfl + int(.99 * minpoint)
downwardinfl = downwardinfl #+ int(.10 * minpoint)
else:
diff_I = np.diff(dataI)
downwardinfl = np.nonzero(np.where(diff_I<0, diff_I, 0))[0][0]
end_index = downwardinfl + int((np.argmax(diff_I)- downwardinfl) * time_aft)
upperC = np.amax(dataV[downwardinfl:end_index])
lowerC = np.amin(dataV[downwardinfl:end_index])
diff = np.abs(upperC - lowerC)
t1 = dataT[downwardinfl:end_index] - dataT[downwardinfl]
curve, pcov_1p = curve_fit(exp_decay_1p, t1, dataV[downwardinfl:end_index]/1000, maxfev=500000, bounds=([(upperC-0.5)/1000, -np.inf, 0], [(upperC+0.5)/1000, np.inf, np.inf]), xtol=None)
tau = curve[2]
if plot:
plt.figure(2)
plt.clf()
plt.plot(t1, dataV[downwardinfl:end_index]/1000, label='Data')
plt.plot(t1, exp_decay_1p(t1, *curve), label='1 phase fit')
plt.legend()
plt.pause(3)
return tau
except:
return 0
def compute_sag(dataT,dataV,dataI, time_aft=50):
min_max = [np.argmin, np.argmax]
find = 0
time_aft = time_aft / 100
if time_aft > 1:
time_aft = 1
diff_I = np.diff(dataI)
upwardinfl = np.nonzero(np.where(diff_I>0, diff_I, 0))[0][0]
downwardinfl = np.nonzero(np.where(diff_I<0, diff_I, 0))[0][0]
if upwardinfl < downwardinfl: #if its depolarizing then swap them
temp = downwardinfl
downwardinfl = upwardinfl
upwardinfl = temp
else:
pass
find = 1
dt = dataT[1] - dataT[0] #in s
end_index = upwardinfl - int(0.100/dt)
end_index2 = upwardinfl - int((upwardinfl - downwardinfl) * time_aft)
if end_index<downwardinfl:
end_index = upwardinfl - 5
vm = np.nanmean(dataV[end_index:upwardinfl])
min_point = downwardinfl + min_max[find](dataV[downwardinfl:end_index2])
avg_min = np.nanmean(dataV[min_point])
sag_diff = avg_min - vm
return sag_diff, vm
def membrane_resistance(dataT,dataV,dataI):
try:
diff_I = np.diff(dataI)
downwardinfl = np.nonzero(np.where(diff_I<0, diff_I, 0))[0][0]
end_index = downwardinfl + int((np.argmax(diff_I)- downwardinfl)/2)
upperC = np.mean(dataV[:downwardinfl-100])
lowerC = np.mean(dataV[downwardinfl+100:end_index-100])
diff = -1 * np.abs(upperC - lowerC)
I_lower = dataI[downwardinfl+1]
t1 = dataT[downwardinfl:end_index] - dataT[downwardinfl]
#v = IR
#r = v/I
v_ = diff / 1000 # in mv -> V
I_ = I_lower / 1000000000000 #in pA -> A
r = v_/I_
return r #in ohms
except:
return np.nan
def membrane_resistance_subt(dataT, dataV,dataI):
resp_data = []
stim_data = []
for i, sweep in enumerate(dataV):
abs_min, resp = compute_sag(dataT[i,:], sweep, dataI[i,:])
ind = find_stim_changes(dataI[i, :])
baseline = np.mean(sweep[:ind[0]])
stim = dataI[i,ind[0] + 1]
stim_data.append(stim)
resp_data.append((resp+abs_min) - baseline)
resp_data = np.array(resp_data) * mV
stim_data = np.array(stim_data) * pA
res = linregress(stim_data / amp, resp_data / volt)
resist = res.slope * ohm
return resist / Gohm
def mem_cap(resist, tau_1p):
#tau = RC
#C = R/tau
C_1p = tau_1p / resist
return C_1p ##In farads?
def create_atf(data, filename="output.atf", rate=20000):
"""Save a stimulus waveform array as an ATF 1.0 file."""
ATF_HEADER="""
ATF 1.0
8 2
"AcquisitionMode=Episodic Stimulation"
"Comment="
"YTop=2000"
"YBottom=-2000"
"SyncTimeUnits=20"
"SweepStartTimesMS=0.000"
"SignalsExported=IN 0"
"Signals=" "IN 0"
"Time (s)" "Trace #1"
""".strip()
out=ATF_HEADER
for i,val in enumerate(data):
out+="\n%.05f\t%.05f"%(i/rate,val)
with open(filename,'w') as f:
f.write(out)
print("wrote",filename)
return
def plot_adex_state(adex_state_monitor):
"""
Visualizes the state variables: w-t, v-t and phase-plane w-v
from https://github.com/EPFL-LCN/neuronaldynamics-exercises/
Args:
adex_state_monitor (StateMonitor): States of "v" and "w"
"""
import matplotlib.pyplot as plt
plt.figure(num=12, figsize=(10,10))
plt.clf()
plt.subplot(2, 2, 1)
plt.plot(adex_state_monitor.t / ms, adex_state_monitor.v[0] / mV, lw=2)
plt.xlabel("t [ms]")
plt.ylabel("u [mV]")
plt.title("Membrane potential")
plt.subplot(2, 2, 2)
plt.plot(adex_state_monitor.v[0] / mV, adex_state_monitor.w[0] / pA, lw=2)
plt.xlabel("u [mV]")
plt.ylabel("w [pAmp]")
plt.title("Phase plane representation")
plt.subplot(2, 2, 3)
plt.plot(adex_state_monitor.t / ms, adex_state_monitor.w[0] / pA, lw=2)
plt.xlabel("t [ms]")
plt.ylabel("w [pAmp]")
plt.title("Adaptation current")
def compute_sse(y, yhat):
sse = np.sum(np.square(y - yhat))
return sse
def compute_mse(y, yhat):
mse = np.mean(np.square(y - yhat))
return mse
def compute_se(y, yhat):
se = np.square(y - yhat)
return se
def compute_emd_1d(y, yhat):
#create a linspace of the length of the y vector
x_a = np.linspace(0, len(y), len(y))
x_b = np.linspace(0, len(yhat), len(yhat))
if len(y) != len(yhat):
raise ValueError("y and yhat must be the same length")
#if y or yhat is all zeros add a small offset to avoid division by zero
if np.sum(y) < 1e-9:
y = y + 0.001
if np.sum(yhat) == 0:
yhat = yhat + 0.001
y = y/np.sum(y)
yhat = yhat/np.sum(yhat)
#compute the emd
dist = ot.emd2_1d(x_a, x_b, y, yhat)
return dist
def equal_array_size_1d(array1, array2, method='append', append_val=0):
ar1_size = array1.shape[0]
ar2_size = array2.shape[0]
if ar1_size == ar2_size:
pass
elif method == 'append':
if ar1_size > ar2_size:
array2 = np.hstack((array2, np.full(ar1_size - ar2_size, append_val)))
elif ar2_size > ar1_size:
array1 = np.hstack((array1, np.full(ar2_size - ar1_size, append_val)))
elif method == 'trunc':
if ar1_size > ar2_size:
array1 = array1[:ar2_size]
elif ar2_size > ar1_size:
array2 = array2[:ar1_size]
elif method == 'interp':
if ar1_size > ar2_size:
interp = interpolate.interp1d(np.linspace(1,ar2_size-1, ar2_size), array2, bounds_error=False, fill_value='extrapolate')
new_x = np.linspace(ar2_size, ar1_size, (ar1_size - ar2_size))
array2 = np.hstack((array2, interp(new_x)))
elif ar2_size > ar1_size:
interp = interpolate.interp1d(np.linspace(1,ar1_size-1, ar1_size), array1, bounds_error=False, fill_value='extrapolate')
new_x = np.linspace(ar1_size, ar2_size, (ar2_size - ar1_size))
array2 = np.hstack((array1, interp(new_x)))
return array1, array2
def compute_spike_dist(y, yhat):
'''
Computes the distance between the two spike trains
takes arrays of spike times in seconds
'''
#y, yhat = equal_array_size_1d(y, yhat, 'append')
train1 = SpikeTrain(y*pq.s, t_stop=6*pq.s)
train2 = SpikeTrain(yhat*pq.s, t_stop=6*pq.s)
dist = van_rossum_dist([train1, train2], tau=40*pq.ms)
## Update later to compute spike distance using van rossum dist
r_dist = dist[0,1] #returns squareform so just
return r_dist
def compute_spike_dist_euc(y, yhat):
'''
Computes the distance between the two spike trains
takes arrays of spike times in seconds
'''
y, yhat = equal_array_size_1d(y, yhat, 'append', append_val=0)
if len(y) < 1 and len(yhat) < 1:
dist = 999
else:
dist = distance.euclidean(y, yhat)
r_dist = dist
return r_dist
def compute_corr(y, yhat):
y, yhat = equal_array_size_1d(y, yhat, 'append')
y = np.nan_to_num(y, nan=0, posinf=0, neginf=0)
yhat = np.nan_to_num(yhat, nan=0, posinf=0, neginf=0)
try:
corr_coef = pearsonr(y, yhat)
except:
corr_coef = 0
return np.amax(corr_coef)
def replace_nan(a):
temp = a.copy()
temp[np.isnan(a)] = np.nanmax(a)
return temp
def drop_rand_rows(a, num):
rows = a.shape[0]-1
rows_to_drop = np.random.rarandint(0, rows, num)
a = np.delete(a,rows_to_drop,axis=0)
return a
def compute_distro_mode(x, bin=20, wrange=False):
if wrange:
bins = np.arange(np.amin(x)-bin, np.amax(x)+bin, bin)
else:
bins = np.arange(0, np.amax(x)+bin, bin)
hist, bins = np.histogram(x, bins=bins)
return bins[np.argmax(hist)]
def compute_corr_minus(y, yhat):
y, yhat = equal_array_size_1d(y, yhat, 'append')
y = np.nan_to_num(y, nan=0, posinf=0, neginf=0)
yhat = np.nan_to_num(yhat, nan=0, posinf=0, neginf=0)
try:
corr_coef = 1 - np.amax(pearsonr(y, yhat))
except:
corr_coef = 1
return corr_coef
def compute_FI(spkind, dt, dataC):
isi = [ dt*np.diff(x) for x in spkind ]
f = [ np.reciprocal(x) for x in isi ]
i = []
for ii in range(len(dataC)):
tmp = dataC[ii]
tmp1 = spkind[ii][:-1]
i.append(tmp[tmp1])
return f, i, isi
def compute_min_stim(dataY, dataX, strt, end):
#find the strt, end
index_strt = np.argmin(np.abs(dataX - strt))
index_end = np.argmin(np.abs(dataX - end))
#Find the min
amin = np.amin(dataY[index_strt:index_end])
return amin
def compute_FI_curve(spike_times, time, bin=20):
FI_full = []
isi=[]
for r in spike_times:
if len(r) > 0:
FI_full.append(len(r))
if len(r) > 1:
isi_row = np.diff(r)
isi.append(np.nanmean(isi_row*1000))
else:
isi.append(0)
else:
FI_full.append(0)
isi.append(0)
return (np.hstack(FI_full) /time), np.hstack(isi)
def compute_sweepwise_isi_hist(spike_times, time, bins=np.logspace(0, 3, 100)):
isi_hist = []
for r in spike_times:
if len(r) > 0:
isi_hist.append(np.diff(r)*1000)
return np.histogram(np.hstack(isi_hist), bins=bins)[0]
def add_spikes_to_voltage(spike_times,voltmonitor, peak=33, index=0):
if len(spike_times) > 0:
trace_round = np.around(voltmonitor.t/ms, decimals=0)
spikes_round = np.around(spike_times, decimals=0)
spike_idx = np.isin(trace_round, spikes_round)
traces_v = voltmonitor[index].v/mV
traces_v[spike_idx] = peak
else:
traces_v = voltmonitor[index].v/mV
return traces_v