from matplotlib import colors

import numpy as np
import matplotlib.pyplot as plt
import json

# where is the neural network stored?
neural_net_file = 'neuralnet.json'

# open the file that contains the neural net and dump its contents in a
# numpy array
with open(neural_net_file, 'r') as f:
    modules = json.load(f)

# how many modules do you want in the connectivity matrix?
net_size = 12

# initialize connectivity matrix
connectivity_matrix = np.zeros((net_size, net_size))

# create a map of module names to numbers to index the connectivity matrix
#dict = {'':''}
#list_of_modules = []
#index = 0
#for m in modules.keys():
#    dict.update({m: index})         # create a numeric index for each dictionary key
#    list_of_modules.append(m)       # create a list of modules to use as labels in heatmap
#    index = index + 1
dict = {'mgns': 0,
        'ea1u': 1,
        'ia1u': 1,
        'ea1d': 2,
        'ia1d': 2,
        'ea2u': 3,
        'ia2u': 3,
        'ea2c': 4,
        'ia2c': 4,
        'ea2d': 5,
        'ia2d': 5,
        'estg': 6,
        'istg': 6,
        'exfs': 7,
        'infs': 7,
        'efd1': 8,
        'ifd1': 8,
        'efd2': 9,
        'ifd2': 9,
        'exfr': 10,
        'infr': 10,
        'attv': 11,
        'atts': 11

# declare list of modules in the same order of appearance as in the dictionary above:
list_of_modules = ['MGN', 'A1u', 'A1d', 'A2u', 'A2c', 'A2d', 'STG', 'FS', 'D1', 'D2', 'FR', 'Att']

# Traverse the modules data structure to find all of the connection weights and their
# destinations, and sum them up to fill out the connectivity matrix
for m in modules.keys():
    for x in range(modules[m][0]):
        for y in range(modules[m][1]):
            # we are going to do the following only for those units in the network that
            # have weights that project to other units elsewhere
            for w in modules[m][8][x][y][4]:
                # First, find outgoing weights for all destination units and (except
                # for those that do not
                # have outgoing weights, in which case do nothing) compute weight * value
                # at destination units
                dest_module = w[0]
                x_dest = w[1]
                y_dest = w[2]
                weight = w[3]

                # Assign the weight just found to its corresponding place in the connectivity
                # matrix, except in connection weights to itself (excluded from the heatmap
                # shown, as we only want cortico-cortical connections. For excitatory connections,
                # assign 1; for inhibitory connections, assign -1.
                if dict[m] == dict[dest_module]:
                    connectivity_matrix[dict[m]][dict[dest_module]] = 0.0
                    if dest_module[0] == 'e':
                        connectivity_matrix[dict[m]][dict[dest_module]] =  1.0
                    elif dest_module[0] == 'i':
                        connectivity_matrix[dict[m]][dict[dest_module]] = -1.0

fig, ax = plt.subplots()
cmap = colors.ListedColormap(['blue', 'white', 'red'])
heatmap = ax.pcolor(connectivity_matrix,cmap=cmap, alpha=0.7)
#cax = ax.pcolor(connectivity_matrix, cmap=cmap, vmin=-1, vmax=1)

# Format
fig = plt.gcf()
#fig.set_size_inches(8, 11)

# turn off the frame

# put the major ticks at the middle of each cell
ax.set_yticks(np.arange(connectivity_matrix.shape[0]) + 0.5, minor=False)
ax.set_xticks(np.arange(connectivity_matrix.shape[1]) + 0.5, minor=False)

# want a more natural, table-like display

#plt.xticks(range(0, len(list_of_modules), 1) )
#plt.yticks(range(0, len(list_of_modules), 1) )

ax.set_xticklabels(list_of_modules, minor=False)
ax.set_yticklabels(list_of_modules, minor=False)

# rotate the


# Turn off all the ticks
ax = plt.gca()

for t in ax.xaxis.get_major_ticks():
    t.tick1On = False
    t.tick2On = False
for t in ax.yaxis.get_major_ticks():
    t.tick1On = False
    t.tick2On = False

print dict

print connectivity_matrix.size

print len(list_of_modules)
