from math import sqrt
from collections import deque
import neuron as nrn
import os
import random
import numpy as np

class Node:
    def __init__(self, item, soma=None, parent=None):
        if(soma == None):
            self.soma = self
        else:
            self.soma = soma
        self.syndist = 9999999
        self.level = 0
        self.parent = None
        self.children = []
        self.x = item.x
        self.item = item
        self.somadist = 0
        self.trunk = True
        if(parent != None):
            self.parent = parent
            self.level = parent.level+1
            self.parent.children.append(self)
            if(parent.trunk and parent.item.diam > 1.5):
                self.trunk = True
            else:
                self.trunk = False
            if(not parent.trunk):
                self.trunk = False
            
        self.nodeid = Node.x_to_id(self)
        self.dist_to_parent = self.calcdistancetoparent()
        
        if self.parent != None:
            self.somadist = self.parent.somadist + self.dist_to_parent
            
        self.dist_to_trunk = self.calc_dist_to_trunk()
        
        p1 = self.getpoint()
        self.point = p1
        p1y = 0
        p2y = 0
        if len(p1) < 1:
            p1 = 0
        else:
            p1y = p1[1]
        p2 = self.soma.getpoint()
        if len(p2) < 1:
            p2 = 0
        else:
            p2y = p2[1]
        self.soma_y_dist = p1y - p2y
        #self.weight_by_distance = self.somadist / 600
        self.weight_by_distance = self.soma_y_dist * 0.002
        
    def calc_neigh_diam(self):
        diam = 0
        diam += self.item.diam
        diam += self.parent.item.diam
        for ch in self.children:
            diam += ch.item.diam
        count = 2 + len(self.children)
        return diam/count
        
    def calc_dist_to_trunk(self):
        node = self
        dist = 0
        while(not node.trunk):
            dist += node.dist_to_parent
            node = node.parent
        return dist
        
    def eucl_dist_to_node(self, node):
        if len(node.point) < 1:
            return 99999
        return Node.eucliddistance(self.point, node.point)
            
    def in_SR_range(self):
        if self.soma_y_dist >= 100 and self.soma_y_dist <= 300:
            return True
        return False
        
    def x_to_id(obj):
        #points = obj.item.sec.psection()['morphology']['pts3d']
        n = obj.item.sec.n3d()
        if(n <= 0):
            return -1
        pid = min(round(n*obj.x), n-1)
        return pid
        
    def getstr(self, level=0):
        pre = " " * level
        pstring = f"{pre}{self.item}[{self.level}]"
        return pstring
    
    def printtree(self, level=0):
        if os.environ["MILEDIDEBUG"] == '1':
            print(self.getstr(level))
        for child in self.children:
            child.printtree(level+1)
    
    def getpoints(self):
        n = self.item.sec.n3d()
        points = []
        for i in range(n):
            point = [
                self.item.sec.x3d(i),
                self.item.sec.y3d(i),
                self.item.sec.z3d(i),
                self.item.sec.diam
            ]
            points.append(point)
        return points
        
    def getpoint(self):
        if(self.nodeid == -1):
            return []
        points = self.getpoints()
        p = points[self.nodeid]
        #return [p[0], p[1], p[2], self.level]
        return [p[0], p[1], p[2], p[3]]
    
    def gettreepoints(self):
        points = []
        cp = self.getpoint()
        if (len(cp) > 0):
            points = [cp]
        for child in self.children:
            cp = child.gettreepoints()
            if (len(cp) > 0):
                points += cp
        return points
    
    def getpointstox(self, x, y):
        points = self.getpoints()
        n = len(points)
        if(n <= 0):
            return []
        pid1 = min(round(n*y), n-1)
        pid2 = min(round(n*x), n-1)
        if pid1 > pid2:
            temp = pid1
            pid1 = pid2
            pid2 = temp
            
        p = []
        for i in range(pid1, pid2+1):
            p.append(points[i])
        return p
    
    def eucliddistance(p1, p2):
        x1 = p1[0]
        x2 = p2[0]
        y1 = p1[1]
        y2 = p2[1]
        z1 = p1[2]
        z2 = p2[2]
        
        ssum = (x2-x1)**2 + (y2-y1)**2 + (z2-z1)**2        
        return sqrt(ssum)
        
    def calcdistancefrompoints(points):
        dist = 0
        for pid in range(len(points)-1):
            dist += Node.eucliddistance(points[pid], points[pid+1])
        return dist
    
    def calcdistancetoparent(self):
        if(self.parent == None):
            return 0
        
        isself = self.parent.item.sec.same(self.item.sec)
        otherx = self.parent.x
        
        dist = 0
        if(not isself):
            otherx = 0
            points = self.parent.getpointstox(self.parent.x, 1)
            dist += Node.calcdistancefrompoints(points)
            
        points = self.getpointstox(otherx, self.x)
        dist += Node.calcdistancefrompoints(points)
        return dist
    
    def get_neighbour_dists(self, caller=None):
        nodes = []
        if caller != self.parent:
            if self.parent != None:
                nodes.append((self.parent, self.dist_to_parent))
        
        for child in self.children:
            if caller != child:
                if child != None:
                    nodes.append((child, child.dist_to_parent))
            
        return nodes
    
    def adjust_synapse_distances(self, caller=None):
        neighbours = self.get_neighbour_dists(caller)
        for neigh in neighbours:
            node = neigh[0]
            node_dist = neigh[1]
            node.syndist = min(node.syndist, node_dist + self.syndist)
            node.adjust_synapse_distances(self)
                
    def foreach(self, func, reverse=False):
        if(not reverse):
            func(self)
        for child in self.children:
            child.foreach(func)
        if(reverse):
            func(self)

    def function_for_neighbors_in_distance(self, function, dist, caller=None):
        neighbours = self.get_neighbour_dists(caller)
        for neigh in neighbours:
            node = neigh[0]
            node_dist = neigh[1]
            if node_dist < dist:
                function(node)
                node.function_for_neighbors_in_distance(function, dist-node_dist, self)

            
    def __str__(self):
        return str(self.item)

    def __repr__(self):
        return str(self.item)


