import json
import numpy as np
import os
import pyx

from helpers import original_data_path
from multiarea_model import MultiAreaModel
from matrix_plot import matrix_plot, rate_histogram_plot

import pylab as pl
from matplotlib import gridspec
from matplotlib import rc_file
rc_file('plotstyle.rc')

"""
Figure layout
"""
nrows = 2
ncols = 4
width = 7.0866
panel_wh_ratio = 0.7 * (1. + np.sqrt(5)) / 2.  # golden ratio
height = 6.
pl.rcParams['figure.figsize'] = (width, height)


fig = pl.figure()
gs1 = gridspec.GridSpec(6, 1)
gs1.update(left=0.08, right=0.55, top=0.95,
           bottom=0.5, wspace=0., hspace=0.2)
ax_rates = []
ax_rates.append(pl.subplot(gs1[:1, 0:1]))
ax_rates.append(pl.subplot(gs1[1:2, 0:1]))
ax_rates.append(pl.subplot(gs1[2:3, 0:1]))
ax_rates.append(pl.subplot(gs1[3:4, 0:1]))
ax_rates.append(pl.subplot(gs1[4:5, 0:1]))
ax_rates.append(pl.subplot(gs1[5:6, 0:1]))

gs2 = gridspec.GridSpec(1, 1)
gs2.update(left=0.83, right=0.95, top=0.95,
           bottom=0.5, wspace=0., hspace=0.2)
ax_EV = pl.subplot(gs2[:, :])

gs4 = gridspec.GridSpec(1, 1)
gs4.update(left=0.08, right=0.7, top=0.4,
           bottom=0.04, wspace=0., hspace=0.2)
ax_sketch = pl.subplot(gs4[:, :])
ax_sketch.spines['right'].set_color('none')
ax_sketch.spines['top'].set_color('none')
ax_sketch.spines['left'].set_color('none')
ax_sketch.spines['bottom'].set_color('none')
ax_sketch.yaxis.set_ticks_position("none")
ax_sketch.xaxis.set_ticks_position("none")
ax_sketch.set_xticks([])
ax_sketch.set_yticks([])

gs3 = gridspec.GridSpec(6, 1)
gs3.update(left=0.62, right=0.75, top=0.95,
           bottom=0.5, wspace=0., hspace=0.2)
ax_phasespace = []
ax_phasespace.append(pl.subplot(gs3[:1, 0:1]))
ax_phasespace.append(pl.subplot(gs3[1:2, 0:1]))
ax_phasespace.append(pl.subplot(gs3[2:3, 0:1]))
ax_phasespace.append(pl.subplot(gs3[3:4, 0:1]))
ax_phasespace.append(pl.subplot(gs3[4:5, 0:1]))
ax_phasespace.append(pl.subplot(gs3[5:6, 0:1]))

gs4 = gridspec.GridSpec(2, 1)
gs4.update(left=0.72, right=0.96, top=0.4,
           bottom=0.04, wspace=0., hspace=0.25)
ax_matrix = pl.subplot(gs4[:1, :])
ax_hist = pl.subplot(gs4[1:2, :])
ax_hist.spines['right'].set_color('none')
ax_hist.spines['top'].set_color('none')
ax_hist.spines['left'].set_color('none')
ax_hist.spines['bottom'].set_color('none')
ax_hist.yaxis.set_ticks_position("none")
ax_hist.xaxis.set_ticks_position("none")
ax_hist.set_xticks([])
ax_hist.set_yticks([])

for ax, label in zip([ax_rates[0], ax_phasespace[0], ax_EV, ax_sketch, ax_matrix],
                     ['A', 'B', 'C', 'D', 'E']):
    if label == 'C':
        label_pos = [-0.1, 1.]
    else:
        label_pos = [-0.1, 1.01]

    ax.text(label_pos[0], label_pos[1], r'\bfseries{}' + label,
            fontdict={'fontsize': 10.,
                      'weight': 'bold',
                      'horizontalalignment': 'left',
                      'verticalalignment': 'bottom'},
            transform=ax.transAxes)

"""
Load data
"""
chi_list = [1.0, 1.8, 1.9, 2., 2.1, 2.5]

