import numpy as np
import math
import sys
from random import randrange as randr
import steps.geom as stetmesh
import steps.rng as srng
import steps.solver as ssolver
import ip3r_model_nodeshaft

# Set simulation parameters
T_END = 30
DT = 0.001
POINTS = int(T_END / DT)
tpnts = np.arange(0.0, T_END, DT)
ntpnts = tpnts.shape[0]

# Import model
mdl = ip3r_model_nodeshaft.getModel()

# Create random number generator
seed = int(sys.argv[1])
r = srng.create('mt19937', 512)
r.initialize(seed)

# Import geometry
name_mesh_file = str(sys.argv[2])
mesh, central_tets, cyt_tris, er_tris, central_tris, cyto_tets, er_area, cyto_area = ip3r_model_nodeshaft.gen_geom(
    name_mesh_file)

# Calculate nb of IP3R on ER membrane
# density: 50 IP3R molecules on 1.76243263552e-13 m2 (denizot et al, PLOS CB, 2019)
nb_ip3r = int(er_area * 50 / (1.76243263552e-13))

# get time interval between successive stimuli
ip3_inf = 50
# get probability that node stimulation fails proba_fail
proba_fail = float(sys.argv[3])

# Create solver object
sim = ssolver.Tetexact(mdl, mesh, r)

###############################
######## ISOLATE NODES ########
###############################

min_xcoord = mesh.getBoundMin()[0]
max_xcoord = mesh.getBoundMax()[0]
sphere_rad = 189.5 * math.pow(10, -9)
# define the barycenters of each node
spheres_centers = [min_xcoord + sphere_rad, (min_xcoord + sphere_rad) / 2, 0, (max_xcoord - sphere_rad) / 2,
                   max_xcoord - sphere_rad]

# list of tetrahedra of each node
nodes_tets_table = [[] for i in range(len(spheres_centers))]
for t in cyto_tets:
    coords = mesh.getTetBarycenter(t)
    for i in range(len(spheres_centers)):
        if coords[0] < spheres_centers[i] + sphere_rad and coords[0] > spheres_centers[i] - sphere_rad:
            nodes_tets_table[i].append(t)

# Same method to create list of triangles of ER membrane in each node
sphere_rad_er = 85 * math.pow(10, -9)
nodes_tris_table = [[] for i in range(len(spheres_centers))]
for t in er_tris:
    coords = mesh.getTriBarycenter(t)
    for i in range(len(spheres_centers)):
        if coords[0] < spheres_centers[i] + sphere_rad_er and coords[0] > spheres_centers[i] - sphere_rad_er:
            nodes_tris_table[i].append(t)

# create ROIs corresponding to each node and each ER triangles corresponding to the given node
for i in range(len(spheres_centers)):
    mesh.addROI('node_' + str(i), stetmesh.ELEM_TET, nodes_tets_table[i])
    mesh.addROI('node_surf_' + str(i), stetmesh.ELEM_TRI, nodes_tris_table[i])

######################################
# Run the simulation and record data #
######################################
# Reset the simulation object
sim.reset()

# Set initial conditions
sim.setCompConc('cyto', 'ca', 120e-9)
sim.setCompConc('cyto', 'ip3', 120e-9)
nb_plc = int(cyto_area * 1696 / 6.88566421434e-13)
sim.setPatchCount('cyto_patch', 'plc', nb_plc)
sim.setCompConc('cyto', 'GCaMP6s', 9.9e-6)
sim.setCompConc('cyto', 'ca_GCaMP6s', 0.1e-6)
sim.setPatchCount('er_patch', 'unb_IP3R', nb_ip3r)

# time of stimulation of 1st Node
t0 = 5000

# set inter-stimulation time interval, tau_ip3
T = int(sys.argv[4])

# set stimulation count to 0
count = 0

# set file name
fna = 'out.' + str(name_mesh_file) + '.T=' + str(T) + '.pfail=' + str(proba_fail) + '.' + str(seed)
f = open(fna, "w")

# run the simulation
for i in range(ntpnts):
    sim.run(tpnts[i])
    ##### Node stimulation: IP3 infusion in the node, following protocol described in Fig2 ######
    if (i - t0) % T == 0 and i < t0 + 4 * T + 2 and i >= t0:
            ROI_name = 'node_' + str(count)
            init_ip3 = sim.getROICount(ROI_name, 'ip3')
            proba_inf = 1 - proba_fail
            # ip3 infusion occurs with proba 1-proba_fail
            if i != t0:
                rd = random.uniform(0, 1)
            else:
                rd = 1
            if rd <= proba_inf:
                sim.setROICount(ROI_name, 'ip3', init_ip3 + ip3_inf)
            count += 1

    ROI_name = 'node_4'
    ROI_surf_name = 'node_surf_4'
    f.write("%d %d %d\n" % (sim.getROICount(ROI_name, 'ca_GCaMP6s'), sim.getROICount(ROI_name, 'ip3'),
                                sim.getROICount(ROI_surf_name, 'open_IP3R')))
    f.flush()

f.close()