import numpy as np
from sklearn import mixture
def hvsd(x): # heaviside step function
return .5*(1.+np.sign(x))*x
def gaussian(x, mean, std):
output = np.exp(-(x-mean)**2/2./std**2)/np.sqrt(2.*np.pi)/std
return output
def fit_3gaussians(Vm, n=1000, ninit=3, bound1=-90, bound2=-35):
clf = mixture.GaussianMixture(n_components=3, max_iter=n, n_init=ninit,
means_init=((-80,), (-65,), (-50,)), covariance_type='spherical')
clf.fit(np.array((Vm[(Vm>bound1) & (Vm<bound2)],)).T)
return clf.weights_, clf.means_.flatten(), np.sqrt(clf.covariances_)
def fit_2gaussians(Vm, n=1000, ninit=3, bound1=-90, bound2=-35, means_init=((-80,), (-50,))):
clf = mixture.GaussianMixture(n_components=2, max_iter=n, n_init=ninit,
covariance_type='spherical')
clf.fit(np.array((Vm[(Vm>bound1) & (Vm<bound2)],)).T)
return clf.weights_, clf.means_.flatten(), np.sqrt(clf.covariances_)
def determine_thresholds(weights, means, stds, down_state_security=1.):
""" Gives the thresholds given the Gaussian Mixture"""
i0, i1 = np.argmin(means[0:2]), np.argmax(means[1:3])+1
alpha = 1.-np.exp(-hvsd(means[i1]-2.*stds[i1]-means[i0]-2.*stds[i0]-down_state_security)/5.)
return means[i0]+2.*alpha*stds[i0]+down_state_security, means[i1]-2.*alpha*stds[i1]
def loop_over_sliding_window(data, window_size=5., window_update=2.5):
# Size of X windows
WS = int(window_size/data['dt'])
# Number of those windows
N_windows = int(data['t'][-1]/window_update)
WS_small = int(window_update/data['dt'])
threshold1, threshold2 = 0.*data['t'], 0.*data['t']
for ii in range(N_windows-int(window_size/window_update)):
icenter = WS/2.+ii*WS_small
i0, i1 = int(icenter-WS/2.), int(icenter+WS/2.)
try:
t1, t2 = determine_thresholds(*fit_3gaussians(data['Vm'][i0:i1]))
except ValueError: # means overfitting
t1, t2 = determine_thresholds(*fit_2gaussians(data['Vm'][i0:i1]))
threshold1[i0:], threshold2[i0:] = t1+0.*data['t'][i0:], t2+0.*data['t'][i0:]
return threshold1, threshold2 # adding 1mV to Down state
if __name__ == '__main__':
import sys
sys.path.append('../..')
from data_analysis.IO.load_data import load_file, get_formated_data
import state_classification
data = get_formated_data('/Users/yzerlaut/DATA/Exps_Ste_and_Yann/2016_12_6/16_48_19_VM-FEEDBACK--OSTIM-AT-VARIOUS-DELAYS.bin')
t1, t2 = loop_over_sliding_window(data)
UD_transitions, DU_transitions = state_classification.get_transition_times(data['t'], data['Vm'], t1, t2)
import matplotlib.pylab as plt
T0, T1 = 50., 70.
zoom = (data['t']>T0) & (data['t']<T1)
plt.plot(data['t'][zoom], data['Vm'][zoom], 'k-')
plt.plot(data['t'][zoom], t1[zoom], 'r-')
plt.plot(data['t'][zoom], t2[zoom], 'b-')
for tt in UD_transitions[(UD_transitions>T0) & (UD_transitions<T1)]:
plt.plot([tt, tt], [data['Vm'].min(), data['Vm'].max()], 'r-', lw=2)
for tt in DU_transitions[(DU_transitions>T0) & (DU_transitions<T1)]:
plt.plot([tt, tt], [data['Vm'].min(), data['Vm'].max()], 'b-', lw=2)
plt.show()