"""
Create MultiAreaModel instance to have access to data structures
"""
M = MultiAreaModel({})

LOAD_ORIGINAL_DATA = True

if LOAD_ORIGINAL_DATA:
    labels = ['33fb5955558ba8bb15a3fdce49dfd914682ef3ea',
              '1474e1884422b5b2096d3b7a20fd4bdf388af7e0',
              '99c0024eacc275d13f719afd59357f7d12f02b77',
              'f18158895a5d682db5002489d12d27d7a974146f',
              '08a3a1a88c19193b0af9d9d8f7a52344d1b17498',
              '5bdd72887b191ec22a5abcc04ca4a488ea216e32']

    label_stat_rate = '99c0024eacc275d13f719afd59357f7d12f02b77'
    data_path = original_data_path
else:
    from network_simulations import init_models
    from config import data_path
    models = init_models('Fig4')
    labels = [M.simulation.label for M in models]
    label_stat_rate = labels[2]  # chi=1.9

rate_time_series = {label: {} for label in labels}
rate_time_series_pops = {label: {} for label in labels}
for label in labels:
    for area in M.area_list:
        fn = os.path.join(data_path, label,
                          'Analysis',
                          'rate_time_series_full',
                          'rate_time_series_full_{}.npy'.format(area))
        rate_time_series[label][area] = np.load(fn)
        rate_time_series_pops[label][area] = {}
        for pop in M.structure[area]:
            fn = os.path.join(data_path, label,
                              'Analysis',
                              'rate_time_series_full',
                              'rate_time_series_full_{}_{}.npy'.format(area, pop))
            rate_time_series_pops[label][area][pop] = np.load(fn)

    with open(os.path.join(data_path, label,
                           'Analysis',
                           'rate_time_series_full',
                           'rate_time_series_full_Parameters.json')) as f:
        rate_time_series[label]['Parameters'] = json.load(f)

# stationary firing rates
fn = os.path.join(data_path, label_stat_rate, 'Analysis', 'pop_rates.json')
with open(fn, 'r') as f:
    pop_rates = {label_stat_rate: json.load(f)}


# Meanfield part: first initialize base class to compute initial rates
# and then compute analytical rates for all configurations of chi
K_stable_path = '../SchueckerSchmidt2017/K_prime_original.npy'

conn_params = {'g': -12.,
               'cc_weights_factor': 1.,
               'cc_weights_I_factor': 1.,
               'K_stable': K_stable_path,
               'fac_nu_ext_5E': 1.125,
               'fac_nu_ext_6E': 1.125 * 10 / 3. - 7 / 3.,
               'fac_nu_ext_TH': 1.2}
input_params = {'rate_ext': 10.}
network_params = {'connection_params': conn_params,
                  'input_params': input_params}

initial_rates = np.zeros(254)
theory_params = {'T': 30.,
                 'dt': 0.01,
                 'initial_rates': initial_rates}

M = MultiAreaModel(network_params, theory=True, theory_spec=theory_params)
p, r_base = M.theory.integrate_siegert()

"""
Plotting
"""
print("Plotting rate time series")
area = 'V1'
for i, (cc, label) in enumerate(zip(chi_list, labels)):
    ax = ax_rates[i]
    ax.spines['right'].set_color('none')
    ax.spines['top'].set_color('none')
    ax.yaxis.set_ticks_position("left")
    ax.xaxis.set_ticks_position("bottom")

    ax.plot(rate_time_series[label][area], lw=1, color='k')
    ax.set_xlim((0., 5e4))
    ax.set_ylim((-5., 60.))
    ax.set_yticks([5., 40.])
    ax.text(51500., 48., r'$\chi = $' + str(cc))
    if i == len(labels) - 1:
        ax.vlines(0., 0., 40.)
        ax.hlines(0., 0., 2.)

        ax.set_xlabel('Time (s)')
        ax.set_xticks([0., 1e4, 2e4, 3e4, 4e4, 5e4])
        ax.set_xticklabels([0, 10, 20, 30, 40, 50])
    else:
        ax.set_xticks([])
    if i == 3:
        ax.set_ylabel(r'Rate $(\mathrm{spikes/s})$')

