import matplotlib.pyplot as plt
import numpy as np
from plotcolors import *
from matplotlib.colors import LogNorm
from area_list import area_list, population_labels


def rate_matrix_plot(fig, ax, matrix, position):
    ''' Create a matrix plot of pop. rates for the stabilization manuscript
    '''
    cm = plt .get_cmap('rainbow')
    cm = cm.from_list(
        'mycmap', [myblue, myblue2, 'white', myred2,  myred], N=256)
    masked_matrix = np.ma.masked_where(np.isnan(matrix), matrix)
    cm.set_under('0.3')
    cm.set_bad('k')

    im = ax.pcolormesh(
        masked_matrix, norm=LogNorm(vmin=0.01, vmax=500.), cmap=cm)
    ax.set_xlim(0, 32)

    R = plt.Rectangle((area_list.index('TH') + 1, 4), 1, 2,
                      edgecolor='none', facecolor='k')
    ax.add_patch(R)

    ax.set_xticks([0.5, 3.5, 14.5, 24.5, 28.5])
    ax.xaxis.set_major_locator(plt.FixedLocator([0, 1, 4, 9, 24, 31]))
    ax.xaxis.set_major_formatter(plt.NullFormatter())
    ax.xaxis.set_minor_locator(plt.FixedLocator([0.5, 2.5, 6.5, 16.5,
                                                 27.5, 31.5]))
    ax.set_xticklabels([8, 7, 6, 5, 4, 2], minor=True)
    ax.tick_params(axis='x', which='minor', length=0.)
    ax.set_xlabel('Arch. type', labelpad=-0.2)

    y_index = list(range(8))
    y_index = [a + 0.5 for a in y_index]
    t = plt.FixedLocator([0.01, 0.1, 1., 10., 100.])
    cb = plt.colorbar(im, ticks=t, ax=ax)

    if position == 'left':
        ax.set_yticklabels(population_labels)
        ax.set_yticks(y_index[::-1])
        ax.set_ylabel('Population')
        ax.yaxis.set_label_coords(-0.2, 0.5)
        # cb.remove()
    elif position == 'right':
        ax.set_yticks([])
        ax.text(42, 5, r'$\nu (1/\mathrm{s})$', rotation=90)


