import numpy as np
import networkx as nx
from networkx.algorithms import bipartite, community

import matplotlib.pyplot as plt

def pr_full_to_pr_sparse(pr_full, Ntot, cluster_size):
    #account the average num of connections per neuron with the full pr
    avg_num_connections = np.mean([(np.random.rand(1)<pr_full)*Ntot for i in np.arange(Ntot)])
    #adjust the pr to account for the cluster size
    pr_sparse = avg_num_connections/cluster_size
    #clip the pr_sparse to 1
    pr_sparse = np.clip(pr_sparse, 0, 1)
    #assert pr_sparse <= 1, "pr_sparse is greater than 1"

    #also assert it results in the same number of average connections
    #avg_num_connections_2 = np.mean([(np.random.rand(1)<pr_sparse)*cluster_size for i in np.arange(Ntot)])
    #assert np.isclose(avg_num_connections, avg_num_connections_2, atol=5), "pr_sparse does not result in the same number of average connections"

    return pr_sparse


def gaussian_clustered(Ntot=1000, NE=500, NI=500, cluster_size=10, variance=5, p_in=0.25, p_out=0.25, no_EE=True, no_II=True, no_IE=False, no_EI=False):
    if NE+NI != Ntot:
        raise ValueError("NE and NI must sum to Ntot")
    #generate the 
    directed_graph = nx.gaussian_random_partition_graph(Ntot, cluster_size, variance, p_in, p_out)
    #bipartite_graph = community.kernighan_lin_bisection(directed_graph, max_iter=1000)

    #get the edges
    nodes = {}
    E_count = 0
    I_count = 0
    for graph in directed_graph.graph['partition']:
        num_nodes = len(graph)
        for i, node in enumerate(graph):
            if i < num_nodes//2 and E_count < NE:
                #name the node in the directed graph, these will be E_{i}
                directed_graph.nodes[node]['name'] = 'E_'+str(E_count)
                nodes['E_'+str(E_count)] = node
                E_count += 1
            elif I_count < NI:
                directed_graph.nodes[node]['name'] = 'I_'+str(I_count)
                nodes['I_'+str(I_count)] = node
                I_count += 1

    
    #invert the dict
    inv_nodes = {v: k for k, v in nodes.items()}
    #get the e -> I edges
    #get the I -> E edges
    IE_edges = []
    EI_edges = []
    EE_edges = []
    II_edges = []
    known_edges = inv_nodes.keys()
    for node in inv_nodes:
        #get the neighbors of the node
        neighbors = list(directed_graph.neighbors(node))
        #get the type of the node
        node_type = inv_nodes[node][0]
        node_number = int(inv_nodes[node][2:])
        #loop through the neighbors
        for neighbor in neighbors:
            if neighbor not in known_edges:
                continue
            #get the type of the neighbor
            neighbor_type = inv_nodes[neighbor][0]
            neighbor_number = int(inv_nodes[neighbor][2:])
            #if the neighbor is an E
            if neighbor_type == 'E':
                #if the node is an I
                if node_type == 'I':
                    #add the node to the IE list
                    IE_edges.append([node_number, neighbor_number])
                #if the node is an E
                else:
                    #add the node to the EE list
                    EE_edges.append([node_number, neighbor_number])
            #if the neighbor is an I
            else:
                #if the node is an I
                if node_type == 'I':
                    #add the node to the II list
                    II_edges.append([node_number, neighbor_number])
                #if the node is an E
                else:
                    #add the node to the EI list
                    EI_edges.append([node_number, neighbor_number])
    
    #order the connection matrix
    IE_connection_array = np.array(IE_edges)
    EI_connection_array = np.array(EI_edges)
    EE_connection_array = np.array(EE_edges)
    II_connection_array = np.array(II_edges)

    return EI_connection_array, IE_connection_array
   

