#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Thu Apr 27 11:01:48 2017.

@author: spiros
"""

from tqdm import tqdm
from place_cell_metrics import field_size
from place_cell_metrics import peak_frequency
from place_cell_metrics import selectivity_index
from place_cell_metrics import sparsity_index2
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np
import pickle
import os
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

fnames = 'Simulation_Results/'

Npyramidals = 130
Nperms = 500
Nbins = 100

npath_x, npath_y = 200, 1

nTrials = 10


spec = 'final_results'

os.system('mkdir -p '+spec+'/figures_final/')

path_figs = spec+'/figures_final/'
path_data = spec+'/data_final/'

file_load = spec+'/metrics2/'
file_load_perms = spec+'/metrics_permutations/'
trials = [str(i) for i in range(1, nTrials+1)]
maindir = os.getcwd()

my_list = ['Control']
for i in range(5, 36, 5):
    my_list += ['Desynch'+str(i), 'ALL0.50_'+str(i)]
for i in ['0.75', '0.50', '0.25']:
    my_list += ['SOMred'+i, 'PVred'+i]
my_list += ['SOMdel', 'PVdel']

everything = {}

rateMaps = {}
infors = {}
stabs1 = {}
stabs2 = {}
stabsAll = {}

prog_bar1 = tqdm(my_list)
prog_bar2 = tqdm(trials)

for case in prog_bar1:

    prog_bar1.set_description('Processing Case: %s' % case)

    infor = []
    stab1 = []
    stab2 = []
    staball = []
    for ntrial in prog_bar2:
        prog_bar2.set_description('Processing Trial %s' % ntrial)
        with open(file_load+'/pickled_sn_'+case+'_'+ntrial+'.pkl', 'rb') as f:
            loaded_data = pickle.load(f)

        if ntrial == '1':
            rateMaps[case] = loaded_data['maps']
        else:
            rateMaps[case] = np.concatenate(
                (rateMaps[case], loaded_data['maps']), axis=0)

        for npyr in range(Npyramidals):

            with open(file_load_perms+'perms_pickled_info_'+case+'_Npyr_'+str(npyr)+'_Mouse_'+ntrial+'.pkl', 'rb') as f:
                infor.append(pickle.load(f))
            with open(file_load_perms+'perms_pickled_stab_even_odd_'+case+'_Npyr_'+str(npyr)+'_Mouse_'+ntrial+'.pkl', 'rb') as f:
                stab1.append(pickle.load(f))
            with open(file_load_perms+'perms_pickled_stab_first_second_'+case+'_Npyr_'+str(npyr)+'_Mouse_'+ntrial+'.pkl', 'rb') as f:
                stab2.append(pickle.load(f))
            with open(file_load_perms+'perms_pickled_stab_all_'+case+'_Npyr_'+str(npyr)+'_Mouse_'+ntrial+'.pkl', 'rb') as f:
                stab_all = pickle.load(f)
            staball.append([np.nanmean(i) for i in stab_all])

    infors[case] = infor
    stabs1[case] = stab1
    stabs2[case] = stab2
    stabsAll[case] = staball

nCells = rateMaps[my_list[0]].shape[0]

SpatialMapsALL = {}
SpatialMapsPLC = {}

for case in prog_bar1:
    prog_bar1.set_description('Processing Case: %s' % case)

    inforALL1 = []
    stabALL1 = []
    inforALL2 = []
    stabALL2 = []
    sparsALL = []
    selecALL = []
    peaksALL = []
    sizesALL = []
    average1ALL = []
    average2ALL = []
    place_cells_idx = []

    numbersALL = []
    numbers_plc = 0
    for npyr in range(nCells):

        pval_infor = sum(infors[case][npyr][0] >=
                         infors[case][npyr][1:])/float(Nperms)
        pval_stab1 = sum(stabs1[case][npyr][0] >=
                         stabs1[case][npyr][1:])/float(Nperms)
        pval_stab2 = sum(stabs2[case][npyr][0] >=
                         stabs2[case][npyr][1:])/float(Nperms)
        pval_staball = sum(stabsAll[case][npyr][0]
                           >= stabsAll[case][npyr][1:])/float(Nperms)

        rate_map = rateMaps[case][npyr, :, :]

        maxpeak = np.max(rate_map)
        meanrate = np.mean(rate_map)
        fsize = field_size(rate_map, relfreq=0.2 *
                           maxpeak, track_length=Nbins)[0]
        infield = field_size(rate_map, relfreq=0.2 *
                             maxpeak, track_length=Nbins)[1]
        outfield = field_size(rate_map, relfreq=0.2 *
                              maxpeak, track_length=Nbins)[2]

        clevel = 0.99
        if (maxpeak >= 1.0) and (pval_infor >= clevel) and (pval_staball >= clevel) and (5./(npath_x/Nbins) <= fsize <= 25./(npath_x/Nbins)):

            numbers_plc += 1
            place_cells_idx.append(npyr)

            inforALL2.append(infors[case][npyr][0])
            # stabALL2.append((stabs1[case][npyr][0]+stabs2[case][npyr][0])/2.0)
            stabALL2.append(stabsAll[case][npyr][0])
            sparsALL.append(sparsity_index2(rate_map))
            selecALL.append(selectivity_index(rate_map))
            peaksALL.append(peak_frequency(rate_map))
            sizesALL.append(fsize)
            average1ALL.append(infield)
            average2ALL.append(outfield)

        inforALL1.append(infors[case][npyr][0])
        # stabALL1.append((stabs1[case][npyr][0]+stabs2[case][npyr][0])/2.0)
        stabALL1.append(stabsAll[case][npyr][0])

    numbersALL.append(numbers_plc/(len(trials)*130.0))

    rate_maps_all = rateMaps[case][:, :, :]
    Ncells = rate_maps_all.shape[0]

    idx = np.argmax(rate_maps_all.squeeze(), axis=1)
    new_idx = np.lexsort((range(Ncells), idx))
    rtMaps = rate_maps_all[new_idx, :, :].squeeze()

    Max = np.max(rtMaps, axis=1).reshape(-1, 1)
    for i in range(Max.shape[0]):
        if Max[i, 0] == 0:
            Max[i, 0] = 1e-12

    rtMaps = rtMaps / np.repeat(Max, Nbins, axis=1)
    SpatialMapsALL[case] = rtMaps

    fig, axes = plt.subplots(nrows=1, ncols=2)

    im0 = axes[0].imshow(rtMaps.squeeze(), cmap="jet", aspect='equal')
    divider = make_axes_locatable(axes[0])
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im0, cax=cax)
    ax = plt.gca()

    rate_maps_plc = rateMaps[case][place_cells_idx, :, :]
    Ncells = rate_maps_plc.shape[0]

    idx = np.argmax(rate_maps_plc.squeeze(), axis=1)
    new_idx = np.lexsort((range(Ncells), idx))
    rtMaps = rate_maps_plc[new_idx, :, :].squeeze()

    Max = np.max(rtMaps, axis=1).reshape(-1, 1)
    for i in range(Max.shape[0]):
        if Max[i, 0] == 0:
            Max[i, 0] = 1e-12

    rtMaps = rtMaps / np.repeat(Max, Nbins, axis=1)

    if rtMaps.shape[0] != 0:
        im1 = axes[1].imshow(rtMaps.squeeze(), cmap="jet", aspect='equal')
        # create an axes on the right side of ax. The width of cax will be 5%
        # of ax and the padding between cax and ax will be fixed at 0.05 inch.
        divider = make_axes_locatable(axes[1])
        cax = divider.append_axes("right", size="5%", pad=0.05)
        plt.colorbar(im1, cax=cax)

    SpatialMapsPLC[case] = rtMaps

    plt.tight_layout()
    plt.savefig(path_figs+'/'+case+'_PlaceCells_heatmap.pdf',
                format='pdf', dpi=300)
    plt.savefig(path_figs+'/'+case+'_PlaceCells_heatmap.png',
                format='png', dpi=300)
    plt.cla()
    plt.clf()
    plt.close()

    mydict_all = {}
    mydict_all['information_all'] = inforALL1
    mydict_all['stabilityIndex_all'] = stabALL1
    mydict_all['information_plc'] = inforALL2
    mydict_all['stabilityIndex_plc'] = stabALL2
    mydict_all['sparsity'] = sparsALL
    mydict_all['selectivity'] = selecALL
    mydict_all['fieldsize'] = sizesALL
    mydict_all['PeakRate'] = peaksALL
    mydict_all['numbers_plc'] = numbersALL
    mydict_all['MeanRateInfield'] = average1ALL
    mydict_all['MeanRateOutfield'] = average2ALL

    metrics = ['information_all', 'stabilityIndex_all', 'information_plc',
               'stabilityIndex_plc', 'sparsity', 'selectivity', 'fieldsize',
               'PeakRate', 'MeanRateInfield', 'MeanRateOutfield']

    everything[case] = mydict_all