import math
import json
import argparse
import os
import numpy as np
from neuron import h, rxd
from neuron.units import hour, day, μm, ms, mV, mM
from neuron.rxd.node import Node3D
from matplotlib import pyplot as plt
import sqlite3
import pandas as pd

plt.ion()
h.load_file("stdrun.hoc")
rxd.options.ics_distance_threshold = -1e-15
rxd.options.concentration_nodes_3d = "all"


def boxoff(ax):
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.get_xaxis().tick_bottom()
    ax.get_yaxis().tick_left()


dendcount = 0


class Dendrite:
    """
    A 1D or hybrid 1D/3D model of a dendrite with 2N + 1 spines. The
    dendrite will be divided into three sections, with the central section
    and corresponding three spines simulated in 3D if dx is provided.

    Attributes
    ----------
    dend    list of 3 sections the left, center and right part of the dendrite.
    prot    rxd.Species representing the proteins synthesized by the
            active polyribosomes in the spine heads.

    """

    def __init__(
        self,
        Ls,
        N=3,
        dend_diam=5 * μm,
        dend_pad=0 * μm,
        neck_diam=0.2 * μm,
        neck_length=2 * μm,
        head_diam=1 * μm,
        head_length=1 * μm,
        D1=1e-3 * μm ** 2 / ms,
        D2=1e-3 * μm ** 2 / ms,
        lambd=60 * μm,
        ifactor=1.25,
        k=0.215e-4,
        cdis=2.0 * mM,
        kn=300,
        dx=None,
        load_init=None,
        nsegs=(25, 5),
    ):
        """
        Parameters
        ----------
        Ls          (float) the distance between spines (in μm).
        N           (int)   the number of spines on either side of the central
                            unpotentiated spine, so the total number of
                            spines is 2*N + 1.
        dend_diam   (float) the diameter (in μm) of the dendrite, default 5μm.
        dend_pad    (float) additional space on the ends of dendrite with no
                            spines. The total length will be;
                            (2 * N + 1) * Ls + 2 * dend_pad
        neck_diam   (float) the diameter (in μm) of the spine neck, default
                            1μm.
        neck_length (float) the length of the spine neck (in μm), default
                            2μm.
        head_diam   (float) the diameter of the spine head (in μm), default
                            1μm.
        head_length (float) the length of the spine head (in μm), default
                            1μm.
        D1          (float) the diffusion coefficient (μm**2/ms) of proteins
                            in the dendrite and unpotentiated spines.
                            default 1e-3.
        D2          (float) the diffusion coefficient (μm**2/ms) of proteins
                            in potentiated spines,
                            default 1e-3.
        lambd       (float) the length constant for the protein (in μm),
                            lambd=(D1/K)**0.5 where K is the protein degradation
                            rate, default 60μm
        k           (float) protein synthesis rate default 0.215e-4 mM/ms
        cdis        (float) the threshold concentration (mM) for Hill function
                            representing protein synthesis in the spine heads,
                            default 2mM.
        kn          (int)   Hill coefficient for the function representing
                            protein synthesis, default 300.
        dx          (float) optional voxel size for 3D simulation, if dx is
                            None the simulation uses 1D.
        load_init   (str)   optional, path to a json file with concentration for
                            each segment to use as their initial value. If
                            load_init is None, the potentiated spines heads
                            start with 2 * cdis, elsewhere at 0.85 * cdis
                            (an elevated value that is sufficiently small not
                            to effect the resulting steady-state).
        nsegs       (list)  pair of integers for the number of segments to use
                            for the spine necks and spine heads, default
                            (11,5).
        """

        global dendcount

        # split the dendrite into Left (1D) Center (1D or 3D) and Right (1D) sections
        dend = [
            h.Section(name="dendL%i" % dendcount),
            h.Section(name="dendC%i" % dendcount),
            h.Section(name="dendR%i" % dendcount),
        ]
        dendcount += 1

        if N < 1:
            raise Exception(
                "Dendrite must have at least 1 spine either side of the potentiated central spine"
            )

        # parameters
        self.dend = dend
        self.dend_pad = dend_pad
        self.length = math.ceil(Ls * (2 * N + 1) + 2 * dend_pad)
        self.dend_diam = dend_diam
        self.neck_l = neck_length
        self.neck_diam = neck_diam
        self.Ls = Ls
        self.k = k
        self.ifactor = ifactor
        self.kn = kn
        self.D1 = D1
        self.D2 = D2
        self.lambd = lambd
        self.head_diam = head_diam
        self.head_l = head_length
        self.tx = dx / 10.0 if dx else 0  # a small offset to ensure 3D sections
        # are joined together
        self.dx = dx
        self.cdis = cdis
        self.allsec = dend.copy()
        self.secs3d = [dend[1]]

        # create the dendrite
        dendx = 0
        dendy = np.round(-self.length / 2.0, 2)
        len3d = 3 * Ls
        lens = [(self.length - len3d) / 2.0, len3d, (self.length - len3d) / 2.0]
        for i, (sec, length) in enumerate(zip(dend, lens)):
            sec.pt3dclear()
            sec.nseg = max(1, int(2 * length + (2 * length / 3) % 3))
            sec.pt3dadd(dendx, dendy, 0, dend_diam)

            # add additional point for 3d/1d connection
            if sec not in self.secs3d and dx is not None:
                sec.pt3dadd(dendx, np.round(dendy + dx, 2), 0, dend_diam)
            dendy = np.round(dendy + length, 2)

            # add additional point for 3d/1d connection
            if sec not in self.secs3d and dx is not None:
                sec.pt3dadd(dendx, np.round(dendy - dx, 2), 0, dend_diam)
            sec.pt3dadd(dendx, dendy, 0, dend_diam)

        dend[1].connect(dend[0](1), 0)
        dend[2].connect(dend[1](1), 0)

        # add spines (only 3 will be 3D)
        self.Nspines = 0
        self.allsec = [dend[0], dend[1], dend[2]]
        self.secs3d = [dend[1]]
        self.allhead = []
        spines = []
        self.add_spine(ypos=0, nsegs=nsegs, use_3d=True)
        for i in range(1, N):
            self.add_spine(
                ypos=-float(i) * Ls, nsegs=nsegs, use_3d=(i == 1), spinelist=spines
            )
            self.add_spine(
                ypos=float(i) * Ls, nsegs=nsegs, use_3d=(i == 1), spinelist=spines
            )

        self.active_spines = spines
        # use 3D on the central 3 spines if dx was specified
        if dx:
            # 3D sim
            rxd.set_solve_type(self.secs3d, dimension=3)
            self.cyt = rxd.Region(self.allsec, nrn_region="i", dx=dx)
        else:
            self.cyt = rxd.Region(self.allsec, name="cyt", nrn_region="i")

        # load a previous solution is one was provided
        # otherwise start with all but the central spine in a potentiated state
        if load_init:
            self.load(load_init)
            init = lambda nd: self._loaded_vals[repr(nd.segment)]
        else:
            init = lambda nd: 2 * cdis if nd.segment.sec in self.allhead[1:] else 0

        # define the species
        pp = rxd.Species(self.cyt, d=D1, name="prot", charge=1, initial=init)

        head_vols = {}
        for sec in self.allsec:
            if sec in self.allhead:
                if sec in self.secs3d and dx is not None:
                    head_vols[sec] = 1.0 / sum(pp.nodes(sec).volume)
                else:
                    head_vols[sec] = 1.0 / sum(self.cyt.geometry.volumes1d(sec))
            else:
                head_vols[sec] = 0

        # parameter 1/vol for node in spine heads and zero elsewhere to control
        # the production rate -- attempt reduce difference between 1D and 3D
        # simulations due to volume differences from voxelization
        self.in_head = rxd.Parameter(
            self.cyt, name="vol", initial=lambda nd: 1 if nd.sec in self.allhead else 0
        )

        # degradation occurs everywhere
        K = D1 / lambd / lambd
        self.degrad = rxd.Rate(pp, -K * pp)

        # convert the rate from nA to mM/ms
        k_adj = k * ifactor * self.in_head
        self.prod = rxd.Rate(pp, k_adj * pp ** kn / (pp ** kn + cdis ** kn))
        self.prot = pp
        for sec in self.active_spines:
            for nd in pp.nodes(sec):
                nd.d = D2
        for sec in self.dend:
            for nd in pp.nodes(sec):
                nd.d = D1

    def dend_seg_by_ypos(self, ypos):
        """returns the segment in the dendrite for a given position (ypos)"""

        pos = self.length / 2.0 + ypos
        for sec in self.dend:
            if sec.L >= pos:
                return sec(pos / sec.L)
            pos -= sec.L

    def dend_ypos_by_seg(self, seg):
        """return the position in the dendrite for a given segment (seg)"""

        offset = 0
        for sec in self.dend:
            if seg in sec:
                return offset + seg.x * sec.L
            offset += sec.L

    def add_spine(self, ypos, nsegs, use_3d, spinelist=None):
        """add a spine to the dendrite at position (ypos)"""

        # create the sections
        head = h.Section(name="head%i" % self.Nspines)
        neck = h.Section(name="neck%i" % self.Nspines)
        self.Nspines += 1

        # 1D set geometry
        head.L = self.head_l
        head.diam = self.head_diam
        neck.L = self.neck_l
        neck.diam = self.neck_diam
        neck.nseg = nsegs[0]
        head.nseg = nsegs[1]

        # depth to ensure the spine is in the dendrite
        depth = (
            self.dend_diam - (self.dend_diam ** 2 - self.neck_diam ** 2) ** (0.5)
        ) / 2.0

        # 3D set geometry
        neckx0 = self.dend_diam / 2.0 - self.tx - depth
        neckx1 = neckx0 + self.neck_l
        headx0 = neckx1 - self.tx
        headx1 = neckx1 + self.head_l

        neck.pt3dclear()
        neck.pt3dadd(neckx0, ypos, 0, self.neck_diam)
        neck.pt3dadd(neckx1, ypos, 0, self.neck_diam)

        head.pt3dclear()
        head.pt3dadd(headx0, ypos, 0, self.neck_diam)
        head.pt3dadd(headx0, ypos, 0, self.head_diam)
        head.pt3dadd(headx1, ypos, 0, self.head_diam)

        # store the sections
        self.allsec.append(neck)
        self.allsec.append(head)
        self.allhead.append(head)
        if spinelist is not None:
            spinelist.append(neck)
            spinelist.append(head)

        if use_3d:
            self.secs3d += [neck, head]

        # connect them if using 1D
        if not self.dx or not use_3d:
            head.connect(neck(1))
            neck.connect(self.dend_seg_by_ypos(ypos), 0)

    def plot1d(self, fun=np.sum):
        """plot the concentration in the segments of the dendrite"""

        dend_x = [
            self.dend_ypos_by_seg(seg) - self.length / 2.0
            for sec in self.dend
            for seg in sec
        ]
        dend_pp = [seg.proti for sec in self.dend for seg in sec]
        plt.plot(dend_x, dend_pp, label="%1.2fms" % h.t)
        plt.xlabel("x (μm)")
        plt.ylabel("concentration (mM)")

    def implots(self, dr=None, vmin=None, vmax=None):
        """heat plot the 3D concentrations

        Arguments
        ---------
        dr (str)        optional, 'x', 'y', or 'z' to average over the given
                        direction, if not specified all three are shown in
                        subplots

        vmin (float)    optional, the minimum concentration of the plot
        vmax (float)    optional, the maximum concentration of the plot.
        """

        drlookup = {"x": 0, "y": 1, "z": 2}
        r = self.cyt
        data = np.nan * np.ones((max(r._xs) + 1, max(r._ys) + 1, max(r._zs) + 1))
        for nd in dendrite.prot.nodes:
            if isinstance(nd, Node3D):
                data[nd._i, nd._j, nd._k] = nd.value
        fig = plt.figure(dpi=200)
        if dr:
            plt.imshow(
                np.nanmean(data, drlookup[dr]), aspect="auto", vmin=vmin, vmax=vmax
            )
            plt.colorbar()
        else:
            for i in range(3):
                plt.subplot(3, 1, 1 + i)
                plt.imshow(np.nanmean(data, i), aspect="auto", vmin=vmin, vmax=vmax)
                plt.colorbar()

    def save(self, filename):
        """save the segment concentrations to a json file (filename) that can
        be used to initialize the model"""

        vals = dict()
        for sec in self.allsec:
            for seg in sec:
                vals[repr(seg)] = seg.proti
        with open(filename, "w") as fp:
            json.dump(vals, fp)

    def load(self, filename):
        """load the segment concentrations from a json file (filename)"""

        with open(filename, "r") as fp:
            self._loaded_vals = json.load(fp)

    def save_concentrations(self, filename):

        """save the protein concentration to filename"""
        np.save(filename, np.array(dendrite.prot.nodes.value))

    def save_sql(self, dbfilename):
        neck_nd = self.prot.nodes(dendrite.active_spines[0](1e-4))[0]
        dend_nd = self.prot.nodes(dendrite.active_spines[0](0))[0]

        neck = neck_nd.sec
        dend = dend_nd.sec
        neck_dx = neck.L / neck.nseg
        dend_dx = dend.L / dend.nseg
        area = np.pi * (neck.diam / 2) ** 2
        rate_out = 2 * neck_nd.d * area / (neck_nd.volume * neck_dx)
        rate_in = rate_out * neck_nd.volume / dend_nd.volume
        flux = rate_out * neck_nd.value - rate_in * dend_nd.value
        data = pd.DataFrame(
            {
                "Ls": self.Ls,
                "N": (self.Nspines + 1) / 2,
                "dend_diam": self.dend_diam,
                "dend_pad": self.dend_pad,
                "neck_diam": self.neck_diam,
                "head_diam": self.head_diam,
                "head_length": self.head_l,
                "neck_length": self.neck_l,
                "D1": [self.D1],
                "D2": [self.D2],
                "lambda": self.lambd,
                "k": self.k,
                "ifactor": self.ifactor,
                "cdis": self.cdis,
                "kn": self.kn,
                "dx": dx if self.dx is not None else 0,
                "active": dendrite.allhead[0].proti > self.cdis,
                "conc": dendrite.allhead[0].proti,
                "flux": flux,
            }
        )
        with sqlite3.connect(dbfilename) as conn:
            data.to_sql("data", conn, if_exists="append", index=False)


