# coding=utf-8
import steps.model as smod
import steps.geom as stetmesh
import steps.utilities.meshio as smeshio
import steps.rng as rng
import steps.utilities.meshio as meshio
import steps.utilities.meshctrl as meshctrl
from numpy import math
import pylab as pl
import itertools as it
import sys
import pickle

# function creating the model of calcium dynamics
# model from Denizot et al., PLOS Computational Biology, 2019.
def getModel():
    # Create model container
    mdl = smod.Model()

    # Create chemical species
    ca = smod.Spec('ca', mdl)
    ip3 = smod.Spec('ip3', mdl)
    plc = smod.Spec('plc', mdl)

    # Calcium buffers
    GCaMP6s = smod.Spec('GCaMP6s', mdl)
    ca_GCaMP6s = smod.Spec('ca_GCaMP6s', mdl)

    # Create IP3R states objects
    unb_IP3R = smod.Spec('unb_IP3R', mdl)
    ip3_IP3R = smod.Spec('ip3_IP3R', mdl)
    caa_IP3R = smod.Spec('caa_IP3R', mdl)
    cai_IP3R = smod.Spec('cai_IP3R', mdl)
    open_IP3R = smod.Spec('open_IP3R', mdl)
    cai_ip3_IP3R = smod.Spec('cai_ip3_IP3R', mdl)
    ca2_IP3R = smod.Spec('ca2_IP3R', mdl)
    ca2_ip3_IP3R = smod.Spec('ca2_ip3_IP3R', mdl)

    # ER surf sys
    ssys = smod.Surfsys('ssys', mdl)

    # outer mb surface
    mb_surf = smod.Surfsys('mb_surf', mdl)

    # Create volume system
    # cyt vol sys
    vsys = smod.Volsys('vsys', mdl)

    # ER vol system
    er_vsys = smod.Volsys('er_vsys', mdl)

    ##### Create diffusion rule
    # Diffusion constant of Calcium (buffered)
    DCST = 0.013e-9
    # Diffusion constant of IP3
    DIP3 = 0.280e-9
    # Diffusion constant of GCaMP6s
    DGCAMP = 0.050e-9

    diff_freeca = smod.Diff('diff_freeca', vsys, ca, DCST)
    diff_ip3 = smod.Diff('diff_ip3', vsys, ip3, DIP3)

    diff_GCaMP6s = smod.Diff('diff_GCaMP6s', vsys, GCaMP6s, DGCAMP)
    diff_ca_GCaMP6s = smod.Diff('diff_ca_GCaMP6s', vsys, ca_GCaMP6s, DGCAMP)

    #####  DEFINE REACTIONS  #####
    #### Calcium Influx and Buffering Reactions ######

    # Ca -> null
    ca_deg = smod.Reac('ca_deg', vsys, lhs=[ca])

    # Ca leak
    ca_leak = smod.Reac('ca_leak', vsys, rhs=[ca])

    # Mobile Ca Buffers
    # PV + Ca <-> CaPV + Ca <-> Ca2PV      PV: Parvalbumin
    # volume/volume reaction ca interacting with PV

    GCaMP6s_bind_ca_f = smod.Reac('GCaMP6s_bind_ca_f', vsys, \
                                  lhs=[ca, GCaMP6s], rhs=[ca_GCaMP6s])
    GCaMP6s_bind_ca_b = smod.Reac('GCaMP6s_bind_ca_b', vsys, \
                                  lhs=[ca_GCaMP6s], rhs=[GCaMP6s, ca])

    #### IP3 Influx and Buffering Reactions ######
    # IP3 leak
    ip3_leak = smod.Reac('ip3_leak', vsys, rhs=[ip3])

    # IP3 degradation
    ip3_deg = smod.Reac('ip3_deg', vsys, lhs=[ip3])

    # ca activating plc_delta-dependent IP3 synthesis
    plc_ip3_synthesis = smod.SReac('plc_ip3_synthesis', mb_surf, \
                                   slhs=[plc], ilhs=[ca], srhs=[plc], irhs=[ca, ip3])

    #### IP3R kinetics #####
    # surface/volume reaction ca from cytosol binds activating IP3R site on unbound IP3R
    unb_IP3R_bind_caa_f = smod.SReac('unb_IP3R_bind_caa_f', ssys, \
                                     ilhs=[ca], slhs=[unb_IP3R], srhs=[caa_IP3R])
    unb_IP3R_bind_caa_b = smod.SReac('unb_IP3R_bind_caa_b', ssys, \
                                     slhs=[caa_IP3R], srhs=[unb_IP3R], irhs=[ca])

    # surface/volume reaction ca from cytosol binds inactivating IP3R site on unbound IP3R
    unb_IP3R_bind_cai_f = smod.SReac('unb_IP3R_bind_cai_f', ssys, \
                                     ilhs=[ca], slhs=[unb_IP3R], srhs=[cai_IP3R])
    unb_IP3R_bind_cai_b = smod.SReac('unb_IP3R_bind_cai_b', ssys, \
                                     slhs=[cai_IP3R], srhs=[unb_IP3R], irhs=[ca])

    # surface/volume reaction ca from cytosol binds activating IP3R site on caa_IP3R
    caa_IP3R_bind_ca_f = smod.SReac('caa_IP3R_bind_ca_f', ssys, \
                                    ilhs=[ca], slhs=[caa_IP3R], srhs=[ca2_IP3R])
    caa_IP3R_bind_ca_b = smod.SReac('caa_IP3R_bind_ca_b', ssys, \
                                    slhs=[ca2_IP3R], srhs=[caa_IP3R], irhs=[ca])

    # surface/volume reaction ca from cytosol binds activating IP3R site on ip3_IP3R
    ip3_IP3R_bind_caa_f = smod.SReac('ip3_IP3R_bind_caa_f', ssys, \
                                     ilhs=[ca], slhs=[ip3_IP3R], srhs=[open_IP3R])
    ip3_IP3R_bind_caa_b = smod.SReac('ip3_IP3R_bind_caa_b', ssys, \
                                     slhs=[open_IP3R], srhs=[ip3_IP3R], irhs=[ca])

    # surface/volume reaction ca from cytosol binds inactivating IP3R site on ip3_IP3R
    ip3_IP3R_bind_cai_f = smod.SReac('ip3_IP3R_bind_cai_f', ssys, \
                                     ilhs=[ca], slhs=[ip3_IP3R], srhs=[cai_ip3_IP3R])
    ip3_IP3R_bind_cai_b = smod.SReac('ip3_IP3R_bind_cai_b', ssys, \
                                     slhs=[cai_ip3_IP3R], srhs=[ip3_IP3R], irhs=[ca])

    # surface/volume reaction ca from cytosol binds activating IP3R site on cai_IP3R
    cai_IP3R_bind_ca_f = smod.SReac('cai_IP3R_bind_ca_f', ssys, \
                                    ilhs=[ca], slhs=[cai_IP3R], srhs=[ca2_IP3R])
    cai_IP3R_bind_ca_b = smod.SReac('cai_IP3R_bind_ca_b', ssys, \
                                    slhs=[ca2_IP3R], srhs=[cai_IP3R], irhs=[ca])

    # surface/volume reaction ca from cytosol binds inactivating IP3R site on open_IP3R
    open_IP3R_bind_ca_f = smod.SReac('open_IP3R_bind_ca_f', ssys, \
                                     ilhs=[ca], slhs=[open_IP3R], srhs=[ca2_ip3_IP3R])
    open_IP3R_bind_ca_b = smod.SReac('open_IP3R_bind_ca_b', ssys, \
                                     slhs=[ca2_ip3_IP3R], srhs=[open_IP3R], irhs=[ca])

    # surface/volume reaction ip3 from cytosol binds unb_IP3R
    unb_IP3R_bind_ip3_f = smod.SReac('unb_IP3R_bind_ip3_f', ssys, \
                                     ilhs=[ip3], slhs=[unb_IP3R], srhs=[ip3_IP3R])
    unb_IP3R_bind_ip3_b = smod.SReac('unb_IP3R_bind_ip3_b', ssys, \
                                     slhs=[ip3_IP3R], srhs=[unb_IP3R], irhs=[ip3])

    # surface/volume reaction ip3 from cytosol binds caa_IP3R
    caa_IP3R_bind_ip3_f = smod.SReac('caa_IP3R_bind_ip3_f', ssys, \
                                     ilhs=[ip3], slhs=[caa_IP3R], srhs=[open_IP3R])
    caa_IP3R_bind_ip3_b = smod.SReac('caa_IP3R_bind_ip3_b', ssys, \
                                     slhs=[open_IP3R], srhs=[caa_IP3R], irhs=[ip3])

    # surface/volume reaction ip3 from cytosol binds cai_IP3R
    cai_IP3R_bind_ip3_f = smod.SReac('cai_IP3R_bind_ip3_f', ssys, \
                                     ilhs=[ip3], slhs=[cai_IP3R], srhs=[cai_ip3_IP3R])
    cai_IP3R_bind_ip3_b = smod.SReac('cai_IP3R_bind_ip3_b', ssys, \
                                     slhs=[cai_ip3_IP3R], srhs=[cai_IP3R], irhs=[ip3])

    cai_ip3_IP3R_bind_ca_f = smod.SReac('cai_ip3_IP3R_bind_ca_f', ssys, \
                                        ilhs=[ca], slhs=[cai_ip3_IP3R], srhs=[ca2_ip3_IP3R])
    cai_ip3_IP3R_bind_ca_b = smod.SReac('cai_ip3_IP3R_bind_ca_b', ssys, \
                                        slhs=[ca2_ip3_IP3R], srhs=[cai_ip3_IP3R], irhs=[ca])

    # surface/volume reaction ip3 from cytosol binds ca2_IP3R
    ca2_IP3R_bind_ip3_f = smod.SReac('ca2_IP3R_bind_ip3_f', ssys, \
                                     ilhs=[ip3], slhs=[ca2_IP3R], srhs=[ca2_ip3_IP3R])
    ca2_IP3R_bind_ip3_b = smod.SReac('ca2_IP3R_bind_ip3_b', ssys, \
                                     slhs=[ca2_ip3_IP3R], srhs=[ca2_IP3R], irhs=[ip3])

    # Ca ions passing through open IP3R channel
    # Ca from inner compartment (ER) flux to outer compartment (cytosol)
    Ca_IP3R_flux = smod.SReac('R_Ca_channel_f', ssys, \
                              slhs=[open_IP3R], irhs=[ca], srhs=[open_IP3R])

    ##### Reaction constants values

    # GCaMP mediated ca buffering
    GCaMP6s_bind_ca_f.setKcst(7.78e6)
    GCaMP6s_bind_ca_b.setKcst(1.12)

    # (maybe *10 both later? that's what I was doing before apparently..)
    # Ca ->  null
    # Ca ->  null
    ca_deg.setKcst(30)

    # Ca leak
    ca_leak.setKcst(15e-8)

    #### IP3 Influx and Buffering Reactions ######
    # IP3 leak -> does not exist in this model
    # ip3_leak.setKcst(0)

    # IP3 -> null
    ip3_deg.setKcst(1.2e-4)

    # ca activating plc_delta-dependent IP3 synthesis
    plc_ip3_synthesis.setKcst(1)

    #### IP3R kinetics #####
    caa_f = 1.2e6
    cai_f = 1.6e4
    ip3_f = 4.1e7
    caa_b = 5e1
    cai_b = 1e2
    ip3_b = 4e2
    unb_IP3R_bind_caa_f.setKcst(caa_f)
    unb_IP3R_bind_caa_b.setKcst(caa_b)

    unb_IP3R_bind_cai_f.setKcst(cai_f)
    unb_IP3R_bind_cai_b.setKcst(cai_b)

    caa_IP3R_bind_ca_f.setKcst(cai_f)
    caa_IP3R_bind_ca_b.setKcst(cai_b)

    ip3_IP3R_bind_caa_f.setKcst(caa_f)
    ip3_IP3R_bind_caa_b.setKcst(caa_b)

    ip3_IP3R_bind_cai_f.setKcst(cai_f)
    ip3_IP3R_bind_cai_b.setKcst(cai_b)

    cai_IP3R_bind_ca_f.setKcst(caa_f)
    cai_IP3R_bind_ca_b.setKcst(caa_b)

    open_IP3R_bind_ca_f.setKcst(cai_f)
    open_IP3R_bind_ca_b.setKcst(cai_b)

    unb_IP3R_bind_ip3_f.setKcst(ip3_f)
    unb_IP3R_bind_ip3_b.setKcst(ip3_b)

    caa_IP3R_bind_ip3_f.setKcst(ip3_f)
    caa_IP3R_bind_ip3_b.setKcst(ip3_b)

    cai_IP3R_bind_ip3_f.setKcst(ip3_f)
    cai_IP3R_bind_ip3_b.setKcst(ip3_b)

    cai_ip3_IP3R_bind_ca_f.setKcst(caa_f)
    cai_ip3_IP3R_bind_ca_b.setKcst(caa_b)

    ca2_IP3R_bind_ip3_f.setKcst(ip3_f)
    ca2_IP3R_bind_ip3_b.setKcst(ip3_b)

    # Ca ions passing through open IP3R channel
    Ca_IP3R_flux.setKcst(6e3)

    return mdl


