import efel
import numpy as np
import math
import matplotlib.pyplot as plt
stim_start = 100#ms
stim_end = 400#ms
decay_start_after_stim =1
decay_end_after_stim = 10
#feature_list = ['Spikecount','time_to_first_spike', 'min_AHP_indices','peak_voltage','mean_AP_amplitude', 'AHP_depth','AP_begin_indices', 'spike_half_width', 'min_AHP_indices','AP_amplitude','min_AHP_values','voltage_base','steady_state_voltage_stimend']
#feature_list = ['voltage_base','steady_state_voltage_stimend','decay_time_constant_after_stim','sag_amplitude','ohmic_input_resistance','voltage_after_stim','voltage_deflection','voltage_deflection_vb_ssse','sag_ratio1','sag_ratio2']
#feature_list = ['voltage_base','steady_state_voltage_stimend','voltage_after_stim','ohmic_input_resistance']
all_feature_list = ['voltage_base','AP_amplitude','voltage_after_stim','peak_time','spike_half_width','AHP_depth','chi']
feature_list = ['voltage_base','AP_amplitude','voltage_after_stim','peak_time','spike_half_width','AHP_depth']
feature_number = len(feature_list)
efel.api.setDoubleSetting('Threshold', -30)
def get_chi(orig_volts,volts,times):
chi = []
for curr_volts in volts:
chi.append(np.sum(np.sqrt(np.square(np.subtract(orig_volts,curr_volts))))/len(times))
return chi
def eval(target_volts_list, data_volts_list,times):
def diff_lists(lis1, lis2):
if lis1 is None or lis2 is None:
return 1000
len1, len2 = len(lis1), len(lis2)
if len1 > len2:
lis2 = np.concatenate((lis2, np.zeros(len1 - len2)), axis=0)
if len2 > len1:
lis1 = np.concatenate((lis1, np.zeros(len2 - len1)), axis=0)
# print(np.sqrt(((lis1 - lis2)**2).mean()))
# print('\n')
return np.sqrt(((lis1 - lis2)**2).mean())
all_features = []
curr_trace_target = {}
curr_trace_target['T'] = times
curr_trace_target['V'] = target_volts_list[0]
curr_trace_target['stim_start'] = [stim_start]
curr_trace_target['stim_end'] = [stim_end]
# curr_trace_target['stimulus_current'] = [-0.40]
curr_trace_target['decay_start_after_stim'] = [decay_start_after_stim]
curr_trace_target['decay_end_after_stim'] = [decay_end_after_stim]
#feature = efel.getFeatureValues([curr_trace_target], feature_list)
traces = [curr_trace_target]
nan_inds_bol = np.isnan(data_volts_list).any(axis=1)
nan_inds = [i for i, x in enumerate(nan_inds_bol) if x]
data_volts_list = np.delete(data_volts_list,nan_inds,axis=0)
for i in range(len(data_volts_list)):
curr_trace_data = {}
curr_trace_data['T'] = times
curr_trace_data['V'] = data_volts_list[i]
curr_trace_data['stim_start'] = [stim_start]
curr_trace_data['stim_end'] = [stim_end]
# curr_trace_data['stimulus_current'] = [-0.40]
curr_trace_data['decay_start_after_stim'] = [decay_start_after_stim]
curr_trace_data['decay_end_after_stim'] = [decay_end_after_stim]
traces.append(curr_trace_data)
print('in efel before getting features')
traces_results = efel.getFeatureValues(traces, feature_list)
print('in efel after getting features')
if 'chi' in all_feature_list:
all_chis = get_chi(target_volts_list[0],data_volts_list,times)
for i in range(len(data_volts_list)):
curr_feature_list=[]
f_counter = 0
for feature_name in all_feature_list:
if feature_name is not 'chi':
diff_feature = diff_lists(traces_results[0][feature_name], traces_results[i+1][feature_name])
# diff_feature = diff_feature * weights[f_counter]
if math.isnan(diff_feature):
diff_feature = 10000
# if diff_feature == 0:
# print('i is' +str(i) + 'feature is ' + feature_name)
# plt.plot(data_volts_list[i],'r')
# plt.plot(target_volts_list[0],'b')
# plt.show()
else:
diff_feature = all_chis[i]
if math.isnan(diff_feature):
diff_feature = 10000
curr_feature_list.append(diff_feature)
f_counter +=1
all_features.append(tuple(curr_feature_list))
all_features = np.array(all_features)
res = []
counter = 0
for ind in nan_inds_bol:
if ind:
res.append(np.zeros(len(all_feature_list))+100000)
else:
res.append(all_features[counter])
counter +=1
print(['best indvs ',res[0]])
return res