import copy
import json
import numpy as np
import pyx
import os

from helpers import original_data_path, population_labels
from multiarea_model import MultiAreaModel
from multiarea_model.multiarea_helpers import vector_to_dict, create_vector_mask
from plotcolors import myblue, myred
from scipy.signal import find_peaks_cwt
from scipy.optimize import minimize
from matplotlib.colors import LogNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable

import matplotlib.pyplot as pl
from matplotlib import gridspec
from matplotlib import rc_file
rc_file('plotstyle.rc')

cmap = pl.cm.coolwarm
cmap = cmap.from_list('mycmap', [myblue, 'white', myred], N=256)

"""
Figure layout
"""
width = 7.0866
panel_wh_ratio = 0.7 * (1. + np.sqrt(5)) / 2.  # golden ratio

nrows = 3
ncols = 3
height = width / panel_wh_ratio * float(nrows) / ncols
height = 6.375
pl.rcParams['figure.figsize'] = (width, height)

fig = pl.figure()
axes = {}


gs1 = gridspec.GridSpec(3, 3)
gs1.update(left=0.055, right=0.9225, top=0.95,
           bottom=0.07, wspace=0.4, hspace=0.4)

axes['A'] = pl.subplot(gs1[0, 0])
axes['B'] = pl.subplot(gs1[0, 1])
axes['C'] = pl.subplot(gs1[0, 2])
axes['D'] = pl.subplot(gs1[1, 0])
axes['E'] = pl.subplot(gs1[1, 1])
axes['F'] = pl.subplot(gs1[1, 2])

axes['G'] = pl.subplot(gs1[2, 0])
pos = axes['G'].get_position()
axes['G2'] = pl.axes([pos.x1 - 0.08 + 0.5, pos.y0+0.05,
                      0.1,
                      pos.y1 - pos.y0])


fd = {'fontsize': 10, 'weight': 'bold', 'horizontalalignment':
      'left', 'verticalalignment': 'bottom'}


for label in ['C', 'D', 'E', 'F']:
    label_pos = [-0.15, 1.04]
    pl.text(label_pos[0], label_pos[1], r'\bfseries{}' + label,
            fontdict=fd, transform=axes[label].transAxes)

pl.text(-0.15, 1.0, r'\bfseries{}' + 'A',
        fontdict=fd, transform=axes['A'].transAxes)

pl.text(-0.05, 1.0, r'\bfseries{}' + 'G',
        fontdict=fd, transform=axes['G'].transAxes)

"""
Load data
"""

"""
Create MultiAreaModel instance to have access to data structures
"""
M = MultiAreaModel({})

LOAD_ORIGINAL_DATA = True
if LOAD_ORIGINAL_DATA:
    label = '99c0024eacc275d13f719afd59357f7d12f02b77'
    data_path = original_data_path
else:
    from network_simulations import init_models
    from config import data_path
    models = init_models('Fig7')
    label = models[0].simulation.label


rate_time_series = {}
for area in M.area_list:
    fn = os.path.join(data_path, label,
                      'Analysis',
                      'rate_time_series_full',
                      'rate_time_series_full_{}.npy'.format(area))
    rate_time_series[area] = np.load(fn)

fn = os.path.join(data_path, label,
                  'Analysis',
                  'rate_time_series_full',
                  'rate_time_series_full_Parameters.json')
with open(fn, 'r') as f:
    rate_time_series['Parameters'] = json.load(f)

cross_correlation = {}
for area in M.area_list:
    cross_correlation[area] = {}
    for area2 in M.area_list:
        fn = os.path.join(data_path, label,
                          'Analysis',
                          'cross_correlation',
                          'cross_correlation_{}_{}.npy'.format(area, area2))
        cross_correlation[area][area2] = np.load(fn)

fn = os.path.join(data_path, label,
                  'Analysis',
                  'cross_correlation',
                  'cross_correlation_time.npy')
cross_correlation['time'] = np.load(fn)