def rate_histogram_plot(fig, ax_pos, matrix, position):

    # set up lower axis

    ax2_pos = (ax_pos[0], ax_pos[1], 1 / 5. * ax_pos[2], 2 / 5. * ax_pos[3])
    ax1_pos = (ax_pos[0] + 1.2 / 5. * ax_pos[2], ax_pos[1],
               3.8 / 5. * ax_pos[2], 2 / 5. * ax_pos[3])

    ax = plt.axes(ax1_pos)
    plt.locator_params(axis='y', nbins=3)
    #ax_2 = plt.axes(ax2_pos,sharey=ax)
    ax_2 = plt.axes(ax2_pos)
    plt.locator_params(axis='y', nbins=3)

    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')
    ax.tick_params(axis='x', which='minor', length=0.)
    ax.tick_params(axis='y', length=0.)
    ax.spines['right'].set_color('none')
    ax.spines['top'].set_color('none')
    ax.spines['left'].set_color('none')
    ax.set_xscale('Log')
    ax.set_xlabel(r'$\nu (1/\mathrm{s})$', labelpad=2)
    ax.set_xticks(10.**(np.array([-3, -2, -1, 0, 1, 2, 3])))
    ax.set_xticklabels([r'$10^{-3}$', r'', r'$10^{-1}$', r'',
                        r'$10^{1}$', r'', r'$10^{3}$'])
    ax.set_yticks([])
    ax.set_xlim(10**-4, 10**3)
    d = 0.02
    kwargs = dict(transform=ax.transAxes, color='k', clip_on=False)
    ax.plot((-d, d), (-5 * d, 5 * d), **kwargs)
    ax.plot((-d - 0.05, d - 0.05), (-5 * d, 5 * d), **kwargs)

    if position == 'right':
        ax.plot([500., 500.], [0., 60.], '--',
                dashes=(2, 1), color='k', lw=0.5)

    ax_2.yaxis.set_ticks_position('left')
    ax_2.xaxis.set_ticks_position('bottom')
    ax_2.tick_params(axis='x', which='minor', length=0.)
    ax_2.spines['right'].set_color('none')
    ax_2.spines['top'].set_color('none')
    ax_2.set_xscale('Log')
    ax_2.set_xticks(10.**(np.array([-5])))
    ax_2.set_xticklabels([r'$0.0$'])
    ax_2.set_xlim(10**-6, 10**-4)
    if position == 'left':
        ax_2.set_ylabel('Count')
        ax_2.yaxis.set_label_coords(-0.55, 1.4)

    # set up upper axis

    ax_upper_pos = (ax_pos[0] + 1.2 / 5. * ax_pos[2], ax_pos[1] +
                    3 / 5. * ax_pos[3], 3.8 / 5. * ax_pos[2], 2 / 5. * ax_pos[3])
    ax_upper_pos_2 = (ax_pos[0], ax_pos[1] + 3 / 5. *
                      ax_pos[3], 1. / 5. * ax_pos[2], 2 / 5. * ax_pos[3])
    ax_upper = plt.axes(ax_upper_pos)
    ax_upper_2 = plt.axes(ax_upper_pos_2)
    plt.locator_params(axis='y', nbins=3)

    ax_upper.yaxis.set_ticks_position('left')
    ax_upper.xaxis.set_ticks_position('bottom')
    ax_upper.tick_params(axis='x', which='minor', length=0.)
    ax_upper.tick_params(axis='y', length=0.)
    ax_upper.spines['right'].set_color('none')
    ax_upper.spines['top'].set_color('none')
    ax_upper.spines['left'].set_color('none')
    ax_upper.tick_params(axis='y', length=0.)
    ax_upper.set_xscale('Log')
    ax_upper.set_xticks(10.**(np.array([-3, -2, -1, 0, 1, 2, 3])))
    ax_upper.set_xticklabels([])
    ax_upper.set_yticks([])
    ax_upper.set_xlim(10**-4, 10**3)

    d = 0.02
    kwargs = dict(transform=ax_upper.transAxes, color='k', clip_on=False)
    ax_upper.plot((-d, d), (-5 * d, 5 * d), **kwargs)
    ax_upper.plot((-d - 0.05, d - 0.05), (-5 * d, 5 * d), **kwargs)

    ax_upper_2.yaxis.set_ticks_position('left')
    ax_upper_2.xaxis.set_ticks_position('bottom')
    ax_upper_2.tick_params(axis='x', which='minor', length=0.)
    ax_upper_2.spines['right'].set_color('none')
    ax_upper_2.spines['top'].set_color('none')
    ax_upper_2.set_xscale('Log')
    ax_upper_2.set_xticks(10.**(np.array([-5])))
    ax_upper_2.set_xticklabels([])
    ax_upper_2.set_xlim(10**-6, 10**-4)

    if position == 'right':
        ax_upper.plot(
            [500., 500.], [0., 60.], '--', dashes=(2, 1), color='k', lw=0.5)
        ax_upper.text(400., 62., r'$1/\tau_{\mathrm{r}}$')
    # Plot of rates

    E_rates = matrix[1::2]
    E_rates = E_rates[np.where(np.logical_not(np.isnan(E_rates)))]
    bins = 10**np.arange(-6., 3., 0.15)
    E_vals, E_bins = np.histogram(E_rates, bins=bins)

    I_rates = matrix[::2]
    I_rates = I_rates[np.where(np.logical_not(np.isnan(I_rates)))]
    bins = 10**np.arange(-6., 3., 0.15)
    I_vals, I_bins = np.histogram(I_rates, bins=bins)

    ax.bar(I_bins[:-1], I_vals, width=np.diff(I_bins),
           color=myred, linewidth=0.)
    ax_2.bar(I_bins[:-1], I_vals, width=np.diff(I_bins),
             color=myred, linewidth=0.)
    ax_upper.bar(
        E_bins[:-1], E_vals, width=np.diff(E_bins), color=myblue, linewidth=0.)
    ax_upper_2.bar(
        E_bins[:-1], E_vals, width=np.diff(E_bins), color=myblue, linewidth=0.)