# Simulation of tripartite synapse model
# Tiina Manninen and Ausra Saudargiene
# Reference: Tiina Manninen, Ausra Saudargiene, and Marja-Leena Linne. Astrocyte-mediated spike-timing-dependent
# long-term depression modulates synaptic properties in the developing cortex. PLoS Comput Biol, 2020.

# -----------------------------------------------------------------------

from datetime import datetime
import os
from preneuron import Pre
from postneuron import Post
from astrocyte import Astro
from tqdm import tqdm
from scipy import io
from collections import defaultdict


def state_var_to_be_saved(pre, post, astro):

    return{
        "Ca_CaNHVA_pre": pre["Ca_CaNHVA_pre"],
        "Ca_NMDAR_pre": pre["Ca_NMDAR_pre"],
        "CaN_pre": pre["CaN_pre"],
        "Glu_syncleft": post["Glu_syncleft"],
        "Prel_pre": pre["Prel_pre"],
        "RA2_O_pre":  pre["RA2_O_pre"],
        "Rrel_pre": pre["Rrel_pre"],
        "V_pre": pre["V_pre"],
        "AG_post": post["AG_post"],
        "Ca_post": post["Ca_post"],
        "Ca_ER_post": post["Ca_ER_post"],
        "Ca_DAG_GaGTP_PLC_post": post["Ca_DAG_GaGTP_PLC_post"],
        "Ca_DAG_PLC_post": post["Ca_DAG_PLC_post"],
        "Ca_GaGTP_PLC_post": post["Ca_GaGTP_PLC_post"],
        "Ca_PLC_post": post["Ca_PLC_post"],
        "DAG_post": post["DAG_post"],
        "GaGTP_PLC_post": post["GaGTP_PLC_post"],
        "h_IP3R_post": post["h_IP3R_post"],
        "IP3_post": post["IP3_post"],
        "PLC_post": post["PLC_post"],
        "V_dend_post": post["V_dend_post"],
        "V_soma_post": post["V_soma_post"],
        "Ca_astro": astro["Ca_astro"],
        "Glu_extsyn": astro["Glu_extsyn"],
        "h_astro": astro["h_astro"],
        "IP3_astro": astro["IP3_astro"],
        "Rrel_astro": astro["Rrel_astro"]
           }