# Correlation peak
def correlation_peak(area, area2, t_min, t_max, width=[5.], max_lag=100.):
    t = cross_correlation['time']
    cross = cross_correlation[area][area2]

    # First look for maxima in cc
    cc = cross[0][1]
    indices = np.where(np.logical_and(t > -max_lag, t < max_lag))
    maxima = find_peaks_cwt(cc[indices], np.array(width))
    max_times = t[indices][maxima]
    if len(maxima) > 0:
        if area == area2:
            selected_max = np.where(t[indices] == 0)[0]
            selected_max_time = 0.
        else:
            selected_max = maxima[np.argmax(cc[indices][maxima])]
            selected_max_time = max_times[np.argmax(cc[indices][maxima])]
        selected_max_value = cc[indices][selected_max]
    else:
        selected_max = np.nan
        selected_max_time = np.nan
        selected_max_value = 0.

    # 2nd, look for minima in cc, i.e., maxima in -cc
    cc = -1. * cross[0][1]
    minima = find_peaks_cwt(cc[indices], np.array(width))
    min_times = t[indices][minima]
    if len(minima) > 0:
        if area == area2:
            selected_min = np.where(t[indices] == 0)[0]
            selected_min_time = 0.
        else:
            selected_min = minima[np.argmax(cc[indices][minima])]
            selected_min_time = min_times[np.argmax(cc[indices][minima])]
        selected_min_value = cc[indices][selected_min]
    else:
        selected_min = np.nan
        selected_min_time = np.nan
        selected_min_value = 0.

    # Then select the one with the larger absolute value
    if abs(selected_max_value) > abs(selected_min_value):
        selected_peak = selected_max
        selected_peak_time = selected_max_time
    else:
        selected_peak = selected_min
        selected_peak_time = selected_min_time
 
    # selected_peak = selected_max
    # selected_peak_time = selected_max_time
    cc = cross[0][1]
    if area != 'MDP' and area2 != 'MDP':
        return selected_peak_time, cc[indices][selected_peak]
    else:
        return np.nan, np.nan


"""
Plotting
"""

"""
Panel A: Matrix plot of rate time series
"""
interval = (10750., 11250.)
x_ticks = np.array([10800, 11200])

tmin = interval[0]
tmax = interval[1]

i_min = int(tmin - rate_time_series['Parameters']['t_min'])
i_max = int(tmax - rate_time_series['Parameters']['t_min'])

areas = []
transitions = []
for area in M.area_list:
    rate = rate_time_series[area][i_min:i_max]
    indices = np.where(rate > 2. * np.mean(rate))
    if len(indices[0]) > 0:
        transitions.append(indices[0][0])
    else:
        transitions.append(rate.size - 1)
    areas.append(area)
areas = np.array(areas)
areas = areas[np.argsort(transitions)]
area_string_A = areas[0]
for area in areas[1:]:
    area_string_A += ' '
    area_string_A += area

transitions = np.sort(transitions)


matrix = np.array([])
time_interval = int(tmax - tmin)
for area in areas:
    rate = rate_time_series[area][i_min:i_max]
    matrix = np.append(matrix, rate / np.mean(rate))
matrix = matrix.reshape((len(areas), time_interval))


y_index = list(range(len(areas)))
y_index = [(a + 0.5) for a in y_index]
ytick_labels = areas
pl.yticks(y_index, ytick_labels)


ax1 = axes['A']
ax1.yaxis.set_ticks_position("left")
ax1.xaxis.set_ticks_position("bottom")
ax1.tick_params(axis='y', length=0.)

x_index = [np.where(np.arange(tmin, tmax, 1.) == i)[0][0] for i in x_ticks]
ax1.set_xticks(x_index)
ax1.set_xticklabels(x_ticks)

ax1.set_ylabel('Area')
ax1.yaxis.set_label_coords(-0.18, 0.5)
ax1.set_yticks(np.arange(32.) + 0.5)
# ax1.set_yticklabels(areas, size=8)
ax1.set_yticks([])
ax1.set_xlabel('Time (ms)', labelpad=-0.05)
im = ax1.pcolormesh(matrix, cmap=pl.get_cmap('inferno'), vmin=0., vmax=12.)
pl.colorbar(im, ax=ax1, ticks=[0., 10., 20.])
ax1.set_ylim((0., 32.))
ax1.text(1.3, 0.65, r'$\nu (\mathrm{spikes/s})$',
         rotation=90, transform=ax1.transAxes)


