import multiprocessing
import random
import sqlite3
import pandas as pd
from neuron import h, rxd
from neuron.units import mV, ms

h.load_file("stdrun.hoc")

THRESHOLD_CONCENTRATION = 0.5
NUM_ORIENTATIONS = 100

try:
    with sqlite3.connect("wave_time_3d.db") as conn:
        old_data = pd.read_sql("SELECT * FROM data", conn)
except:
    old_data = pd.DataFrame({"theta": [], "phi": [], "alpha": [], "dx": []})

def on_stopevent():
    h.stoprun = True


def save_data(theta, phi, dx, alpha, length, diam, speed, error, sim_time):
    # connect to the database (or create it if it doesn't exist)
    conn = sqlite3.connect("wave_time_3d.db")
    c = conn.cursor()

    # ensure the table exists
    c.execute(
        """
        CREATE TABLE IF NOT EXISTS data (
            theta REAL,
            phi REAL,
            dx REAL,
            alpha REAL,
            length REAL,
            diam REAL,
            speed REAL,
            relative_error REAL,
            sim_time REAL
        )
        """
    )

    # store the data
    c.execute(
        "INSERT INTO data VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
        (theta, phi, dx, alpha, length, diam, speed, error, sim_time),
    )
    conn.commit()
    conn.close()


def run_sim(theta, phi, dx, alpha=0.25, L=251, diam=2):
    # theta, phi are polar angle and azimuthal angle, respectively
    # per ISO 80000-2:2019... this is physics style not math convention
    # theta \in [0, \pi), phi \in [0, 2*pi)
    import time
    import numpy as np

    start = time.perf_counter()

    # setup the model geometry
    dend = h.Section(name="dend")
    dend.pt3dadd(0, 0, 0, diam)
    dend.pt3dadd(
        L * h.cos(phi) * h.sin(theta),
        L * h.sin(phi) * h.sin(theta),
        L * h.cos(theta),
        diam,
    )
    dend.nseg = 251

    # setup the model kinetics
    cyt = rxd.Region([dend], name="cyt", dx=dx, nrn_region="i")
    c = rxd.Species(
        cyt, d=1, name="c", initial=lambda node: 1 if node.x * dend.L < 50 else 0
    )
    wave_reaction = rxd.Rate(c, -c * (1 - c) * (alpha - c))

    # integration options
    rxd.nthread(4)
    rxd.set_solve_type(dimension=3)

    # the locations we'll monitor
    pt1 = dend(100 / dend.L)
    pt2 = dend(200 / dend.L)
    distance = h.distance(pt1, pt2)

    # monitor concentration timecourses at the points above
    threshold = h.ref(THRESHOLD_CONCENTRATION)
    c_pt1 = h.Vector().record(pt1._ref_ci)
    c_pt2 = h.Vector().record(pt2._ref_ci)
    t = h.Vector().record(h._ref_t)

    # stop simulation when pt2 crosses the threshold
    ste = h.StateTransitionEvent(1)
    ste.transition(0, 0, pt2._ref_ci, threshold, on_stopevent)

    def on_finitialize():
        ste.state(0)

    fih = h.FInitializeHandler(on_finitialize)

    # use variable step integration
    h.CVode().active(True)
    h.CVode().atol(1e-6)

    # actually run the simulation
    h.finitialize(-65 * mV)
    h.continuerun(3000 * ms)
    print(f"end time: {h.t}")

    # interpolate to estimate the crossing times
    pt1_crossing_time = np.interp(THRESHOLD_CONCENTRATION, c_pt1, t)
    pt2_crossing_time = np.interp(THRESHOLD_CONCENTRATION, c_pt2, t)

    measured_speed = distance / (pt2_crossing_time - pt1_crossing_time)
    expected_speed = 2 ** 0.5 * (0.5 - alpha)
    speed_error = abs(1 - measured_speed / expected_speed)

    print(f"pt1_crossing_time = {pt1_crossing_time}")
    print(f"pt2_crossing_time = {pt2_crossing_time}")
    print(f"speed = {measured_speed}")
    print(f"expected speed = {expected_speed}")
    print(f"relative error = {speed_error}")

    finished = time.perf_counter()
    print(f"elapsed time = {finished - start} s")
    save_data(
        theta, phi, dx, alpha, L, diam, measured_speed, speed_error, finished - start
    )


if __name__ == "__main__":
    # ensure deterministic randomness
    random.seed(1)

    # pick random orientations
    orientations = [
        (random.random() * h.PI, random.random() * 2 * h.PI)
        for _ in range(NUM_ORIENTATIONS)
    ]

    # do the parameter study
    for dx in [2 ** -1, 2 ** -2, 2 ** -3, 2 ** -4, 1, 2 ** -5]:
        for alpha in [0.25, 0.15, 0.35]:
            for theta, phi in orientations:
                if any((old_data["dx"] == dx) & (old_data["alpha"] == alpha) & (old_data["theta"] == theta) & (old_data["phi"] == phi)):
                    print(f"Skipping: dx={dx}, alpha={alpha}, theta={theta}, phi={phi}")
                else:
                    print(f"Running: dx={dx}, alpha={alpha}, theta={theta}, phi={phi}")
                    p = multiprocessing.Process(
                        target=run_sim, args=(theta, phi, dx, alpha)
                    )
                    p.start()
                    p.join()