"""Credit to Robert McDougal for providing the h.SaveState() example code upon which the initialize() (below) is based,
and for his assistance with numerous other implementation hurdles."""
import copy
import math
from sorcery import print_args
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from neuron import h, gui
from neuron.units import ms, mV
import pprint
from datetime import datetime
#====================== General files and tools =====================
h.load_file("nrngui.hoc")
#====================== cvode =======================================
h.dt = 1.0e-01
cvode = h.CVode()
cvode.active(1)
cvode.atol(1.0e-06)
state_filename='warmed_up.dat'
v_init = -80.53 * mV
t_warmup = 3000.0 * ms
t_stop = t_warmup + 1000.0 * ms
###define a python function to restore from saved state.
def restore_state(state_filename=state_filename):
s = h.SaveState()
f = h.File(state_filename) #<---Requires that this script be run at least once, otherwise this file will not exist yet.
s.fread(f)
f.close()
s.restore()
###create a function to setup and optionally warmup and save the simulation state
def initialize(warm_up_and_save=False, state_filename=state_filename, load_state=True):
global fih
###save the warmed-up simulation (calcium equilibrated)
if warm_up_and_save is True:
h.finitialize(v_init)
h.continuerun(t_warmup)
# plt.plot(cell1.t, cell1.vsoma, label="original", linestyle='--', linewidth=3.0, color='black')
s = h.SaveState()
s.save()
f = h.File(state_filename)
s.fwrite(f)
f.close()
if load_state is True:
cvode.active(1)
fih = h.FInitializeHandler(restore_state)
class Pyramidal:
cells = [] # Class variable to store all neurons (instances) created by Pyramidal
def __init__(self, gid):
self.gid = gid
Pyramidal.cells.append(self)
self.tropical_colors = [
'#FF4500', # Coral
'#FFD700', # Gold
'#32CD32', # LimeGreen
'#20B2AA', # LightSeaGreen
'#FF69B4', # HotPink
'#FF6347', # Tomato
'#FFDAB9', # PeachPuff
'#98FB98', # PaleGreen
'#F08080', # LightCoral
'#00CED1', # DarkTurquoise
'#ADFF2F', # GreenYellow
'#87CEEB', # SkyBlue
'#FFA07A', # LightSalmon
'#B0E0E6', # PowderBlue
'#FFDEAD', # NavajoWhite
'#FFB6C1', # LightPink
'#FFA500', # Orange
'#BDB76B', # DarkKhaki
'#8A2BE2', # BlueViolet
'#66CDAA', # MediumAquamarine
'#FA8072', # Salmon
'#7FFFD4', # Aquamarine
'#F0E68C', # Khaki
'#F5DEB3', # Wheat
'#EE82EE', # Violet
'#FAFAD2', # LightGoldenrodYellow
'#D8BFD8', # Thistle
'#DDA0DD', # Plum
'#CD853F', # Peru
'#FFC0CB', # Pink
'#DA70D6', # Orchid
'#EEE8AA', # PaleGoldenrod
'#40E0D0', # Turquoise
'#FFFAF0', # FloralWhite
'#AFEEEE', # PaleTurquoise
'#F5F5DC', # Beige
'#FFEFDB', # PapayaWhip
'#FFEBCD', # BlanchedAlmond
'#FFE4C4', # Bisque
'#FFB5C5', # HotPinkLight
'#F4A460', # SandyBrown
'#D2B48C', # Tan
'#C0C0C0', # Silver
'#A52A2A', # Brown
'#9ACD32', # YellowGreen
'#FAEBD7', # AntiqueWhite
'#8B4513', # SaddleBrown
'#6B8E23', # OliveDrab
'#808080', # Gray
'#BC8F8F' # RosyBrown
]
self.v_init = v_init
self.t_warmup = t_warmup
self.t_stop = t_stop
h.load_file("import3d.hoc")
# h.load_file("models/L5PCbiophys1.hoc")
h.load_file("models/L5PCbiophys4_mixedAISNavs.hoc")
h.load_file("models/L5PCtemplate.hoc")
self.cell = h.L5PCtemplate("morphologies/cell1.asc")
self.soma = self.cell.soma[0]
# self.all = self.soma.wholetree()
### Identify apical segments for observation
# self.apic = self.cell.apic[37]
self.apic = self.cell.apic[36]
self.apic_BAC = self.cell.apic[37]
self.apic_BAC_seg = self.cell.apic[37](0.6)
### Identify axon and AIS
self.axon0 = self.cell.axon[0]
self.axon1 = self.cell.axon[1]
self.apic0 = self.cell.apic[0]
self.axon0.nseg = 31
self.axon1.nseg = 31
self.axon0.L = 30.0 #um
self.axon1.L = 30.0 #um
self.ais_segments = [seg for seg in self.axon0] + [seg for seg in self.axon1]
self.sigma = 10.0 #<---called "flatness" in Thresher.py, this parameter multiplies the argument of the tanh() function in the Nav profiles.
self.aisNavSeparation = 0.0
self.aisNavCrossover = 0.5
self.aisDeltaRightShift = 0.0
self.get_aisNavProperties()
###Passive cable: AIS in original Hay model is a dead end: there was no axon attached to it.
###Adding a passive cable is the least intrusive means of adding an axon,
###i.e. providing the proper boundary conditions at the end of the AIS, without interfering with the model tuning.
self.passive = h.Section('passive')
self.passive.Ra = 100.0 # Ohm cm
self.passive.cm = 1.0
self.passive.diam = 1.0 #um
self.passive.L = 400.0 #um
self.passive.nseg = 2*math.ceil(0.5*self.passive.L) + 1
self.passive.connect(self.ais_segments[-1].sec(1))
h.pas.insert(self.passive)
for seg in self.passive:
seg.pas.e = v_init
seg.pas.g = 0.01
self.area_soma = self.get_area_of_section(self.soma)
self.area_ais = self.get_area_of_section(self.ais_segments)
self.gNaVTotal_soma = self.get_gNaVTotal(self.soma(0.5))
self.gNaVTotal_ais = self.get_gNaVTotal(self.ais_segments[1])
### initialize current clamp with zero amplitude
self.ic = h.IClamp(self.soma(0.5)) #h.IClamp(axon1(0.5)) #h.IClamp(axon1(1.0))
self.ic.amp = 0.0 #nA
self.amplitudes = [self.ic.amp]
self.vpeak_data = {}
self.tpeak_data = {}
### Insert mechanism to record vpeak in every segment of the model.
self.all = self.soma.wholetree()
for sec in self.all:
h.peak.insert(sec)
### default parameters for generating threshold curves
self.initialize_threshold_dict()
self.IMAX = 20.0 #nA #<---for threshold finder
self.vSpikeThreshold = -70.0 #10.0<---somatic BAP criterion #mV
self.spike_recording_site = self.apic_BAC_seg #self.soma(0.5) #self.apic(0.99)
self.spikeFound = False
self.epsilon = 1.0e-4
self.i_spikeFound = None
self.vpeak = None
self.separation_vals = [0.0, 1.0]
self.crossover_vals = [0.5]
self.scaleNavTwo_vals = [1.0]
self.scaleNavTwo = 1.0
self.deltaRightShift_vals = [0.0]
self.DeltaVrs = 0.0
self.data_dict = {}
self.setup_recording_vectors()
h.define_shape()
def __str__(self):
return f"Pyramidal[{self.gid}]"
def get_area_of_section(self, section=None):
if section is not None:
area_section = 0.0
for seg in section:
area_section+=seg.area()
return area_section
def stim(self, location=False, amplitude=False, duration=False, delay=False):
###move current clamp to "location"
if location is not False:
self.ic.loc(location)
if amplitude is not False:
self.ic.amp = amplitude
else:
self.ic.amp = 0.0 #nA
if duration is not False:
self.ic.dur = duration #ms
# else:
# self.ic.dur = 1.0 * ms
if delay is not False:
self.ic.delay = delay #ms
else:
self.ic.delay = t_warmup + 100.0*ms
def run_amplitudes(self, amplitudes=None):
if amplitudes is not None:
self.amplitudes = amplitudes
import matplotlib.pyplot as plt
# Create Figure 1 and Axes 1
self.fig1, self.ax1 = plt.subplots()
for amplitude in self.amplitudes:
self.stim(amplitude=amplitude)
h.finitialize(v_init)
self.reset_peak()
h.continuerun(t_stop)
# plt.plot(self.t, self.caisoma, label=f"gnabar={self.soma(0.5).NaTs2_t.gNaTs2_tbar}")
self.ax1.plot(self.t, self.vsoma, label=f"{self.ic.amp:g}")
# plt.plot(self.t, self.currentvec)
# zoom the x-axis to view the spikes
self.ax1.set_xlim([3098.0, 3110.0])
# create legend
self.ax1.legend().set_title('1ms pulse amplitude [nA]')
# indicate stimulation site in title
self.ax1.set_title(f"Nav Separation = {self.aisNavSeparation:g}"+f" | Crossover = {self.aisNavCrossover:g}"+r' | Stimulating at '+str(self.ic.get_segment().sec.name()).replace('L5PCtemplate', '').replace('[0]', '').replace('[1]', '').replace('.', ''))
# axis labels
self.ax1.set_ylabel(r'$V_{soma} \ \ \ [mV]$', fontsize=20)
self.ax1.set_xlabel(r'$t \ \ \ [ms]$', fontsize=20)
# Show Figure 1
self.fig1.show()
def spike(self, plot=False):
self.stim(amplitude=self.ic.amp)
h.finitialize(v_init)
# self.initialize()
self.reset_peak()
h.continuerun(t_stop)
self.t_postWarmUp = np.array(self.t)[np.where(np.array(self.t)>self.t_warmup)]
self.spikeSignal = np.array(self.vSpikeRecording)[np.where(np.array(self.t)>self.t_warmup)]
self.vMax = np.max(self.spikeSignal)
# print_args(self.aisNavSeparation, self.aisNavCrossover)
print(f"x={self.aisNavSeparation:g}, 𝜿={self.aisNavCrossover:g}, I={self.ic.amp:g}")
# print_args(self.ic.amp)
print_args(self.vMax)
# print_args(self.apic_BAC_seg.peak.vpeak)
if self.vMax >= self.vSpikeThreshold:
self.spikeFound = True
self.i_spikeFound = self.ic.amp
self.vpeak = self.vMax
else:
self.spikeFound = False
if plot==True:
import matplotlib.pyplot as plt
# Create Figure 1 and Axes 1
self.fig1, self.ax1 = plt.subplots()
# plt.plot(self.t, self.caisoma, label=f"gnabar={self.soma(0.5).NaTs2_t.gNaTs2_tbar}")
# self.ax1.plot(self.t, self.vsoma, label=f"{self.ic.amp:g}")
# plt.plot(self.t, self.currentvec)
self.ax1.plot(self.t_postWarmUp, self.spikeSignal, label=f"{self.ic.amp:g}")
# zoom the x-axis to view the spikes
self.ax1.set_xlim([3098.0, 3110.0])
# create legend
self.ax1.legend().set_title('1ms pulse amplitude [nA]')
# indicate stimulation site in title
self.ax1.set_title(f"Nav Separation = {self.aisNavSeparation:g}"+f" | Crossover = {self.aisNavCrossover:g}"+r' | Stimulating at '+str(self.ic.get_segment().sec.name()).replace('L5PCtemplate', '').replace('[0]', '').replace('[1]', '').replace('.', ''))
# axis labels
self.ax1.set_ylabel(r'$V_{soma} \ \ \ [mV]$', fontsize=20)
self.ax1.set_xlabel(r'$t \ \ \ [ms]$', fontsize=20)
# Show Figure 1
self.fig1.show()
# else:
# plt.close(fig1)
return self.spikeFound
def ithresh(self, IMAX=None, epsilon=None):
"""Theshold finder"""
self.i_spikeFound = None
self.spikeFound = False
if epsilon is not None:
self.epsilon = epsilon # Tolerance for convergence
if IMAX is not None:
self.IMAX = IMAX
delta = (self.IMAX - self.epsilon)/5.1 # Initial step size
self.ic.amp = self.IMAX
while True:
if self.spike():
self.ic.amp -= delta
continue
else:
self.ic.amp += delta
if self.ic.amp <= 0.0 or self.ic.amp >= self.IMAX:
print(f"Error, no threshold found. Pulse amplitude was set to {self.ic.amp:g} upon quitting")
break
if abs(delta) < self.epsilon:
break
delta = delta / 2.0 # Update step size (bisection method)
def threshCurve(self, separation_vals=None, crossover_vals=None, deltaRightShift_vals=None, scaleNavTwo_vals=None):
if separation_vals is not None:
self.separation_vals = separation_vals
if crossover_vals is not None:
self.crossover_vals = crossover_vals
if deltaRightShift_vals is not None:
self.deltaRightShift_vals = deltaRightShift_vals
if scaleNavTwo_vals is not None:
self.scaleNavTwo_vals = scaleNavTwo_vals
self.initialize_threshold_dict()
for separation in self.separation_vals:
for crossover in self.crossover_vals:
for DeltaVrs in self.deltaRightShift_vals:
for scaleNavTwo in self.scaleNavTwo_vals:
self.NaV2scaleNavTwo(scaleNavTwo)
self.NaV2DeltaVrs(DeltaVrs)
self.set_aisNavProfile(separation=separation, crossover=crossover)
self.ithresh()
self.update_threshold_dict()
def initialize_threshold_dict(self):
self.threshold_dict = {}
self.threshold_dict['spike_recording_site'] = []
self.threshold_dict['electrode_location'] = []
self.threshold_dict['separation'] = []
self.threshold_dict['crossover'] = []
self.threshold_dict['Ithresh'] = []
self.threshold_dict['vpeak'] = []
self.threshold_dict['DeltaVrs'] = []
self.threshold_dict['scaleNavTwo'] = []
def update_threshold_dict(self):
self.threshold_dict['spike_recording_site'] += [str(self.spike_recording_site)]
self.threshold_dict['electrode_location'] += [str(self.ic.get_segment())]
self.threshold_dict['separation'] += [self.aisNavSeparation]
self.threshold_dict['crossover'] += [self.aisNavCrossover]
self.threshold_dict['Ithresh'] += [self.i_spikeFound]
self.threshold_dict['vpeak'] += [self.vpeak]
self.threshold_dict['DeltaVrs'] += [self.DeltaVrs]
self.threshold_dict['scaleNavTwo'] += [self.scaleNavTwo]
def plot_threshold_dict(self, threshold_dict=None):
if threshold_dict is not None:
data = self.threshold_dict
else:
data = threshold_dict
import matplotlib.pyplot as plt
# Create Figure 1 and Axes 1
self.fig1, self.ax1 = plt.subplots()
self.ax1.plot(data['separation'], data['Ithresh'], label='')
# indicate stimulation site in title
self.ax1.set_title(r'BackProp theshold as a function of $Na_{V}$ distribution')
# axis labels
self.ax1.set_ylabel(r'$\kappa$', fontsize=20)
self.ax1.set_xlabel(r'$x$', fontsize=20)
# Show Figure 1
self.fig1.show()
def create_recording(self, vecname, vecvar, vecsource):
"""method for recording data and automatically adding the new vector to data_dict. Adds the data vector called self.str(vecname) to the class"""
setattr(self, vecname, h.Vector().record(getattr(vecsource, '_ref_'+vecvar)))
datavec = getattr(self, vecname)
self.data_dict[vecname] = datavec
# return datavec
def setup_recording_vectors(self):
# self.t = h.Vector().record(h._ref_t)
# self.vsoma = h.Vector().record(self.soma(0.5)._ref_v)
self.create_recording(vecname='t', vecvar='t', vecsource=h)
self.create_recording(vecname='currentvec', vecvar = 'i', vecsource = self.ic)
self.create_recording(vecname='vsoma', vecvar='v', vecsource=self.soma(0.5)) #<--- creates the vector self.vsoma = h.Vector().record(self.soma(0.5)._ref_v) and adds it to self.data_dict
self.create_recording(vecname='caisoma', vecvar='cai', vecsource=self.soma(0.5))
self.create_recording(vecname='vaxon0', vecvar='v', vecsource=self.axon0(0.5))
self.create_recording(vecname='vaxon1', vecvar='v', vecsource=self.axon1(0.5))
self.create_recording(vecname='vapic', vecvar='v', vecsource=self.apic(0.99))
self.create_recording(vecname='vSpikeRecording', vecvar='v', vecsource=self.spike_recording_site)
def get_gNaVTotal(self, seg):
gNaV2 = getattr(getattr(seg, 'NaTa2_t', None), 'gNaTa2_tbar', 0)
gNaV6 = getattr(getattr(seg, 'NaTa6_t', None), 'gNaTa6_tbar', 0)
gNaVSomatic_variant = getattr(getattr(seg, 'NaTs2_t', None), 'gNaTs2_tbar', 0)
gNaVaxonal_variant = getattr(getattr(seg, 'NaTa_t', None), 'gNaTa_tbar', 0)
return gNaV2 + gNaV6 + gNaVSomatic_variant + gNaVaxonal_variant
def get_aisNavProperties(self):
self.aisLength = h.distance(self.ais_segments[0].sec(0), self.ais_segments[-1].sec(1))
self.ais_positions = []
self.ais_normalized_positions = []
self.Nav2profile = []
self.Nav6profile = []
self.aisDeltaRightShift_profile = []
for seg in self.ais_segments:
position_of_seg = h.distance(self.ais_segments[0].sec(0), seg)
self.ais_positions += [position_of_seg]
self.ais_normalized_positions += [position_of_seg/self.aisLength]
self.Nav2profile += [seg.NaTa2_t.gNaTa2_tbar]
self.Nav6profile += [seg.NaTa6_t.gNaTa6_tbar]
if seg.NaTa2_t.vRS03 == seg.NaTa6_t.vRS03:
self.aisDeltaRightShift_profile += [seg.NaTa2_t.vRS03]
else:
print(f"Error: aisDeltaRightShift values do NOT match for each NaV subtype present in AIS: seg.NaTa2_t.vRS03={seg.NaTa2_t.vRS03:g}, seg.NaTa6_t.vRS03={seg.NaTa6_t.vRS03:g}")
exit()
self.ais_gNavTotal = self.Nav2profile[0] + self.Nav6profile[0]
###Define a function to alter Nav2 right-shift
def NaV2DeltaVrs(self, DeltaVrs=False):
if DeltaVrs is not False:
self.DeltaVrs = DeltaVrs
for seg in self.ais_segments:
seg.NaTa2_t.DeltaVrs = self.DeltaVrs
def NaV2scaleNavTwo(self, scaleNavTwo=False):
if scaleNavTwo is not False:
self.scaleNavTwo = scaleNavTwo
for seg in self.ais_segments:
seg.NaTa2_t.scaleNavTwo = self.scaleNavTwo
###Define separate functions to create the density profile of each NaV subtype
def Nav2Profile_func(self, segment):
import numpy as np
s = h.distance(self.ais_segments[0].sec(0), segment)/self.aisLength
return self.ais_gNavTotal*(0.5)*(1.0 - self.aisNavSeparation*np.tanh(self.sigma*(s - self.aisNavCrossover)))
def Nav6Profile_func(self, segment):
import numpy as np
s = h.distance(self.ais_segments[0].sec(0), segment)/self.aisLength
return self.ais_gNavTotal*(0.5)*(1.0 + self.aisNavSeparation*np.tanh(self.sigma*(s - self.aisNavCrossover)))
def set_aisNavProfile(self, separation=False, crossover=False, deltaRightShift=False):
self.get_aisNavProperties()
if separation is not False:
self.aisNavSeparation = separation
if crossover is not False:
self.aisNavCrossover = crossover
if deltaRightShift is not False:
self.aisDeltaRightShift = deltaRightShift
for seg in self.ais_segments:
seg.NaTa2_t.gNaTa2_tbar = self.Nav2Profile_func(seg) #self.ais_gNavTotal
seg.NaTa6_t.gNaTa6_tbar = self.Nav6Profile_func(seg) #self.ais_gNavTotal
seg.NaTa2_t.vRS03 = self.aisDeltaRightShift
seg.NaTa6_t.vRS03 = self.aisDeltaRightShift
self.get_aisNavProperties()
def plot_aisNavProfiles(self):
import matplotlib.pyplot as plt
self.get_aisNavProperties()
# Create Figure 1 and Axes 1
self.fig1, self.ax1 = plt.subplots()
# Add data to Figure 1, Axes 1
self.ax1.plot(self.ais_positions, np.array(self.Nav2profile)+np.array(self.Nav6profile), linewidth=1.0, color='black', label=r'$\bar{g}_{Na_{V}, Total} = \bar{g}_{Na_{V}1.2}+\bar{g}_{Na_{V}1.6}$')
self.ax1.plot(self.ais_positions, self.Nav2profile, linewidth=3.0, label=r'$\bar{g}_{Na_{V}1.2}$')
self.ax1.plot(self.ais_positions, self.Nav6profile, linewidth=2.0, linestyle='--', label=r'$\bar{g}_{Na_{V}1.6}$')
# create legend
self.ax1.legend()
# create title
self.ax1.set_title(f"Nav Separation = {self.aisNavSeparation:g}"+f" | Crossover = {self.aisNavCrossover:g}")
#set y-limit for NaV density axis
self.ax1.set_ylim([0.0, self.ais_gNavTotal*1.05])
# Show Figure 1
self.fig1.show()
# Close Figure 1
# plt.close(fig1)
def get_vpeak(self):
data = {}
for sec in self.all:
# for sec in [self.soma, self.axon0, self.axon1]:
multiplier = 1.0
if 'dend' in str(sec) or 'apic' in str(sec):
multiplier = -1.0
# if 'apic[36]' in str(sec):
data[str(sec)] = [[multiplier*h.distance(self.soma(0.5), seg) for seg in sec], [seg.peak.vpeak for seg in sec]]
self.vpeak_data = data
def get_tpeak(self):
data = {}
for sec in self.all:
# for sec in [self.soma, self.axon0, self.axon1]:
multiplier = 1.0
if 'dend' in str(sec) or 'apic' in str(sec):
multiplier = -1.0
# if 'apic[36]' in str(sec):
data[str(sec)] = [[multiplier*h.distance(self.soma(0.5), seg) for seg in sec], [seg.peak.tpeak for seg in sec]]
self.tpeak_data = data
def plot_peak_everywhere(self, amplitude=None, separation=False, crossover=False, deltaRightShift=False, plot_tpeak=False, filename=None, legend1_location=None, legend2_location=None):
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import rc
import tikzplotlib
### global formatting
label_size = 20
legend_size = 17
line_width = 3
matplotlib.rcParams['xtick.labelsize'] = label_size
matplotlib.rcParams['ytick.labelsize'] = label_size
matplotlib.rcParams['axes.labelsize'] = label_size*1.3
matplotlib.rcParams['legend.fontsize'] = legend_size
matplotlib.rcParams['legend.title_fontsize'] = legend_size
matplotlib.rcParams['lines.linewidth'] = line_width
###Create and Close Figure 1 and Figure 2, to clear any pre-existing figures:
self.fig1, self.ax1 = plt.subplots(figsize=(9,6))
plt.close(self.fig1)
if plot_tpeak==True:
self.fig2, self.ax2 = plt.subplots(figsize=(9,6))
plt.close(self.fig2)
### Apply the custom format function to the x and y axes
from matplotlib.ticker import FuncFormatter
### Define a custom function for formatting
def format_ticks(x, _):
return f"{x:g}"
# return f"{x:.1f}"
self.formatter = FuncFormatter(format_ticks)
self.fig1, self.ax1 = plt.subplots(figsize=(9,6))
self.ax1.tick_params(left=True, right=True)
self.ax1.xaxis.set_major_formatter(self.formatter)
self.ax1.yaxis.set_major_formatter(self.formatter)
if plot_tpeak==True:
self.fig2, self.ax2 = plt.subplots(figsize=(9,6))
self.ax2.tick_params(left=True, right=True)
self.ax2.xaxis.set_major_formatter(self.formatter)
self.ax2.yaxis.set_major_formatter(self.formatter)
if separation or crossover or deltaRightShift:
self.set_aisNavProfile(separation=separation, crossover=crossover, deltaRightShift=deltaRightShift)
if amplitude is None:
amplitude=8.2
self.stim(amplitude=amplitude)
h.finitialize(v_init)
self.reset_peak()
h.continuerun(t_stop)
if filename is None:
self.figname1 = 'vpeak_'+str(self.ic.amp)+'_'+str(self.ic.dur)+'_'+str(self.aisNavSeparation)+'_'+str(self.aisNavCrossover)+'_'+str(self.ic.get_segment())
self.figname2 = 'tpeak_'+str(self.ic.amp)+'_'+str(self.ic.dur)+'_'+str(self.aisNavSeparation)+'_'+str(self.aisNavCrossover)+'_'+str(self.ic.get_segment())
else:
self.figname1 = filename+'_vpeak'
self.figname2 = filename+'_tpeak'
self.get_vpeak()
print_args(self.cell.apic[37].nseg)
self.ax1.plot(-1.0*h.distance(self.soma(0.5), self.cell.apic[37](0.6)), self.cell.apic[37](0.6).peak.vpeak, linewidth=0.0, linestyle='', marker='d', markersize=10.0, color='black', alpha=0.7, label='record')#str(self.cell.apic[37])[-8:])
self.ax1.plot(-1.0*h.distance(self.soma(0.5), self.apic(1.0)), self.apic(0.99).peak.vpeak, linewidth=0.0, linestyle='', marker='p', markersize=10.0, color='black', alpha=0.7, label='bifurcation') #, label=str(self.apic)[-8:])
for sec, segdata in self.vpeak_data.items():
# if 'apic[36]' in str(sec):
# self.ax1.plot(segdata[0], segdata[1], linewidth=0.0, linestyle='', marker='s', alpha=0.7, label=str(sec))
# elif 'apic' in str(sec) and max(segdata[1])>23.0 and max(segdata[0])<-500.0:
# self.ax1.plot(segdata[0], segdata[1], linewidth=0.0, linestyle='', marker='p', alpha=0.7, label=str(sec))
# elif 'apic' in str(sec) or 'dend' in str(sec):
# self.ax1.plot(segdata[0], segdata[1], linewidth=0.0, linestyle='', marker='o', color='blue', alpha=0.3)
# else:
# self.ax1.plot(segdata[0], segdata[1], linewidth=0.0, linestyle='', marker='o', alpha=0.3, label=str(sec))
if 'axon' in str(sec):
self.ax1.plot(segdata[0], segdata[1], linewidth=0.0, linestyle='', marker='o', alpha=0.3, label='AIS_'+str(sec)[-2:-1])
else:
self.ax1.plot(segdata[0], segdata[1], linewidth=0.0, linestyle='', marker='o', alpha=0.3)
# create legend
if legend1_location is not None:
self.ax1.legend(loc=legend1_location)
else:
self.ax1.legend(loc='upper left')
self.ax1.set_ylabel(r'$V_{peak}\ [mV]$', fontsize=27)
self.ax1.set_xlabel(r'Distance to Soma $s\ [\mu m]$', fontsize=27)
# Show Figure 1
self.fig1.subplots_adjust(top=0.98, bottom=0.13)
self.ax1.set_ylim(-86.0, 54.0)
self.fig1.savefig(self.figname1+'.pdf')
self.fig1.show()
if plot_tpeak==True:
self.get_tpeak()
self.ax2.plot(-1.0*h.distance(self.soma(0.5), self.cell.apic[37](0.6)), self.cell.apic[37](0.6).peak.tpeak-self.ic.delay, linewidth=0.0, linestyle='', marker='d', markersize=10.0, color='black', alpha=0.7, label=str(self.cell.apic[37])[-8:])
self.ax2.plot(-1.0*h.distance(self.soma(0.5), self.apic(1.0)), self.apic(0.99).peak.tpeak-self.ic.delay, linewidth=0.0, linestyle='', marker='p', markersize=10.0, color='black', alpha=0.7, label=str(self.apic)[-8:])
# self.ax2.legend(loc='upper right')
for sec, segdata in self.tpeak_data.items():
# self.ax2.plot(segdata[0], np.array(segdata[1])-self.ic.delay, linewidth=0.0, linestyle='', marker='o', alpha=0.3, label=str(sec))
if 'axon' in str(sec):
self.ax2.plot(segdata[0], np.array(segdata[1])-self.ic.delay, linewidth=0.0, linestyle='', marker='o', alpha=0.3, label='AIS_'+str(sec)[-2:-1])
else:
self.ax2.plot(segdata[0], np.array(segdata[1])-self.ic.delay, linewidth=0.0, linestyle='', marker='o', alpha=0.3)
if legend2_location is not None:
self.ax2.legend(loc=legend2_location)
else:
self.ax2.legend(loc='upper right')
self.ax2.set_ylabel(r'$t_{peak}\ [ms]$', fontsize=27)
self.ax2.set_xlabel(r'Distance to Soma $s\ [\mu m]$', fontsize=27)
# Show Figure 2
self.fig2.subplots_adjust(top=0.98, bottom=0.13)
self.fig2.savefig(self.figname2+'.pdf')
self.fig2.show()
# Close Figure 1
# plt.close(fig1)
def reset_peak(self):
for sec in self.all:
for seg in sec:
seg.peak.vpeak = -1000.0
seg.peak.tpeak = -1000.0
def get_neuron_coordinates_and_colors(self):
x, y, z, diameters, colors = [], [], [], [], []
color_idx = 0
for sec in self.all:
for i in range(int(sec.n3d()) - 1): # -1 to avoid an out-of-range error
x1, y1, z1 = sec.x3d(i), sec.y3d(i), sec.z3d(i)
x2, y2, z2 = sec.x3d(i + 1), sec.y3d(i + 1), sec.z3d(i + 1)
diameter = (sec.diam3d(i) + sec.diam3d(i + 1)) / 2.0 # average diameter for this segment
x.append(np.array([x1, x2]))
y.append(np.array([y1, y2]))
z.append(np.array([z1, z2]))
diameters.append(diameter)
colors.append(self.tropical_colors[color_idx])
color_idx = (color_idx + 1) % len(self.tropical_colors)
return x, y, z, diameters, colors
def show3Dcell(self, colors=None, figsize=(10, 10), xlim=None, ylim=None, zlim=None):
from mpl_toolkits.mplot3d.art3d import Line3DCollection
x, y, z, diameters, colors = self.get_neuron_coordinates_and_colors()
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')
# Remove grid and axes
ax.grid(False)
ax.axis('off')
ax.w_xaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
ax.w_yaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
ax.w_zaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
# Draw the neuron segments as pipes
for xi, yi, zi, di, color in zip(x, y, z, diameters, colors):
ax.add_collection3d(Line3DCollection([list(zip(xi, yi, zi))], linewidths=di, colors=color))
# Set axis limits if provided
if xlim:
ax.set_xlim(xlim)
else:
ax.set_xlim(min([xi.min() for xi in x]), max([xi.max() for xi in x]))
if ylim:
ax.set_ylim(ylim)
else:
ax.set_ylim(min([yi.min() for yi in y]), max([yi.max() for yi in y]))
if zlim:
ax.set_zlim(zlim)
else:
ax.set_zlim(min([zi.min() for zi in z]), max([zi.max() for zi in z]))
# Adjust aspect ratio and margins
ax.set_box_aspect([ub - lb for lb, ub in (ax.get_xlim3d(), ax.get_ylim3d(), ax.get_zlim3d())])
plt.tight_layout()
plt.show()
def export_dict_to_python_script(dict_to_export):
# Use pprint to get a nicely formatted string of the dictionary
formatted_dict_str = pprint.pformat(dict_to_export, width=80, sort_dicts=True)
# Generate a file name with the current date and time
datetime_now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
file_name = f'data_dict_{datetime_now}.py'
# Write to a file, including the variable assignment
with open(file_name, 'w') as file:
file.write(f"data = {formatted_dict_str}\n")
print(f'Dictionary has been saved to {file_name}')
###*******************************************************************************
###*** Demo **********************************************************************
###*******************************************************************************
if __name__ == '__main__':
###
###Create cells
###
cell1 = Pyramidal(1)
cell2 = Pyramidal(2)
###
### Initialize/resume the simulation
###
initialize(warm_up_and_save=True)
# initialize(warm_up_and_save=False)
for key, item in cell1.data_dict.items():
print(key, item)
###Display cell1 in 3D matplotlib
# cell1.show3Dcell()
cell1.set_aisNavProfile(separation=0.0, crossover=0.5)
cell1.plot_aisNavProfiles()
# exit()
# # cell1.stim(location=cell1.soma(0.5), duration=1.0)
# cell1.stim(location=cell1.ais_segments[-1], duration=1.0)
# cell1.threshCurve(separation_vals=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], crossover_vals=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])
# thresh_data = copy.deepcopy(cell1.threshold_dict)