from Na12HMMModel_TF import *
import matplotlib.pyplot as plt
import numpy as np
from neuron import h
import efel
efel.api.setDoubleSetting('Threshold', -40)
import pandas as pd
import math
from scipy.signal import find_peaks
def get_sim_volt_values(sim,mut_name,rec_extra = False,dt = 0.005,stim_amp = 0.5):
sim.dt= dt
rec_extra = True
sim.l5mdl.init_stim(amp=stim_amp)
if rec_extra:
Vm, I, t, stim,extra_vms = sim.l5mdl.run_model(dt=dt,rec_extra = rec_extra)
sim.extra_vms = extra_vms
else:
Vm, I, t, stim = sim.l5mdl.run_model(dt=dt)
extra_vms = {}
return Vm,t,extra_vms,I,stim
def get_sim_volt_valuesTF(sim,mut_name,dt = 0.005,stim_amp = 0.5):
sim.dt= dt
#sim.make_het()
rec_extra = False
sim.l5mdl.init_stim(amp=stim_amp)
if rec_extra:
Vm, I, t, stim,_ = sim.l5mdl.run_sim_model(dt=dt)
# sim.extra_vms = extra_vms
else:
Vm, I, t, stim,_ = sim.l5mdl.run_sim_model(dt=dt)
# extra_vms = {}
return Vm,t,I,stim
def get_features(sim,prefix=None,mut_name = 'na12WT',rec_extra=True):
print("running routine")
dt=0.005#0.1#0.005
Vm,t,extra_vms,_,__ = get_sim_volt_values(sim,mut_name,rec_extra=rec_extra)
stim_start = 200 #100 original
stim_end = 1900 #800 original
trace={}
trace = {'T':t,'V':Vm,'stim_start':[stim_start],'stim_end':[stim_end]}
trace['T']= trace['T'] * 1000
#for neu
feature_list= ['AP_height','AP_width','AP1_peak','AP1_width','Spikecount','all_ISI_values']
traces = [trace]
features = efel.getFeatureValues(traces,feature_list)
try:
features[0]['ISI mean'] =features[0]['all_ISI_values'].mean()
except Exception as e:
features[0]['ISI mean'] = 0
features[0]['AP_height'] =features[0]['AP_height'].mean()
features[0]['AP_width'] =features[0]['AP_width'][0].mean()
features[0]['AP1_peak'] =features[0]['AP1_peak'][0]
features[0]['AP1_width'] =features[0]['AP1_width'][0]
features[0]['Spikecount'] =features[0]['Spikecount'][0]
spike_count= features[0]['Spikecount']
isi_values = features[0]['all_ISI_values']
median_spike = int(math.floor(spike_count/2)) + 1
start = int((stim_start + isi_values[0:median_spike-1].sum())/dt)
start2 = int((stim_start + isi_values[0:5].sum())/dt)
start3 = int((stim_start+isi_values[1])/dt)
start4 = int((stim_start+isi_values[0:3].sum())/dt)
try:
end = 380000
except Exception as e:
end=400000
print("There were not enough spikes to calculate median isi")
volt_segment = Vm[start4:start2]
dvdt = np.gradient(volt_segment) / dt
filtered_indices = np.where(dvdt > 50)[0]
filtered_dvdt = dvdt[filtered_indices]
filtered_volt_segment = volt_segment[filtered_indices]
dvdtslope = np.diff(filtered_dvdt)
peaks, peaks_vals = find_peaks(filtered_dvdt)
negative_slope_indices = []
negative_slopes = []
n=5
for peak in peaks:
start_index = peak + n
if start_index < len(dvdtslope):
segment_indices = np.where(dvdtslope[start_index:] < 0)[0] + start_index + 1
segment_indices = segment_indices[segment_indices < len(filtered_dvdt)]
negative_slope_indices.extend(segment_indices)
negative_slopes.extend(dvdtslope[segment_indices - 1])
negative_slope_indices = np.array(negative_slope_indices)
negative_slopes = np.array(negative_slopes)
negative_slope_indices, unique_indices = np.unique(negative_slope_indices, return_index=True)
negative_slopes = np.array(negative_slopes)[unique_indices]
change_in_slopes = np.diff(negative_slopes)
negative_slopes_truncated=negative_slopes[:-1]
change_in_slopes=change_in_slopes[1:]
negative_slope_indices = negative_slope_indices[1:]
negative_slopes_truncated=negative_slopes_truncated[1:]
least_change_index = np.argmin(np.abs(change_in_slopes) + np.abs(negative_slopes_truncated))
dvdt_at_least_change = filtered_dvdt[negative_slope_indices[least_change_index]]
negative_slope_indices = negative_slope_indices[1:]
fig, ax1 = plt.subplots()
ax1.plot(filtered_dvdt, label='Filtered DVDT')
ax1.set_xlabel('Index')
ax1.set_ylabel('DVDT')
ax1.axvline(x=negative_slope_indices[least_change_index], color='r', linestyle='--', label='Least Change in Slope')
ax1.scatter(negative_slope_indices[least_change_index], filtered_volt_segment[negative_slope_indices[least_change_index]], color='r', label='Least Change Point')
ax1.legend(loc='upper right')
ax1.annotate(f'dvdt: {dvdt_at_least_change:.2f}',
xy=(negative_slope_indices[least_change_index], filtered_volt_segment[negative_slope_indices[least_change_index]]),
xytext=(negative_slope_indices[least_change_index] + 5, filtered_volt_segment[negative_slope_indices[least_change_index]] + 5))
ax2 = ax1.twinx()
ax2.plot(negative_slope_indices, negative_slopes_truncated, label='Negative Slopes', color='g',linewidth=0.2)
ax2.set_ylabel('Negative Slopes')
ax2.legend(loc='lower left')
plt.title('Filtered Voltage Segment with Least Change in Slope')
# Save the plot as a PDF
fig.savefig(f'{mut_name}_dvdt_slopes.pdf')
#### End shoulder-finding code ####
curr_peaks_indices,curr_peaks_values= find_peaks(dvdt,height = 100)
print(f'start: {start}, start2:{start2}, end: {end}')
print(f'volt_segment: {volt_segment}')
print(f'dvdt: {dvdt}')
print(f'curr_peaks_indices: {curr_peaks_indices}')
print(f'curr_peaks_values: {curr_peaks_values}')
features[0]['dvdt Peak1 Height'] = curr_peaks_values['peak_heights'][0]
features[0]['dvdt Peak1 Voltage'] = volt_segment[curr_peaks_indices[0]]
features[0]['dvdt Peak2 Height'] = curr_peaks_values['peak_heights'][-1]
features[0]['dvdt Peak2 Voltage'] = volt_segment[curr_peaks_indices[-1]]
features[0]['dvdt Threshold'] = volt_segment[np.where(dvdt>1)[0][0]]
features[0]['dvdt Peak2 Shoulder'] = dvdt_at_least_change
# features[0]['Peak2_shoulder'] = peak2
positive_slope_indices = np.where(dvdt > 1)[0]
if len(positive_slope_indices) > 0:
threshold_index = positive_slope_indices[0]
features[0]['dvdt Threshold_DEBUG'] = volt_segment[threshold_index]
else:
features[0]['dvdt Threshold_DEBUG'] = None
if rec_extra:
#for ais
trace['V'] = extra_vms['ais']
feat_list = ['Spikecount']
traces = [trace]
feature_ais = efel.getFeatureValues(traces,feat_list)
features[0]['ais spikecount'] =feature_ais[0]['Spikecount']
features[0]['ais spikecount'] =features[0]['ais spikecount'][0]
#for nexus
trace['V'] = extra_vms['nexus']
feat_list = ['Spikecount']
traces = [trace]
feature_nex = efel.getFeatureValues(traces,feat_list)
features[0]['nex spikecount'] =feature_nex[0]['Spikecount']
features[0]['nex spikecount'] =features[0]['nex spikecount'][0]
#for dist_dend
trace['V'] = extra_vms['dist_dend']
feat_list = ['Spikecount']
traces = [trace]
feature_disdend = efel.getFeatureValues(traces,feat_list)
features[0]['disdend spikecount'] =feature_disdend[0]['Spikecount']
features[0]['disdend spikecount'] =features[0]['disdend spikecount'][0]
features = pd.DataFrame(features)
features = features.drop(columns =['all_ISI_values'])
features.insert(0,'Type',mut_name)
with open (f'{prefix}_efel.csv','w') as f:
features.to_csv(f,index=False)
f.close()
return features
mut_names = ['']
mut_not_found = {}
feature_row=None
for mut_name in mut_names:
try:
feature_row = get_features(mutant_name=mut_name)
with open('efel_features.csv', 'a') as f:
feature_row.to_csv(f, header=f.tell()==0,index=False) #bug needs an existing file
except Exception as e:
print(e)
mut_not_found[mut_name] = e