def buildNodes(soma):
    que = deque()
    allitems = {}

    prev = None
    que.append(soma)
    root = Node(soma(0.5))
    allitems[str(soma(0.5))] = root

    while (len(que) > 0):
        item = que.pop()
        que.extend(item.children())
        for itemyi in list(item.allseg()):
            itemy = str(itemyi)
            tps = itemyi.sec.trueparentseg()
            if (tps == None):
                if(itemyi.x != 0.5):
                    allitems[itemy] = Node(itemyi, root, parent=allitems[str(itemyi.sec(0.5))])
            else:
                if itemyi.sec.same(prev.item.sec):
                    allitems[itemy] = Node(itemyi, root, parent = prev)
                else:
                    allitems[itemy] = Node(itemyi, root, parent = allitems[str(tps)])

            prev = allitems[itemy]
    return allitems

def get_nodes_in_radius(nodes, center, radius=20):
    nodes_in_radius = []
    for n1 in nodes:
        if center.eucl_dist_to_node(n1) < radius:
            nodes_in_radius.append(n1)
        
    return nodes_in_radius


def select_random_synapse(allnodes, somadist_L=None, trunkdist_L=None, syn_diam=None, seed=None, centerneuron=None):
    if seed == None:
        random.seed()
    else:
        random.seed(seed)

    ''' Select nodes that are:
    1. in Stratum Radiatum
    2. not trunk (diameter > x)
    3. not edge of segment (0 or 1), since it doesn't allow us to play synapse on there
    '''
    possible_synapse_nodes = [node for node in allnodes.values() if node.in_SR_range() and not node.trunk and not node.x == 0 and not node.x == 1]
    possible_synapse_nodes_distance = possible_synapse_nodes
    
    if somadist_L is not None:
        possible_synapse_nodes_distance = [x for x in possible_synapse_nodes_distance if abs(x.soma_y_dist - somadist_L) < 20]
        if len(possible_synapse_nodes_distance) < 1:
            possible_synapse_nodes_distance = sorted(possible_synapse_nodes, key=lambda x: abs(x.soma_y_dist - somadist_L))[:5]
        if (os.environ.get("NRN_DEBUG") == "1"):
            print("1. Selected distance from soma nodes: ")
            for x, dist in [(x, x.soma_y_dist) for x in possible_synapse_nodes_distance]:
                print(f" {f'{str(x)}:'.ljust(42)}  {f'{dist:.2f}'.ljust(5)}")
        #centernode = possible_synapse_nodes[0]
    #    else:


    if trunkdist_L is not None:
        possible_synapse_nodes_distance_trunk = [x for x in possible_synapse_nodes_distance if abs(x.dist_to_trunk - trunkdist_L) < 20]
        if len(possible_synapse_nodes_distance_trunk) < 1:
            possible_synapse_nodes_distance = sorted(possible_synapse_nodes_distance, key=lambda x: abs(x.dist_to_trunk - trunkdist_L))[:5]
        else:
            possible_synapse_nodes_distance = possible_synapse_nodes_distance_trunk
        if (os.environ.get("NRN_DEBUG") == "1"):
            print("2. Selected distance from trunk nodes: ")
            for x, dist in [(x, x.dist_to_trunk) for x in possible_synapse_nodes_distance]:
                print(f" {f'{str(x)}:'.ljust(42)}  {f'{dist:.2f}'.ljust(5)}")
        

    if syn_diam is not None:
        possible_synapse_nodes_distance_diam = [x for x in possible_synapse_nodes_distance if abs(x.calc_neigh_diam() - syn_diam) < 20]
        if len(possible_synapse_nodes_distance_diam) < 1:
            possible_synapse_nodes_distance = sorted(possible_synapse_nodes_distance, key=lambda x: abs(x.calc_neight_diam() - syn_diam))[:5]
        else:
            possible_synapse_nodes_distance = possible_synapse_nodes_distance_diam
        if (os.environ.get("NRN_DEBUG") == "1"):
            print("3. Selected dendrite diameters: ")
            for x, dist in [(x, x.calc_neigh_diam()) for x in possible_synapse_nodes_distance]:
                print(f" {f'{str(x)}:'.ljust(42)}  {f'{dist:.2f}'.ljust(5)}")


            
    centernodes = random.sample(possible_synapse_nodes_distance, 1)

    if centerneuron is not None:
        centernodes = [centerneuron]
    

    if os.environ["MILEDIDEBUG"] == '1':
        print("Final cluster: ")
        for x in [x for x in centernodes]:
            print(f" {x}")
        print("Center node: ")
        print(f" {centernodes[0]}")
    
    return (centernodes, centernodes[0])