# ################### PANEL B ####################
"""
Panel B: Cross-correlation of 3 pairs of areas
"""

ax2 = axes['B']
ax2.spines['right'].set_color('none')
ax2.spines['top'].set_color('none')
ax2.yaxis.set_ticks_position("left")
ax2.xaxis.set_ticks_position("bottom")
pl.text(-0.1, 1.0, r'\bfseries{}' + 'B',
        fontdict=fd, transform=ax2.transAxes)

area = 'V1'
areas2 = [area, 'V2', 'FEF']
colors = ['0.0', '0.4', '0.6']

for i, area2 in enumerate(areas2):
    t = cross_correlation['time']
    cross = cross_correlation[area][area2]
    cc = cross[0][1]
    tp, pp = correlation_peak(
        area, area2, 500., 100500., max_lag=100.)

    ax2.plot(t, cc, color=colors[i])
    ax2.vlines(tp, -5000., pp, color=colors[i], linestyle='dashed')

ax2.set_xticks([-100., -50., 0., 50., 100.])
ax2.set_xlim((-100., 100.))
ax2.set_yticks([0., 20000., 100000.])
ax2.set_yticklabels([r'$0.$', r'$2\cdot10^4$', r'$10^5$'])
ax2.set_ylim((-5000., 35000.))
ax2.text(-130., 15000., r'$C(\tau)$', rotation=90)
ax2.set_xlabel(r'Time lag $\tau$ ($\mathrm{ms}$)', labelpad=-0.05)


print("Constructing CC Matrix")
cc_matrix = []
peak_matrix = []

cc_matrix = np.zeros((32, 32))
for i, area in enumerate(M.area_list):
    cc_list = []
    peak_list = []
    for j, area2 in enumerate(M.area_list):
        if cc_matrix[j][i] != 0.:
            cc_matrix[i][j] = -1. * cc_matrix[j][i]
        else:
            tp, peak = correlation_peak(area, area2,
                                        500., 100500., max_lag=100.)
            cc_matrix[i][j] = tp
            peak_list.append(peak)
    peak_matrix.append(peak_list)

peak_matrix = np.array(peak_matrix)
d = {'matrix': cc_matrix}

cc_matrix_masked = np.ma.masked_where(np.isnan(cc_matrix), cc_matrix)

"""
Panel C: Extremum matrix unsorted
"""
ax = axes['C']
ax.yaxis.set_ticks_position("none")
ax.xaxis.set_ticks_position("none")

ax.set_xlabel('Area B')
# ax.xaxis.set_label_coords(0.5, -0.2)

ax.set_ylabel('Area A')
ax.set_xlim((0, 32))
ax.set_ylim((0, 32))

ax.set_aspect(1. / ax.get_data_ratio())
vlim = np.max(np.abs(cc_matrix_masked))
cmap.set_bad('0.5')
im = ax.pcolormesh(cc_matrix_masked[::-1], cmap=cmap, vmin=-vlim, vmax=vlim)

area_string_C = M.area_list[0]
for area in M.area_list[1:]:
    area_string_C += ' '
    area_string_C += area

ax.set_xticks([])
ax.set_yticks([])

pl.colorbar(im, ax=ax, fraction=0.044, ticks=[-80, -40, 0, 40, 80])
ax.text(44.8, 27., r'Extremum time ($\mathrm{ms}$)', rotation=90)


def dev(i, j, hierarchy, cc_matrix):
    """
    Deviation function for the linear programming algorithm
    determining the hierarchy.
    """
    return (hierarchy[i] - hierarchy[j] - cc_matrix[i][j])


def hier_dev(hierarchy, cc_matrix):
    deviation = 0.
    for i in range(hierarchy.size):
        for j in range(hierarchy.size):
            deviation += (dev(i, j, hierarchy, cc_matrix)) ** 2
    return np.sqrt(deviation)