if __name__ == "__main__":
    try:
        parser = argparse.ArgumentParser(
            description="""Run hybrid 1D/3D spines simulation"""
        )
        parser.add_argument(
            "--N",
            nargs="?",
            type=int,
            default=5,
            help="""number of spines either side of the central spine -- default 5""",
        )
        parser.add_argument(
            "--Ls",
            nargs="?",
            type=float,
            default=19,
            help="""length between spines -- default 19um""",
        )
        parser.add_argument(
            "--D2",
            nargs="?",
            type=float,
            default=1e-3,
            help="""Diffusion in potentiated spines -- default 1e-3um^2/ms""",
        )
        parser.add_argument(
            "--lambd",
            nargs="?",
            type=float,
            default=60,
            help="""length scale -- default 60um""",
        )
        parser.add_argument(
            "--k",
            nargs="?",
            type=float,
            default=0.215e-4,
            help="""protein production rate -- default 0.215e-4""",
        )
        parser.add_argument(
            "--dx",
            nargs="?",
            type=float,
            default=None,
            help="""3D discretization  -- default use 1D only""",
        )
        parser.add_argument(
            "--neck_length",
            nargs="?",
            type=float,
            default=2,
            help="""spine neck length  -- default 2um""",
        )
        parser.add_argument(
            "--initial",
            nargs="?",
            type=str,
            default=None,
            help="""json file for initial concentrations by segment""",
        )
        parser.add_argument(
            "--output",
            nargs="?",
            type=str,
            default=None,
            help="""output filename for summary data""",
        )
        args = parser.parse_args()
    except:
        os._exit(1)

    Ls = (
        args.Ls
    )  # critical value 19.475um (or 21.585um if the production rate is 25% greater)
    dx = args.dx
    N = args.N
    D2 = args.D2
    init_file = args.initial
    lambd = args.lambd
    rxd.nthread(4)

    cv = h.CVode()
    cv.active(True)
    cv.atol(1e-7)
    dendrite = Dendrite(
        N=N,
        Ls=Ls,
        dx=dx,
        D2=D2,
        k=args.k,
        neck_length=args.neck_length,
        lambd=lambd,
        load_init=init_file,
    )
    t_vec = h.Vector().record(h._ref_t)
    c_vec = h.Vector().record(dendrite.allhead[0](0.5)._ref_proti)

    # initialize and run
    h.finitialize(-70 * mV)
    dconc = dendrite.dend[1](0.5).proti
    hconc = dendrite.allhead[0](0.5).proti
    for i in range(1, 10):
        h.continuerun(i * day)
        if (
            abs(dconc - dendrite.dend[1](0.5).proti) < 1e-12
            and abs(hconc - dendrite.allhead[0](0.5).proti) < 1e-12
        ):
            break
        else:
            dconc = dendrite.dend[1](0.5).proti
            hconc = dendrite.allhead[0](0.5).proti
    else:
        print("did not converge")

    if args.output:
        dendrite.save_sql(args.output)

    # save the results
    if dx:
        dendrite.save_concentrations(
            "spines_hybrid_N_%i_Ls_%1.2f_dx_%1.2.npy" % (N, Ls, dx)
        )
        dendrite.plot1d()
        plt.savefig("spines_hybrid_N_%i_Ls_%1.2f_dx_%1.2_1D_plot.png" % (N, Ls, dx))
        plt.close()

        dendrite.implot()
        plt.savefig("spines_hybrid_N_%i_Ls_%1.2f_dx_%1.2_heatmap.png" % (N, Ls, dx))

    else:
        dendrite.save_concentrations("spines_1D_N_%i_Ls_%1.2f.npy" % (N, Ls))
        dendrite.plot1d()
        plt.savefig("spines_1D_N_%i_Ls_%1.2f_1D_plot.png" % (N, Ls))
        plt.close()