def create_random_cluster(allnodes, syn_density, syn_count, somadist_L=None, trunkdist_L=None, syn_diam=None, seed=None, centerneuron=None):
    if syn_count == 1:
        return select_random_synapse(allnodes, somadist_L, trunkdist_L, syn_diam, seed, centerneuron)
    if seed == None:
        random.seed()
    else:
        random.seed(seed)


    possible_synapse_nodes = [node for node in allnodes.values() if node.in_SR_range() and not node.trunk and not node.x == 0 and not node.x == 1]
    possible_synapse_nodes_distance = possible_synapse_nodes
    
    if somadist_L is not None:
        possible_synapse_nodes_distance = [x for x in possible_synapse_nodes_distance if abs(x.soma_y_dist - somadist_L) < 20]
        if len(possible_synapse_nodes_distance) < 1:
            possible_synapse_nodes_distance = sorted(possible_synapse_nodes, key=lambda x: abs(x.soma_y_dist - somadist_L))[:5]


    if trunkdist_L is not None:
        possible_synapse_nodes_distance_trunk = [x for x in possible_synapse_nodes_distance if abs(x.dist_to_trunk - trunkdist_L) < 20]
        if len(possible_synapse_nodes_distance_trunk) < 1:
            possible_synapse_nodes_distance = sorted(possible_synapse_nodes_distance, key=lambda x: abs(x.dist_to_trunk - trunkdist_L))[:5]
        else:
            possible_synapse_nodes_distance = possible_synapse_nodes_distance_trunk
        

    if syn_diam is not None:
        possible_synapse_nodes_distance_diam = [x for x in possible_synapse_nodes_distance if abs(x.calc_neigh_diam() - syn_diam) < 20]
        if len(possible_synapse_nodes_distance_diam) < 1:
            possible_synapse_nodes_distance = sorted(possible_synapse_nodes_distance, key=lambda x: abs(x.calc_neight_diam() - syn_diam))[:5]
        else:
            possible_synapse_nodes_distance = possible_synapse_nodes_distance_diam


    centernodes = random.sample(possible_synapse_nodes_distance, 1)

    if centerneuron is not None:
        centernodes = [centerneuron]
        
    radius = 30
    selected_count = 4
    
    circle_nodes = get_nodes_in_radius(possible_synapse_nodes, centernodes[0], radius)
    #centernodes += random.sample(circle_nodes, selected_count)
    
    cluster = create_cluster(possible_synapse_nodes, centernodes, syn_density, syn_count, circle_nodes, somadist_L)
    return (cluster, centernodes[0])


def create_cluster(allnodes, centernodes, syn_density, syn_count, circle_nodes, somadist_L):
    for node in allnodes:
        node.syndist = 999999
    
    distance = 1/syn_density
    if type(centernodes) is not list:
        cluster_synapses = [centernodes]
    else:
        cluster_synapses = centernodes
        
    syn_count -= len(cluster_synapses)
    
    for node in cluster_synapses:
        node.syndist = 0

    for node in cluster_synapses:
        node.adjust_synapse_distances()
        
    while syn_count > 0:
        syndistances = [[node, node.syndist, centernodes[0].eucl_dist_to_node(node) + abs(node.soma_y_dist - somadist_L) * 5, abs(distance - node.syndist)] for node in allnodes if node.in_SR_range() and node.syndist > 0 and node.syndist > (0.9*distance)]
        #syndistances += [[node, 0, centernodes[0].eucl_dist_to_node(node), abs(distance - node.syndist)] for node in allnodes if node.syndist > 0]
        #print(syndistances)
        
        syndistances = sorted(syndistances, key=lambda x: x[3])
        if syndistances[0][1] > (distance * 2.0):
            syndistances = sorted(syndistances, key=lambda x: x[2])
        
        added_synapses = random.sample(syndistances[0:5], 1)
        for added_synapse in [x[0] for x in added_synapses]:
            if (syn_count <= 0):
                break
            added_synapse.syndist = 0
            cluster_synapses.append(added_synapse)
            syn_count -= 1
            cluster_synapses[-1].adjust_synapse_distances()
    
    return cluster_synapses