def stochastic_block(Ntot=1000, NE=500, NI=500, NUM_CLUSTER=15, p_EE=0.00, p_II=0.00, p_IE=0.04, p_EI=0.02):
    #make a list of nodes, E and I by splitting the total number of nodes
    nodes = np.arange(Ntot)
    clusters_tot = NUM_CLUSTER*2
    ne_per_cluster = NE//NUM_CLUSTER
    ni_per_cluster = NI//NUM_CLUSTER
    #make a list of clusters sizes
    ne_cluster_sizes = np.full(NUM_CLUSTER, ne_per_cluster)
    ni_cluster_sizes = np.full(NUM_CLUSTER, ni_per_cluster)
    cluster_sizes = np.hstack((ne_cluster_sizes, ni_cluster_sizes))
    #make a square matrix of the probabilities
    prob_matrix = np.zeros((clusters_tot, clusters_tot))
    #fill the matrix with the probabilities
    for row in np.arange(clusters_tot):
        #if row is an E
        row_type = 'E' if row < NUM_CLUSTER else 'I'
        if row_type == 'E':
            prob_matrix[row, :] = np.array([p_EE if col < NUM_CLUSTER else p_EI for col in np.arange(clusters_tot)])
        else:
            prob_matrix[row, :] = np.array([p_EI if col < NUM_CLUSTER else p_II for col in np.arange(clusters_tot)])
    
    #multiply the matrix by uniform random numbers
    #prob_matrix = np.multiply(prob_matrix, np.random.rand(clusters_tot, clusters_tot))
    #feed to the stochastic block model
    directed_graph = nx.stochastic_block_model(cluster_sizes, prob_matrix, directed=True)

    num_computed_ne = np.sum([len(graph) for i, graph in enumerate(directed_graph.graph['partition']) if i < NUM_CLUSTER])
    num_computed_ni = np.sum([len(graph) for i, graph in enumerate(directed_graph.graph['partition']) if i >= NUM_CLUSTER])


    #get the edges
    IE_edges = []
    EI_edges = []
    EE_edges = []
    II_edges = []

    for i, graph in enumerate(directed_graph.graph['partition']):
        if i < NUM_CLUSTER:
            node_type = 'E'
        else:
            node_type = 'I'
        for node in graph:
            neighbors = list(directed_graph.neighbors(node))
            for neighbor in neighbors:
                if neighbor in graph:
                    if node_type == 'E':
                        EE_edges.append([node, neighbor - NE])
                    else:
                        II_edges.append([node  - NE, neighbor])
                else:
                    if node_type == 'E':
                        EI_edges.append([node, neighbor - num_computed_ne])
                    else:
                        IE_edges.append([node - num_computed_ne, neighbor])

      #order the connection matrix
    IE_connection_array = np.array(IE_edges)
    EI_connection_array = np.array(EI_edges)
    EE_connection_array = np.array(EE_edges)
    II_connection_array = np.array(II_edges)

    #drop some of the connections larger than NE
    IE_connection_array = IE_connection_array[IE_connection_array[:,0] < NE, :]
    EI_connection_array = EI_connection_array[EI_connection_array[:,0] < NE, :]

    print(IE_connection_array.min(axis=0))
    print(EI_connection_array.min(axis=0))
    return EI_connection_array, IE_connection_array
   
    