print("Plotting critical eigenvalues")
lambda_max = []
analytical_rates = {}
for chi, label in zip(chi_list[:-1], labels[:-1]):
    if chi == 1.:
        chi_I = 1.
    else:
        chi_I = 2.

    conn_params = {'g': -12.,
                   'cc_weights_factor': chi,
                   'cc_weights_I_factor': chi_I,
                   'K_stable': K_stable_path,
                   'fac_nu_ext_5E': 1.125,
                   'fac_nu_ext_6E': 1.125 * 10 / 3. - 7 / 3.,
                   'fac_nu_ext_TH': 1.2}
    input_params = {'rate_ext': 10.}
    network_params = {'connection_params': conn_params,
                      'input_params': input_params}

    initial_rates = np.zeros(254)
    theory_params = {'T': 30.,
                     'dt': 0.01,
                     'initial_rates': initial_rates}

    M = MultiAreaModel(network_params, theory=True, theory_spec=theory_params)

    pops, rates_full = M.theory.integrate_siegert()
    analytical_rates[chi] = rates_full[:, -1]
    ana_rates = analytical_rates[chi]
    lambda_max.append(M.theory.lambda_max(ana_rates))

ax = ax_EV
ax.spines['right'].set_color('none')
ax.spines['left'].set_color('none')
ax.spines['top'].set_color('none')
ax.yaxis.set_ticks_position("left")
ax.xaxis.set_ticks_position("bottom")

ax.plot(lambda_max, np.arange(len(lambda_max)), '.', ms=5)
ax.vlines(1., 0, 4)
ax.set_ylim((-0.5, 4.5))

ax.invert_yaxis()
ax.set_xlabel(r'$\mathrm{max}\{\mathrm{Re}\left(\lambda_i\right)\}$')
ax.set_ylabel(r'$\chi$')
ax.set_xticks([0.5, 1.])
ax.set_yticks(np.arange(0., 5.))
ax.set_yticklabels(chi_list)

        
load_path = 'Fig4_theory_data'
for i, cc_weights_factor in enumerate(chi_list):
    ax = ax_phasespace[i]
    ax.spines['right'].set_color('none')
    ax.spines['left'].set_color('none')
    ax.spines['top'].set_color('none')
    ax.set_xticks([0., 5., 35.])
    if i != len(labels) - 1:
        ax.set_xticklabels([])
    else:
        ax.set_xlabel(r'Average activity ($\mathrm{spikes/s}$)')
    ax.yaxis.set_ticks_position("none")
    ax.xaxis.set_ticks_position("bottom")
    ax.set_yticks([])

    # Trim data to 10000 samples to have the same sample number for all configurations
    data = np.load(os.path.join(load_path,
                                'results_{}.npy'.format(cc_weights_factor)))
    vals, bins = np.histogram(
        np.mean(data[:, :, -1], axis=1), bins=50, range=(0, 40))

    ax.bar(bins[:-1], vals, width=np.diff(bins)
           [0], color='k', edgecolor='none')

print("Plotting rate matrix")
label = '99c0024eacc275d13f719afd59357f7d12f02b77'

matrix = np.zeros((len(M.area_list), 8))
for i, area in enumerate(M.area_list):
    for j, pop in enumerate(M.structure['V1'][::-1]):
        if pop not in M.structure[area]:
            rate = np.nan
        else:
            rate = pop_rates[label][area][pop][0]
        if rate == 0.0:
            rate = 1e-5
        matrix[i][j] = rate
matrix = np.transpose(matrix)
matrix_plot(fig, ax_matrix, matrix, position='single')

pos = ax_hist.get_position()
ax_hist_pos = [pos.x0, pos.y0, pos.x1 - pos.x0, pos.y1 - pos.y0]
rate_histogram_plot(fig, ax_hist_pos, matrix, position='single')

pl.savefig('Fig4_metastability_mpl.eps')

"""
Merge with sketch figure
"""
c = pyx.canvas.canvas()
c.insert(pyx.epsfile.epsfile(
    0.5, 0.5, "Fig4_metastability_mpl.eps", width=17.6))
c.insert(pyx.epsfile.epsfile(
    0.8, 1., "Fig4_metastability_phasespace_sketch.eps", width=12.2))

c.writeEPSfile("Fig4_metastability.eps")