from neuron import h, rxd
from neuron.units import ms, mV, µm, mM
import plotnine as p9
import pandas as pd
import multiprocessing as mp
import math
from functools import lru_cache
h.load_file("stdrun.hoc")
# centered at 76.5
start_x = 70 * µm
stop_x = 83 * µm
initial_concentration = 1 * mM
D = 1
tstop = 50 * ms
def fundamental_solution(t, D, distance):
return 1 / math.sqrt(4 * math.pi * D * t) * math.exp(-(distance ** 2) / (4 * D * t))
@lru_cache(maxsize=None)
def solution_from_interval(
x, t, D, start_x, stop_x, initial_concentration, interval_steps=100
):
"""solution of diffusion on an infinite line from a finite interval source"""
assert start_x < stop_x
dx = (stop_x - start_x) / interval_steps
return initial_concentration * sum(
dx * fundamental_solution(t, D, x - (start_x + (i + 0.5) * dx))
for i in range(interval_steps)
)
def solution_from_interval_with_boundaries(
x, t, D, start_x, stop_x, initial_concentration, interval_steps=100
):
# 153 µm is the total length of the line
# in principle, the mathematical solution involves an infinite sum
return (
solution_from_interval(
x, t, D, start_x, stop_x, initial_concentration, interval_steps
)
+ solution_from_interval(x, t, D, -stop_x, -start_x, 0, interval_steps)
+ solution_from_interval(
x, t, D, 153 - stop_x, 153 - start_x, 0, interval_steps
)
)
def main(dx=0.25, ics_partial_volume_resolution=2):
# NOTE: NEURON's default is to use an ics_partial_volume_resolution of 2.
print("starting setup...")
# set accuracy of volumes
rxd.options.ics_partial_volume_resolution = ics_partial_volume_resolution
# small time step (half of NEURON's default)... needed for numerical stability with small dx
h.dt = 0.0125 * ms
# construct 4 axons
# axon1: just 1D
# axon2: 3D in the middle, 1D on edges
# axon3: 3D everywhere
# axon4: 3D on edges, 1D in the middle
axon1 = h.Section(name="axon1")
axon2a = h.Section(name="axon2a")
axon2b = h.Section(name="axon2b")
axon2c = h.Section(name="axon2c")
axon3 = h.Section(name="axon3")
axon4a = h.Section(name="axon4a")
axon4b = h.Section(name="axon4b")
axon4c = h.Section(name="axon4c")
axon2b.connect(axon2a)
axon2c.connect(axon2b)
axon4b.connect(axon4a)
axon4c.connect(axon4b)
axon3.L = axon1.L = 153 * µm
axon2a.L = axon2b.L = axon2c.L = axon4a.L = axon4b.L = axon4c.L = 51 * µm
for sec in [axon1, axon2a, axon2b, axon2c, axon3, axon4a, axon4b, axon4c]:
sec.diam = 2 * µm
sec.nseg = int(sec.L) * 2
rxd.set_solve_type([axon2b, axon3, axon4a, axon4c], dimension=3)
# initialization rule
def init_concentration(node):
return initial_concentration if start_x < node.x3d < stop_x else 0
# set up the model
cytosol = rxd.Region(
[axon1, axon2a, axon2b, axon2c, axon3, axon4a, axon4b, axon4c],
name="cytosol",
nrn_region="i",
dx=dx,
)
ca = rxd.Species(cytosol, name="ca", d=D, charge=2, initial=init_concentration)
# run the simulation
print("initializing...")
h.finitialize(-65 * mV)
initial_mass = sum(
node.concentration * node.volume
for node in ca.nodes
if node.sec in [axon2a, axon2b, axon2c]
)
print(
"volume in 1D part:",
sum(node.volume for node in ca.nodes if node.sec in [axon2a, axon2c]),
)
print(
"volume in 3D part:",
sum(node.volume for node in ca.nodes if node.sec in [axon2b]),
)
print("initial mass:", initial_mass)
print("running...")
h.continuerun(tstop)
ending_mass = sum(
node.concentration * node.volume
for node in ca.nodes
if node.sec in [axon2a, axon2b, axon2c]
)
print("ending mass:", ending_mass)
print(
f"true concentration at x=76.5: {solution_from_interval_with_boundaries(76.5 * µm, tstop, D, start_x, stop_x, initial_concentration)}"
)
print(
f"true concentration 1/3 of the way in: {solution_from_interval_with_boundaries(51 * µm, tstop, D, start_x, stop_x, initial_concentration)}"
)
print(
f"change over 1 timestep in true concentration 1/3 of the way in: {solution_from_interval_with_boundaries(51 * µm, tstop + h.dt, D, start_x, stop_x, initial_concentration) - solution_from_interval_with_boundaries(51 * µm, 50 * ms, D, start_x, stop_x, initial_concentration)}"
)
# prepare the data
# note: axon1, axon2, axon3, and axon4 all have same number of segments
sec2id = {
axon1: "1D",
axon2a: "3D middle",
axon2b: "3D middle",
axon2c: "3D middle",
axon3: "full 3D",
axon4a: "3D edges",
axon4b: "3D edges",
axon4c: "3D edges",
}
all_nodes = [
node
for sec in [axon1, axon2a, axon2b, axon2c, axon3, axon4a, axon4b, axon4c]
for node in ca.nodes(sec)
]
all_ids = [sec2id[node.sec] for node in all_nodes]
data = pd.DataFrame(
{
"x": [node.x3d for node in all_nodes],
"vol": [node.volume for node in all_nodes],
"cai": [node.concentration for node in all_nodes],
"id": all_ids,
"true_cai": [
solution_from_interval_with_boundaries(
node.x3d, h.t, D, start_x, stop_x, initial_concentration
)
for node in all_nodes
],
}
)
data["id"] = data["id"].astype("category")
data["dx"] = dx
data["ics_partial_volume_resolution"] = ics_partial_volume_resolution
data["mass"] = data["vol"] * data["cai"]
data["average_cai"] = data.groupby(["x", "id"]).mass.transform("sum") / data.groupby(["x", "id"]).vol.transform("sum")
data["abs_error"] = data["average_cai"] - data["true_cai"]
data["unaveraged_abs_error"] = data["cai"] - data["true_cai"]
data["rel_error"] = data["abs_error"] / data["true_cai"]
# Plot
p9.options.figure_size = (4, 1)
g = (
p9.ggplot(data, p9.aes(x="x", y="cai", color="id"))
+ p9.geom_line()
+ p9.labs(x="Position (µm)", y="Ca$^{2+}$ concentration (mM)")
)
g.save(f"comparison-to-truth-dx-{dx}-ics-{ics_partial_volume_resolution}.pdf")
# Plot
p9.options.figure_size = (4, 3)
g = (
p9.ggplot(data, p9.aes(x="x", y="abs_error", color="id"))
+ p9.geom_line()
+ p9.labs(x="Position (µm)", y="Absolute error (mM)", title=f"dx={dx}")
)
g.save(f"comparison-to-truth-abs-error-dx-{dx}-ics-{ics_partial_volume_resolution}.pdf")
# print maximum absolute error
g = (
p9.ggplot(data, p9.aes(x="x", y="rel_error", color="id"))
+ p9.geom_line()
+ p9.labs(x="Position (µm)", y="relative error", title=f"dx={dx}")
)
g.save(f"comparison-to-truth-rel-error-dx-{dx}-ics-{ics_partial_volume_resolution}.pdf")
return data
if __name__ == "__main__":
with mp.Pool() as pool:
data = pd.concat(pool.starmap(main, [(0.25, 2), (0.25, 6), (0.125, 2)]))
data["original_dx"] = data["dx"]
data["dx"] = data["dx"].astype("category")
p9.options.figure_size = (4, 3)
g = (
p9.ggplot(data[data["ics_partial_volume_resolution"] == 2], p9.aes(x="x", y="abs_error", color="id", linetype="dx"))
+ p9.geom_line()
+ p9.labs(x="Position (µm)", y="Signed absolute error (mM)")
)
g.save(f"comparison-to-truth-abs-error-by-dx.pdf")
data["ics_partial_volume_resolution"] = data["ics_partial_volume_resolution"].astype("category")
g = (
p9.ggplot(data[data["original_dx"] == 0.25], p9.aes(x="x", y="abs_error", color="id", linetype="ics_partial_volume_resolution"))
+ p9.geom_line()
+ p9.labs(x="Position (µm)", y="Signed absolute error (mM)")
)
g.save(f"comparison-to-truth-abs-error-by-ics.pdf")
pd.set_option('precision', 12)
print("Averaged concentration error:")
print(data.groupby(["dx", "ics_partial_volume_resolution", "id"]).abs_error.agg(["max", "min"]))
print("Unaveraged:")
print(data.groupby(["dx", "ics_partial_volume_resolution", "id"]).unaveraged_abs_error.agg(["max", "min"]))