def custom_circ_graph(Ntot=1000, EI_ratio=0.5,NUM_CLUSTERS=5, p_ei=0.02, p_ie=0.02, inter_cluster_num=2, return_clust_labels=False):
    N = int(Ntot*EI_ratio)
    DIV  = int((N/NUM_CLUSTERS))
    number_idx = np.arange(NUM_CLUSTERS)
    number_idx = np.ravel(np.repeat(number_idx, DIV+1))
    #if the len of number_idx is not N, add some more, or drop random rows
    if len(number_idx) < N:
        print("adding more")
        number_idx = np.append(number_idx, np.random.choice(number_idx, size=N-len(number_idx), replace=True))
    elif len(number_idx) > N:
        print("dropping some")
        number_idx = np.delete(number_idx, np.random.choice(np.arange(len(number_idx)), size=len(number_idx)-N, replace=False))
    
    if NUM_CLUSTERS == 1:
        #if the user wants a single cluster, just return a random connection matrix
        return random_connect_matrix(Ntot=1000, EI_ratio=0.5, p_ei=0.02, p_ie=0.02)
    
    #decide which nodes to connect
    unique_idx = np.unique(number_idx)
    unique_idx_pairs = [np.random.choice(unique_idx[unique_idx!=i], size=(len(unique_idx))) for i in np.arange(inter_cluster_num)]
    unique_idx_pairs = np.hstack((unique_idx.reshape(-1,1), np.array(unique_idx_pairs).T))
    #adjust the probability of connections
    p_ei = pr_full_to_pr_sparse(p_ei, Ntot//2, DIV*inter_cluster_num)
    p_ie = pr_full_to_pr_sparse(p_ie, Ntot//2, DIV*inter_cluster_num)

    EI_cluster_labels = np.copy(number_idx)

    #connect ids with idxs with same number
    connection_array = [0, 0]
    for i in np.arange(N):
        paired_idx = unique_idx_pairs[unique_idx_pairs[:,0] == number_idx[i], :][0]
        EI_cluster_labels[i] = np.where(unique_idx==number_idx[i])[0][0]
        for pair in paired_idx:
            downstream = np.where(number_idx==pair)[0]
            #Each connection only has a 50% chance of being made, so drop half to none of them
            #possible prob of p_EI
            count = np.count_nonzero(np.where(np.random.rand(N)<=p_ei, 1, 0))
            downstream = np.random.choice(downstream, size=count, replace=True)

            connection = np.vstack((np.full(len(downstream), i), downstream)).T
            connection_array = np.vstack((connection_array, connection))
    connection_array = np.array(connection_array).reshape(-1,2)[1:, :]
    EI_connection_array = np.copy(connection_array)
    #create a new connection array for the other way around
    unique_idx_pairs = [np.random.choice(unique_idx, size=(len(unique_idx))) for i in np.arange(inter_cluster_num)]
    unique_idx_pairs = np.hstack((unique_idx.reshape(-1,1), np.array(unique_idx_pairs).T))
    #print(unique_idx_pairs)


    IE_cluster_labels = np.copy(number_idx) + NUM_CLUSTERS

    #connect ids with idxs with same number
    connection_array = [0, 0]
    for i in np.arange(N):
        paired_idx = unique_idx_pairs[unique_idx_pairs[:,0] == number_idx[i], :][0]
        for pair in paired_idx:
            downstream = np.where(number_idx==pair)[0]
            #Each connection only has a 50% chance of being made, so drop half to none of them
            #possible prob of p_EI
            count = np.count_nonzero(np.where(np.random.rand(N)<=p_ie, 1, 0))
            downstream = np.random.choice(downstream, size=count, replace=True)

            connection = np.vstack((np.full(len(downstream), i), downstream)).T
            connection_array = np.vstack((connection_array, connection))
    connection_array = np.array(connection_array).reshape(-1,2)[1:, :]
    IE_connection_array = np.copy(connection_array)
    #if the IE_connection_array or EI_connection_array is empty, just return two zeros
    if len(IE_connection_array) == 0:
        IE_connection_array = np.zeros((1,2)).astype(int)
    if len(EI_connection_array) == 0:
        EI_connection_array = np.zeros((1,2)).astype(int)

    if return_clust_labels:
        return EI_connection_array, IE_connection_array, {'EI_clust': EI_cluster_labels, 'IE_clust': IE_cluster_labels}

    return EI_connection_array, IE_connection_array

def random_connect_matrix(Ntot=1000, EI_ratio=0.5, p_ei=0.02, p_ie=0.02):
    
    N = int(Ntot*EI_ratio)

    #each neuron has a p_ei chance of connecting to another neuron etc
    #make a matrix of random numbers
    EI_matrix = []
    IE_matrix = []
    for i in np.arange(N):
        downstream = np.where(np.random.rand(N)<=p_ei, 1, 0)
        connection = np.vstack((np.full(len(downstream), i), np.where(downstream==1)[0])).T
        EI_matrix.append(connection)
    EI_matrix = np.vstack(EI_matrix)
    IE_matrix = []
    for i in np.arange(N):
        downstream = np.where(np.random.rand(N)<=p_ie, 1, 0)
        connection = np.vstack((np.full(len(downstream), i), np.where(downstream==1)[0])).T
        IE_matrix.append(connection)
    IE_matrix = np.vstack(IE_matrix)
    EI_connection_array = EI_matrix
    IE_connection_array = IE_matrix


    return EI_connection_array, IE_connection_array


if __name__=="__main__":
    custom_circ_graph(NUM_CLUSTERS=27, inter_cluster_num=16, p_ei=0.04, p_ie=0.0001)