def create_hierarchy(cc_matrix, areas):
    """
    Determined the heirarchy for a given set of areas and their
    cross-correlation peak matrix.
    """
    res = minimize(hier_dev, np.random.rand(
        cc_matrix[0].size), args=(cc_matrix,))
    hierarchy = res['x']
    index_transformation = np.argsort(hierarchy)
    hierarchical_areas = copy.copy(areas)
    hierarchical_areas = np.array(hierarchical_areas)
    hierarchical_areas = hierarchical_areas[index_transformation]
    hierarchy = hierarchy[index_transformation]
    # Map hierarchy onto [0,1] interval
    hierarchy -= np.min(hierarchy)
    hierarchy /= np.max(hierarchy)
    return res, hierarchy, hierarchical_areas, index_transformation


def count_violations(hierarchy, cc_matrix):
    """
    Count the violations of a given hierarchy based on the given
    matrix of cross-correlation peaks.
    """

    violations = 0
    for i in range(hierarchy.size):
        for j in range(hierarchy.size):
            x1 = hierarchy[i] - hierarchy[j]
            x2 = cc_matrix[i][j]
            if np.sign(x1) != np.sign(x2):
                violations += 1

    return violations / 2.  # Divide by two to not count each pair double


print("Computing hierarchy")
area_list = np.array(M.area_list)
# We exclude area MDP because it does not receive connections from any
# other area and thus does not participate in cortico-cortical
# communication
ind_without_MDP = np.isfinite(cc_matrix[0])
cc_matrix_without_MDP = cc_matrix[ind_without_MDP][:, ind_without_MDP]
area_list_without_MDP = area_list[ind_without_MDP]
res, hierarchy, hierarchical_areas, index_transformation = create_hierarchy(
    cc_matrix_without_MDP, area_list_without_MDP)
hierarchy_to_area_list = []

for area in area_list:
    if area not in ['MDP']:
        hierarchy_to_area_list.append(
            np.where(hierarchical_areas == area)[0][0])

# Export hierarchy to csv
with open('Fig7_temporal_hierarchy.csv', 'w') as f:
    for hier, area in zip(hierarchy, hierarchical_areas):
        f.write(area + ',' + str(hier) + '\n')


"""
Panel D: Extremum matrix sorted
"""

ax = axes['D']
cc_matrix_hier = cc_matrix_without_MDP[index_transformation][:,
                                                             index_transformation]
# Add area MDP to the matrix for plotting purposes
cc_matrix_hier = np.insert(cc_matrix_hier, 0, cc_matrix[14][:-1], axis=0)
cc_matrix_hier = np.insert(cc_matrix_hier, 0, cc_matrix[14], axis=1)

cc_matrix_hier_masked = np.ma.masked_where(
    np.isnan(cc_matrix_hier), cc_matrix_hier)
ax.yaxis.set_ticks_position("none")
ax.xaxis.set_ticks_position("none")
ax.set_xlabel('Area B')
# ax.xaxis.set_label_coords(0.5, -0.2)
ax.set_ylabel('Area A')
# ax.yaxis.set_label_coords(-0.17, 0.5)


ax.set_xlim((0, 32))
ax.set_ylim((0, 32))

ax.set_aspect(1. / ax.get_data_ratio())
im = ax.pcolormesh(
    cc_matrix_hier_masked[:, ::-1], cmap=cmap, vmin=-vlim, vmax=vlim)

area_string_D = hierarchical_areas[0]
for area in hierarchical_areas[1:]:
    area_string_D += ' '
    area_string_D += area

ax.set_xticks([])
ax.set_yticks([])

pl.colorbar(im, ax=ax, fraction=0.044, ticks=[-80, -40, 0, 40, 80])
ax.text(43.3, 27., r'Extremum time ($\mathrm{ms}$)', rotation=90)

"""
Eigenvalue spectrum and eigenvector projection
"""
conn_params = {'replace_non_simulated_areas': 'het_poisson_stat',
               'g': -11.,
               'K_stable': '../../K_stable.npy',
               'fac_nu_ext_TH': 1.2,
               'fac_nu_ext_5E': 1.125,
               'fac_nu_ext_6E': 1.41666667,
               'cc_weights_I_factor': 2.,
               'cc_weights_factor': 1.9,
               'av_indegree_V1': 3950.}
