from neuron import h, rxd
from neuron.units import mV, ms
import matplotlib.pyplot as plt

h.load_file("stdrun.hoc")
import time
import multiprocessing


def run_sim(tstop, dim, d):
    start = time.perf_counter()

    soma = h.Section(name="soma")
    soma.L = 10
    soma.diam = 10

    soma.insert(h.hh)

    rxd.set_solve_type(dimension=dim)
    cyt = rxd.Region([soma], nrn_region="i", name="cyt")
    na = rxd.Species(cyt, name="na", charge=1, initial=h.nai0_na_ion, d=d)

    ic = h.IClamp(soma(0.5))
    ic.amp = 0.1
    ic.delay = 0 * ms
    ic.dur = tstop

    t = h.Vector().record(h._ref_t)
    v = h.Vector().record(soma(0.5)._ref_v)
    sodium = h.Vector().record(soma(0.5)._ref_nai)

    h.finitialize(-65 * mV)
    h.continuerun(tstop)

    finished = time.perf_counter()
    return {
        "time": finished - start,
        "num_nodes": len(na.nodes),
        "total_na": sum(node.volume * node.concentration for node in na.nodes),
        "surface_area": sum(node.surface_area for node in na.nodes),
        "volume": sum(node.volume for node in na.nodes),
        "na_d": na.d,
        "t": t,
        "v": v,
        "sodium": sodium,
        "dimension": dim,
    }



def main():
    tstop = 100 * ms

    fig = plt.figure(figsize=(6, 6))
    voltage_axis = fig.add_subplot(2, 1, 1)
    voltage_axis_zoom = voltage_axis.inset_axes([0.01, 0.38, 0.3, 0.6])
    sodium_axis = fig.add_subplot(2, 1, 2)
    sodium_axis_zoom = sodium_axis.inset_axes([0.01, 0.38, 0.3, 0.6])

    pool = multiprocessing.Pool()

    for data in pool.starmap(
        run_sim,
        [
            [tstop, 3, 1e-4],
            [tstop, 3, 1e-3],
            [tstop, 3, 1e-2],
            [tstop, 3, 1e-1],
            [tstop, 3, 1],
            [tstop, 1, 0],
        ],
    ):
        if data["dimension"] == 3:
            for axis in [voltage_axis, voltage_axis_zoom]:
                axis.plot(data["t"], data["v"], label=f"D={data['na_d']}")
            for axis in [sodium_axis, sodium_axis_zoom]:
                axis.plot(data["t"], data["sodium"], label=f"D={data['na_d']}")
        else:
            for axis in [voltage_axis, voltage_axis_zoom]:
                axis.plot(data["t"], data["v"], "k--", label="1D")
            for axis in [sodium_axis, sodium_axis_zoom]:
                axis.plot(data["t"], data["sodium"], "k--", label="1D")


        print(
            f"""
dimension = {data["dimension"]}
na.d = {data["na_d"]}
len(na.nodes) = {data["num_nodes"]}
total(na) = {data["total_na"]}
surface area = {data["surface_area"]}
volume = {data["volume"]}
elapsed time = {data["time"]}
"""
        )

        if data["dimension"] == 1:
            print("(surface area interpreted differently for 1D (no edge faces), but total current flux is the same)")

    sodium_axis.legend(ncol=1, loc="upper right", facecolor="white", framealpha=1)

    voltage_axis.set_xlim(0, tstop)
    voltage_axis.set_xticklabels([])
    sodium_axis.set_xlim(0, tstop)
    sodium_axis.set_xlabel("Time (ms)")
    sodium_axis.set_ylabel("[Na$^+$] (mM)")
    sodium_axis.set_ylim(10, 15)
    voltage_axis.set_ylim(-80, 40)
    voltage_axis.set_ylabel("Membrane potential (mV)")
    voltage_axis_zoom.set_xlim(79, 84)
    voltage_axis_zoom.set_ylim(-75, 30)
    sodium_axis_zoom.set_xlim(68, 76)
    sodium_axis_zoom.set_ylim(10.2, 11.4)
    voltage_axis_zoom.set_xticks([])
    voltage_axis_zoom.set_yticks([])
    sodium_axis_zoom.set_xticks([])
    sodium_axis_zoom.set_yticks([])
    voltage_axis.indicate_inset_zoom(voltage_axis_zoom, edgecolor="black")
    sodium_axis.indicate_inset_zoom(sodium_axis_zoom, edgecolor="black")
    # subplot labeling based on https://stackoverflow.com/questions/18344939/matplotlib-panel-label-out-of-the-box-above-the-ylabel
    voltage_axis.text(-0.1, 1.15, "A", fontweight="bold", va="top", ha="right", transform=voltage_axis.transAxes, fontsize=16)
    sodium_axis.text(-0.1, 1.15, "B", fontweight="bold", va="top", ha="right", transform=sodium_axis.transAxes, fontsize=16)

    fig.savefig("response_to_currents.pdf")

    plt.show()


if __name__ == "__main__":
    main()