def main(T_shift):
    # Start to count the time spent in simulation
    time_start = datetime.now()
    print("Simulation started at", time_start)

    # Define the name of the result directory to be created
    path_template = "./results_post_pre_pairing_100x/%sms/"

    path = path_template % T_shift

    try:
        os.makedirs(path)
    except OSError:
        print("Creation of the directory %s failed" % path)
    else:
        print("Successfully created the directory %s " % path)

    # --------------------
    # External stimulation
    # --------------------
    stim_start = 20000          # ms; Stimulation start time
    trainlengthtime = 500000    # ms; Stimulation lasting time
    restlengthtime = 20000      # ms; Resting time after stimulation ends
    no_trains = 1               # 1; Number of trains
    pulserate = 0.2             # Hz, 1/s; Frequency of stimulus
    pulselengthtime = 10        # ms; Length of the pulse

    A_stim_post = 25            # uA/cm^2; External current amplitude to postsynaptic neuron per unit area
    A_stim_pre = 10             # uA/cm^2; External current amplitude to presynaptic neuron per unit area

    dt = 0.05                   # ms; Simulation time step
    T_end = stim_start + no_trains * (trainlengthtime + restlengthtime)  # ms; Simulation end time
    Nsteps = round(T_end / dt)  # Number of simulation steps

    t = [i * dt for i in range(Nsteps + 1)]  # ms; Time vector

    pulselength = round(pulselengthtime / dt)
    no_pulses = round(pulserate * trainlengthtime * 1e-3)
    pauselengthtime = (trainlengthtime - no_pulses * pulselengthtime) / no_pulses  # ms
    pauselength = round(pauselengthtime / dt)
    restlength = round(restlengthtime / dt)

    steps_T_shift = round(T_shift / dt)

    stim_pause_post = ([A_stim_post] * pulselength + [0] * pauselength) * no_pulses

    stim_pause_pre = ([A_stim_pre] * pulselength + [0] * pauselength) * no_pulses

    # Postsynaptic stimulus
    I_ext_post = [0] * (round(stim_start / dt) + 1) + (stim_pause_post + [0] * restlength) * no_trains

    # Presynaptic stimulus
    I_ext_pre = [0] * (round(stim_start / dt) + steps_T_shift + 1) + (stim_pause_pre + [0] * restlength) * no_trains

    # ---------------------------------
    # Presynaptic neuron initialization
    # ---------------------------------

    pre_params = Pre.get_parameters()
    pre_init = Pre.get_initial_values(pre_params)

    pre = Pre(pre_params, pre_init)

    # ----------------------------------
    # Postsynaptic neuron initialization
    # ----------------------------------

    post_params = Post.get_parameters()
    post_init = Post.get_initial_values(post_params)

    post = Post(post_params, post_init)

    # ------------------------
    # Astrocyte initialization
    # ------------------------
    astro_params = Astro.get_parameters()
    astro_init = Astro.get_initial_values(astro_params)

    astro = Astro(astro_params, astro_init)

    # -----------------------------------------
    # Initialization of Ca leak flux parameters
    # -----------------------------------------
    Ca_flux_post = post.calcium_other_fluxes()
    Ca_par_post = post.calcium_leak_parameters(Ca_flux_post["J_CaL_post"],
                                               Ca_flux_post["J_IP3R_post"],
                                               Ca_flux_post["J_NMDAR_post"],
                                               Ca_flux_post["J_PMCA_post"],
                                               Ca_flux_post["J_SERCA_post"])

    Ca_flux_astro = astro.calcium_other_fluxes()
    Ca_par_astro = astro.calcium_leak_parameters(Ca_flux_astro["J_IP3R_astro"],
                                                 Ca_flux_astro["J_SERCA_astro"])

    # ----------------------------------------------------------
    # Variables to be saved in a dictionary and then into a file
    # ----------------------------------------------------------
    state_var = state_var_to_be_saved(pre.x, post.x, astro.x)
    saved_state_var = {key: [values] for key, values in state_var.items()}
    saved_other_var = defaultdict(list)

    t_spike = 10     # ms

    for i in tqdm(range(Nsteps)):

        if i == stim_start * 1 / 2 / dt or i == stim_start * 3 / 4 / dt:

            # ---------------------------------
            # Adjusting Ca leak flux parameters
            # ---------------------------------
            Ca_flux_post = post.calcium_other_fluxes()
            Ca_par_post = post.calcium_leak_parameters(Ca_flux_post["J_CaL_post"],
                                                       Ca_flux_post["J_IP3R_post"],
                                                       Ca_flux_post["J_NMDAR_post"],
                                                       Ca_flux_post["J_PMCA_post"],
                                                       Ca_flux_post["J_SERCA_post"])

            Ca_flux_astro = astro.calcium_other_fluxes()
            Ca_par_astro = astro.calcium_leak_parameters(Ca_flux_astro["J_IP3R_astro"],
                                                         Ca_flux_astro["J_SERCA_astro"])

        # ---------------------------------------------
        # Saving old values for certain state variables
        # ---------------------------------------------
        Ca_pre_old = pre.x["Ca_CaNHVA_pre"]
        Prel_pre_old = pre.x["Prel_pre"]
        Rrel_pre_old = pre.x["Rrel_pre"]
        V_pre_old = pre.x["V_pre"]
        Ca_astro_old = astro.x["Ca_astro"]
        Rrel_astro_old = astro.x["Rrel_astro"]

        # ----------------------------------------------
        # Fraction of presynaptic Glu release inhibition
        # ----------------------------------------------
        f_pre = pre.x["X_ac_pre"] / pre_params["X_total_pre"]

        # ----------------------
        # Differential equations
        # ----------------------
        deriv_pre, other_var_pre = pre.derivative(
            astro.x["Glu_extsyn"], post.x["Glu_syncleft"], I_ext_pre[i+1])

        deriv_post, other_var_post = post.derivative(
            pre.params["f_Glu_pre"], I_ext_post[i+1], Ca_par_post["r_leakCell_post"], Ca_par_post["r_leakER_post"])

        deriv_ast, other_var_ast = astro.derivative(post.x["AG_post"], Ca_par_astro["r_leakER_astro"])

        # ----------------------------------
        # Solving the differential equations
        # ----------------------------------
        pre.solve_deriv(deriv_pre, dt)
        post.solve_deriv(deriv_post, dt)
        astro.solve_deriv(deriv_ast, dt)

        # -----------------------------------------------------
        # Updating those variables that include delta functions
        # -----------------------------------------------------

        # Counting the time from previous presynaptic spike
        if (pre.x["V_pre"] >= 0) and (V_pre_old < 0):
            t_spike = 0
        else:
            t_spike = t_spike + dt

        # Glu release from presynaptic neuron
        if (pre.x["Ca_CaNHVA_pre"] >= pre.params["C_thr_pre"]) and (t_spike < 10):
            pre.solve_deltaf(Ca_pre_old, f_pre, Prel_pre_old, Rrel_pre_old)
            post.solve_deltaf(pre, Rrel_pre_old)
            t_spike = 10

        # Glu release from astrocyte
        if (astro.x["Ca_astro"] >= astro.params["C_thr_astro"]) and (Ca_astro_old < astro.params["C_thr_astro"]):
            astro.solve_deltaf(Rrel_astro_old)

        # -----------
        # Saving data
        # -----------
        state_var = state_var_to_be_saved(pre.x, post.x, astro.x)

        for key, values in state_var.items():
            saved_state_var[key].append(values)

        other_var = {"f_pre": f_pre,
                     "Glu_NMDAR_pre": other_var_pre["Glu_NMDAR_pre"],
                     "ICaNHVA_pre": other_var_pre["ICaNHVA_pre"],
                     "ICa_NMDAR_pre": other_var_pre["ICa_NMDAR_pre"],
                     "I_AMPAR_post": other_var_post["I_AMPAR_post"],
                     "ICaLHVA_dend_post": other_var_post["ICaLHVA_dend_post"],
                     "ICaLLVA_dend_post": other_var_post["ICaLLVA_dend_post"],
                     "ICa_NMDAR_post": other_var_post["ICa_NMDAR_post"],
                     "J_CaL_post": other_var_post["J_CaL_post"],
                     "J_IP3R_post": other_var_post["J_IP3R_post"],
                     "J_leakCell_post": other_var_post["J_leakCell_post"],
                     "J_leakER_post": other_var_post["J_leakER_post"],
                     "J_NMDAR_post": other_var_post["J_NMDAR_post"],
                     "J_PMCA_post": other_var_post["J_PMCA_post"],
                     "J_SERCA_post": other_var_post["J_SERCA_post"]}

        for key, values in other_var.items():
            saved_other_var[key].append(values)

    # Saving dictionaries to mat files
    io.savemat(os.path.join(path, "state_var_results.mat"), saved_state_var)
    io.savemat(os.path.join(path, "other_var_results.mat"), saved_other_var)
    io.savemat(os.path.join(path, "time_stimuli.mat"),
               {**{"time": [tp / 1000 for tp in t]}, **{"I_ext_pre": [I_ext_pre]}, **{"I_ext_post": [I_ext_post]}})
    io.savemat(os.path.join(path, "stimulation_parameters.mat"),
               {**{"dt": dt}, **{"pulserate": pulserate}, **{"T_shift": T_shift}})

    # Simulation time
    time_end = datetime.now()
    total_time = (time_end - time_start).seconds / 60.  # min
    print("\n")
    print("Simulation finished at", time_end)
    print("Total time = {0:.2f} minutes".format(total_time))


if __name__ == "__main__":

    for T_shift in range(10, 210, 10):  # ms; Temporal difference between post and pre activation
        main(T_shift)