input_params = {'rate_ext': 10.}
neuron_params = {'V0_mean': -150.,
                 'V0_sd': 50.}
network_params = {'connection_params': conn_params,
                  'neuron_params': neuron_params}

theory_params = {'T': 50.,
                 'dt': 0.1,
                 'initial_rates': 'random_uniform',
                 'initial_rates_iter': 15}

M = MultiAreaModel(network_params, theory=True, simulation=False,
                   theory_spec=theory_params)
pops, rates_full = M.theory.integrate_siegert()
# Here, pick a calculation that converges to the LA state
ana_rates = rates_full[12][:, -1]
lambda_max, slope, slope_sigma, G, EV = M.theory.lambda_max(ana_rates, full_output=True)

"""
Panel E: Eigenvalues
"""
ax = axes['E']
ax.set_frame_on(False)
ax.set_xticks([])
ax.set_yticks([])

pos = ax.get_position()

# Real part < 0
ax0 = pl.axes([pos.x0 + 0.02,
               pos.y0,
               0.6 * (pos.x1 - pos.x0),
               pos.y1 - pos.y0])
ax0.spines['right'].set_color('none')
ax0.spines['top'].set_color('none')

ax0.plot(np.real(EV[0]),
         np.imag(EV[0]), '.')

ax0.set_xlabel(r'$\mathrm{Re}(\lambda_i)$')
ax0.xaxis.set_label_coords(1., -0.2)
ax0.yaxis.set_label_coords(-0.15, 0.5)
ax0.set_ylabel(r'$\mathrm{Im}(\lambda_i)$')
ax0.yaxis.set_label_coords(-0.17, 0.5)

ax0.vlines(1., -3., 3., lw=0.9)
ax0.set_xlim((-20., 0.1))
kwargs = dict(transform=ax.transAxes, color='k', clip_on=False, lw=1.)
ax0.plot([0.7, 0.72], [-0.05, 0.03], **kwargs)
ax0.plot([0.73, 0.75], [-0.05, 0.03], **kwargs)

# Real part > 0
ax1 = pl.axes([pos.x0 + 0.18,
               pos.y0,
               0.35 * (pos.x1 - pos.x0),
               pos.y1 - pos.y0])
ax1.spines['right'].set_color('none')
ax1.spines['left'].set_color('none')
ax1.spines['top'].set_color('none')
ax1.set_yticks([])
ax1.plot(np.real(EV[0]),
         np.imag(EV[0]), '.')
critical_eval = EV[0][np.argsort(np.real(EV[0]))[-1]]
ax1.plot(np.real(critical_eval),
         np.imag(critical_eval), '.', color=myred)

ax1.vlines(1., -3., 3., lw=0.7)
ax1.set_xlim((0., 1.))

"""
Panel F: Projection of critical eigenvector onto network
"""
ax = axes['F']
pos = ax.get_position()
divider = make_axes_locatable(ax)
ax_cb = pl.axes([pos.x1,
                 pos.y0,
                 0.02,
                 pos.y1 - pos.y0])

ax_cb.set_frame_on(False)
ax_cb.set_xticks([])
ax_cb.set_yticks([])

critical_eigenvector = np.real(EV[1][:, np.argsort(np.real(EV[0]))[-1]])
r = vector_to_dict(critical_eigenvector, area_list, M.structure)

ev_matrix = np.zeros((8, 32))
for i, area in enumerate(area_list):
    vm = create_vector_mask(M.structure, areas=[area])
    r = critical_eigenvector[vm]
    if area == 'TH':
        r = np.insert(r, 2, np.zeros(2))
    ev_matrix[:, i] = r

ind = [list(area_list).index(area) for area in hierarchical_areas[::-1]]

im = ax.pcolormesh(np.abs(ev_matrix[::-1][:, ind]), cmap=pl.get_cmap('inferno'),
                   norm=LogNorm(vmin=1e-3, vmax=1e0))

area_string_F = area_list[ind][0]
for area in area_list[ind][1:]:
    area_string_F += ' '
    area_string_F += area

