import time
import multiprocessing
import sqlite3
import pandas as pd
from neuron import h, rxd
from neuron.units import mV, ms, um, mM
DB_FILENAME = "thread_scaling.db"
SWC_FILENAME = "B4-CA1-L-D63x1zACR3_1.CNG.swc.txt"
class Cell:
def __init__(self, dx):
cell = h.Import3d_SWC_read()
i3d = h.Import3d_GUI(cell, False)
self.start = self.soma[0] = "cell"
self.dx = dx
class Cylinder:
def __init__(self, dx):
self.all = [h.Section(name=f"dend{i}") for i in range(2)]
for dend in self.all:
dend.L = 25 * um
dend.diam = 1 * um
self.start = self.all[1] = "cylinder"
self.dx = dx
def diffusion_only(obj):
return {
"name": "diffusion",
"regions": [cyt := rxd.Region(obj.all, name="cyt", dx=obj.dx)],
"species": [rxd.Species(cyt, d=1*um**2/ms, initial=lambda node: 1 if node in obj.start else 0)]
def bistable(obj):
return {
"name": "bistable-wave",
"regions": [cyt := rxd.Region(obj.all, name="cyt", dx=obj.dx)],
"species": [c := rxd.Species(cyt, d=1*um**2/ms, initial=lambda node: 1 * mM if node in obj.start else 0)],
"reactions": [rxd.Rate(c, -c * (1 * mM - c) * (0.3 * mM - c))]
def cawave(obj,
caDiff = 0.08,
ip3Diff = 1.41,
cac_init = 1.e-4,
ip3_init = 0.1,
gip3r = 12040,
gserca = 0.3913,
gleak = 6.020,
kserca = 0.1,
kip3 = 0.15,
kact = 0.4,
ip3rtau = 2000,
fc = 0.8,
fe = 0.2,
average_ca_inside = 0.0017
cae_init = (average_ca_inside - cac_init * fc) / fe
return {
"name": "cawave",
"regions": [
cyt := rxd.Region(obj.all, nrn_region='i', geometry=rxd.FractionalVolume(fc, surface_fraction=1), dx=obj.dx),
er := rxd.Region(obj.all, geometry=rxd.FractionalVolume(fe), dx=obj.dx),
cyt_er_membrane := rxd.Region(obj.all, geometry=rxd.DistributedBoundary(1), dx=obj.dx)
"species": [
ca := rxd.Species([cyt, er], d=caDiff, name='ca', charge=2, initial=lambda node: cac_init if node in cyt else cae_init, atolscale=1e-6),
ip3 := rxd.Species(cyt, d=ip3Diff, initial=lambda node: 2 * mM if node in obj.start else ip3_init),
ip3r_gate_state := rxd.State(cyt_er_membrane, initial=0.8)
"misc": [
h_gate := ip3r_gate_state[cyt_er_membrane],
minf := ip3[cyt] * 1000. * ca[cyt] / (ip3[cyt] + kip3) / (1000. * ca[cyt] + kact),
k := gip3r * (minf * h_gate) ** 3
"reactions": [
serca := rxd.MultiCompartmentReaction(ca[cyt], ca[er], gserca / ((kserca / (1000. * ca[cyt])) ** 2 + 1), membrane=cyt_er_membrane, custom_dynamics=True),
leak := rxd.MultiCompartmentReaction(ca[er], ca[cyt], gleak, gleak, membrane=cyt_er_membrane),
ip3r := rxd.MultiCompartmentReaction(ca[er], ca[cyt], k, k, membrane=cyt_er_membrane),
ip3rg := rxd.Rate(h_gate, (1. / (1 + 1000. * ca[cyt] / (0.3)) - h_gate) / ip3rtau)
with sqlite3.connect(DB_FILENAME) as conn:
old_data = pd.read_sql("SELECT * FROM data", conn)
old_data = pd.DataFrame({"nthread": [], "morphology": [], "kinetics": [], "dx": [], "runcount": [], "runtime": []})
def run_sim(nthread, morphology, kinetics, dx):
# setup the model
morph = morphology(dx)
my_kinetics = kinetics(morph)
# skip if we've already done this
if any((old_data["dx"] == dx) & (old_data["morphology"] == & (old_data["kinetics"] == my_kinetics["name"]) & (old_data["nthread"] == nthread)):
print(f"skipping: dx: {dx}, morph: {}, kinetics: {my_kinetics['name']}, nthread: {nthread}")
print(f"running: dx: {dx}, morph: {}, kinetics: {my_kinetics['name']}, nthread: {nthread}")
# run the sim several times
times = []
for run in range(NUM_RUNS):
print(f" run #{run + 1}")
initial_time = time.perf_counter()
h.finitialize(-65 * mV)
start_time = time.perf_counter()
print(f" initialization time: {start_time - initial_time}")
h.continuerun(100 * ms)
end_time = time.perf_counter()
times.append(end_time - start_time)
print(f" elapsed: {end_time - start_time} s")
# store the data in the database
data = pd.DataFrame(
"nthread": nthread,
"kinetics": my_kinetics['name'],
"dx": dx,
"runcount": range(NUM_RUNS),
"runtime": times
with sqlite3.connect(DB_FILENAME) as conn:
data.to_sql("data", conn, if_exists="append", index=False)
if __name__ == "__main__":
for dx in [0.12, 0.06]:
for morphology in [Cylinder, Cell]:
for kinetics in [cawave, diffusion_only, bistable]:
for nthread in [1, 2, 3, 4, 5, 6, 7, 8]:
p = multiprocessing.Process(
target=run_sim, args=(nthread, morphology, kinetics, dx)