########################################################################
def gen_geom(name_mesh_file):
    # import the tetrahedral mesh
    mesh, nodeproxy, tetproxy, triproxy = meshio.importAbaqus(name_mesh_file, 1e-9)
    # create a compartment comprising all mesh tetrahedrons
    ntets = mesh.countTets()

    # define the different compartments of the branchlet created with Trelis software https://www.csimsoft.com/trelis
    tet_groups = tetproxy.blocksToGroups()
    
    cyto_tets = tet_groups["EB36"] + tet_groups["EB37"] + tet_groups["EB38"] + tet_groups["EB39"] + tet_groups["EB40"]
    er_tets = tet_groups["EB25"]
    central_tets = tet_groups["EB36"]

    if name_mesh_file == '5nodes_ratio1.inp':
        to_bleach_tets = tet_groups["EB36"] + tet_groups["EB39"] + tet_groups["EB40"]

    elif name_mesh_file == '5nodes_ratio2.inp':
        to_bleach_tets = tet_groups["EB36"] + tet_groups["EB37"] + tet_groups["EB40"]
    else:
        to_bleach_tets = tet_groups["EB36"] + tet_groups["EB37"] + tet_groups["EB38"]

    # create cyto compartment
    cyto = stetmesh.TmComp('cyto', mesh, cyto_tets)
    # add volume system to cytosol
    cyto.addVolsys('vsys')

    # Define surfaces
    # ER surface triangles can be defined as the overlap between the astro volume and the ER volume
    ER_TRIS = meshctrl.findOverlapTris(mesh, cyto_tets, er_tets)

    # define out
    ASTRO_TRIS = mesh.getSurfTris()

    # get central tris
    central_tris = meshctrl.findOverlapTris(mesh, cyto_tets, central_tets)

    # create the patch for er membrane
    er_patch = stetmesh.TmPatch('er_patch', mesh, ER_TRIS, cyto)
    er_patch.addSurfsys('ssys')
    er_area = er_patch.getArea()

    # create the patch for astro membrane
    cyto_patch = stetmesh.TmPatch('cyto_patch', mesh, ASTRO_TRIS, icomp=cyto)
    cyto_patch.addSurfsys('mb_surf')
    cyto_area = cyto_patch.getArea()

    # return geometry container object
    return mesh, central_tets, ASTRO_TRIS, ER_TRIS, central_tris, cyto_tets, er_area, cyto_area