ax.set_xlabel('Area')
ax.set_xticks([])
ax.set_yticklabels(population_labels[::-1])
ax.set_yticks(np.arange(8.) + 0.5)
cb = pl.colorbar(im, ax=ax_cb, fraction=1.)
cb.set_ticks([0.001, 1.])
cb.set_ticklabels([r'$0.001$', r'$1$'])
cb.ax.tick_params(labelsize=8, length=0, rotation=0)
ax_cb.text(1.2, 0.8, 'critical eigenvector',
           rotation=90, transform=ax.transAxes)


"""
Create 100 surrogate matrices by shuffling the cross-correlation peak
matrix and measure the violations of the temporal hierarchy to judge
the significance of the temporal hierarchy.
"""
print("Surrogate matrices")
surrogate_matrix = copy.deepcopy(cc_matrix_without_MDP)
violation_list = []
np.random.seed(123)
for i in range(1):
    for j in range(len(area_list_without_MDP)):
        ind = np.extract(np.arange(len(area_list_without_MDP)) != i,
                         np.arange(len(area_list_without_MDP)))
        ind = np.arange(j, len(area_list_without_MDP))
        surr = surrogate_matrix[j][ind][np.random.shuffle(ind)]
        surrogate_matrix[j][ind] = surr
        surrogate_matrix[:, j][ind] = -1.*surr
    (surrogate_res, surrogate_hierarchy,
     surrogate_hierarchical_areas,
     sit) = create_hierarchy(surrogate_matrix, area_list_without_MDP)
    violation_list.append(count_violations(hierarchy, surrogate_matrix[sit][:, sit]))

print("Mean violations of surrogates: ", np.mean(violation_list), " +- ", np.std(violation_list))

print(("Violations of hierarchy: ", count_violations(
    hierarchy, cc_matrix[index_transformation][:, index_transformation])))


for label in ['E', 'G', 'G2']:
    axes[label].spines['right'].set_color('none')
    axes[label].spines['left'].set_color('none')
    axes[label].spines['top'].set_color('none')
    axes[label].spines['bottom'].set_color('none')
    axes[label].yaxis.set_ticks_position("none")
    axes[label].xaxis.set_ticks_position("none")
    axes[label].set_xticks([])
    axes[label].set_yticks([])


"""
Plot the colorbar for the surface plots.
"""
ax = axes['G2']

sm = pl.cm.ScalarMappable(cmap=pl.get_cmap('inferno_r'), norm=pl.Normalize(
    vmax=np.min(hierarchy), vmin=np.max(hierarchy)))
sm.set_array([])
cbticks = []
cbar = pl.colorbar(sm, ax=ax, ticks=cbticks, shrink=0.9)
ax.annotate('', xy=(1.3, 0.9), xycoords='axes fraction',
            xytext=(1.3, 0.1), arrowprops=dict(arrowstyle="->", color='k'))
ax.text(1.45, 23., 'Temporal hierarchy', rotation=90)

pl.text(0.02, 0.1, r'\bfseries{}Order of cortical areas', transform=fig.transFigure)
pl.text(0.02, 0.08, ' '.join((r'\textbf{A}:', area_string_A)),
        transform=fig.transFigure, fontsize=7)
pl.text(0.02, 0.06, ' '.join((r'\textbf{C}:', area_string_C)),
        transform=fig.transFigure, fontsize=7)
pl.text(0.02, 0.04, ' '.join((r'\textbf{D}:', area_string_D)),
        transform=fig.transFigure, fontsize=7)
pl.text(0.02, 0.02, ' '.join((r'\textbf{F}:', area_string_F)),
        transform=fig.transFigure, fontsize=7)


"""
Save figure
"""
pl.savefig('Fig7_temporal_hierarchy_mpl.eps')


"""
Merge surface plots
"""
c = pyx.canvas.canvas()

c.insert(pyx.epsfile.epsfile(
    0., 0., "Fig7_temporal_hierarchy_mpl.eps", width=18.))
c.insert(pyx.epsfile.epsfile(
    2., 2.2, "Fig7_surface_plot_lateral.eps", width=5.5))
c.insert(pyx.epsfile.epsfile(
    8., 2.2, "Fig7_surface_plot_medial.eps", width=5.5))

c.writeEPSfile("Fig7_temporal_hierarchy.eps")