#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Thu Apr 27 11:01:48 2017.

@author: spiros
"""

import pandas as pd
import pickle
import os
import scipy.stats

import matplotlib.pyplot as plt
from place_cell_metrics import sparsity_index2
from place_cell_metrics import selectivity_index
from place_cell_metrics import peak_frequency
from place_cell_metrics import field_size, spatial_coherence
import numpy as np

# matplotlib.use('agg')
from mpl_toolkits.axes_grid1 import make_axes_locatable
from visualization_ import make_dicts, bar_plots

# Save in MATLAB format
from scipy.io import savemat

npath_x, npath_y = 200, 1
Nbins = 100
trials = 10
Npyramidals = 130
Nperms = 500

what_to_do = 3


if what_to_do == 0:
    my_list = ['Control', 'ALL0.50_20', 'Desynch20', 'SOMred0.50', 'PVred0.50']
elif what_to_do == 1:
    my_list = ['Control']
    for i in range(5, 36, 5):
        my_list += ['Desynch'+str(i), 'ALL0.50_'+str(i)]
elif what_to_do == 2:
    my_list = ['Control']
    for i in ['0.75', '0.50', '0.25']:
        my_list += ['SOMred'+i, 'PVred'+i]
    my_list += ['SOMdel', 'PVdel']
elif what_to_do == 3:
    my_list = ['Control']
    for i in range(5, 36, 5):
        my_list += ['Desynch'+str(i), 'ALL0.50_'+str(i)]
    for i in ['0.75', '0.50', '0.25']:
        my_list += ['SOMred'+i, 'PVred'+i]
    my_list += ['SOMdel', 'PVdel']

spec = 'final_results/'
path_figs = spec+'/figures2/'
path_figs2 = spec+'/figures2_metrics/'
file_load = spec+'/metrics2/'
file_load_perms = spec+'/metrics_permutations/'
path_data = spec+'/data_final3/'

trials = [str(i) for i in range(1, trials+1)]

os.system('mkdir -p ' + path_figs2)
os.system('mkdir -p '+path_data)

numbers_all = {}
information_all = {}
information_plc = {}
stability_index_plc = {}
stability_index_all = {}
stability_all_plc = {}
stability_all_all = {}
sparsity_all = {}
selectivity_all = {}
peak_freq_all = {}
mean_freq_in_all = {}
mean_freq_out_all = {}
fieldsize_all = {}
fieldsize_plc = {}
coherence_all = {}
coherence_plc = {}

SpatialMapsALL = {}
SpatialMapsPLC = {}

RateMapMat = {}

for case in my_list:
    rate_map_mat = np.zeros((len(trials), Nbins, Npyramidals))
    for trial in trials:
        with open(file_load+'pickled_sn_'+case+'_'+trial+'.pkl', 'rb') as f:
            loaded_data = pickle.load(f)
        for npyr in range(Npyramidals):
            rate_map = loaded_data['maps'].squeeze()
            rate_map = rate_map[npyr, :]
            rate_map_mat[int(trial)-1, :, int(npyr)] = rate_map

    RateMapMat[case] = rate_map_mat

for ntrial in trials:
    print("TRIAL:", ntrial)
    rateMaps = {}
    time_bin = {}
    information1 = {}
    information2 = {}
    stability_index = {}
    sparsity = {}
    selectivity = {}
    peak_freq = {}
    mean_freq_in = {}
    mean_freq_out = {}
    fieldsize1 = {}
    fieldsize2 = {}
    numbers_plc = {}
    number_of_peaks = {}
    reward_zone = {}
    stability1 = {}
    stability2 = {}
    stability1_all = {}
    stability2_all = {}
    coherence1 = {}
    coherence2 = {}

    SpatialMapsALL['Mouse'+str(ntrial)] = {}
    SpatialMapsPLC['Mouse'+str(ntrial)] = {}

    for case in my_list:
        with open(file_load+'/pickled_sn_'+case+'_'+ntrial+'.pkl', 'rb') as f:
            loaded_data = pickle.load(f)

        place_cells = np.loadtxt('../Simulation_Results/Control/Trial_' +
                                 str(ntrial)+'/Run_1/input_plcs.txt', delimiter=',')
        place_cells = [int(x) for x in list(place_cells[:, 0])]
        rateMaps[case] = loaded_data['maps']
        time_bin[case] = loaded_data['time_in_bin']

        inforALL1 = []
        stabALL1 = []
        stabALL11 = []
        inforALL2 = []
        stabALL2 = []
        stabALL22 = []
        sparsALL = []
        selecALL = []
        peaksALL = []
        sizesALL1 = []
        sizesALL2 = []
        average1ALL = []
        average2ALL = []
        numbersALL = 0
        place_cells_idx = []
        place_cells_max = []
        coher1 = []
        coher2 = []

        for npyr in range(Npyramidals):

            with open(file_load_perms+'perms_pickled_info_'+case+'_Npyr_'+str(npyr)+'_Mouse_'+ntrial+'.pkl', 'rb') as f:
                infor = pickle.load(f)
            with open(file_load_perms+'perms_pickled_stab_even_odd_'+case+'_Npyr_'+str(npyr)+'_Mouse_'+ntrial+'.pkl', 'rb') as f:
                stab1 = pickle.load(f)
            with open(file_load_perms+'perms_pickled_stab_first_second_'+case+'_Npyr_'+str(npyr)+'_Mouse_'+ntrial+'.pkl', 'rb') as f:
                stab2 = pickle.load(f)
            with open(file_load_perms+'perms_pickled_stab_all_'+case+'_Npyr_'+str(npyr)+'_Mouse_'+ntrial+'.pkl', 'rb') as f:
                stab_all = pickle.load(f)
            stab_all = [np.nanmean(i) for i in stab_all]

            pval_infor = sum(infor[0] >= infor[1:])/float(Nperms)
            pval_stab1 = sum(stab1[0] >= stab1[1:])/float(Nperms)
            pval_stab2 = sum(stab2[0] >= stab2[1:])/float(Nperms)
            pval_staball = sum(stab_all[0] >= stab_all[1:])/float(Nperms)

            rate_map = rateMaps[case][npyr, :, :]
            t_bin = time_bin[case][npyr, :, :]

            maxpeak = np.max(rate_map)
            meanrate = np.mean(rate_map)
            sizetest1 = field_size(
                rate_map, relfreq=0.2*maxpeak, track_length=Nbins)[0]
            infield = field_size(rate_map, relfreq=0.2 *
                                 maxpeak, track_length=Nbins)[1]
            outfield = field_size(rate_map, relfreq=0.2 *
                                  maxpeak, track_length=Nbins)[2]
            sp_coher = spatial_coherence(rate_map.squeeze(), window=3)
            clevel = 0.99
            if (maxpeak >= 1.0) and (pval_infor >= clevel) and (pval_staball >= clevel) and (5./(npath_x/Nbins) <= sizetest1 <= 25./(npath_x/Nbins)):

                numbersALL += 1
                place_cells_idx.append(npyr)
                place_cells_max.append(np.argmax(rate_map))

                inforALL2.append(infor[0])
                stabALL2.append(
                    (np.math.atanh(stab1[0])+np.math.atanh(stab2[0]))/2.0)
                stabALL22.append(np.math.atanh(stab_all[0]))
                sparsALL.append(sparsity_index2(rate_map))
                selecALL.append(selectivity_index(rate_map))
                peaksALL.append(peak_frequency(rate_map))
                sizesALL2.append(sizetest1)
                average1ALL.append(field_size(
                    rate_map, relfreq=0.2*np.max(rate_map), track_length=Nbins)[1])
                average2ALL.append(field_size(
                    rate_map, relfreq=0.2*np.max(rate_map), track_length=Nbins)[2])
                coher2.append(sp_coher)

            inforALL1.append(infor[0])
            stabALL1.append(
                (np.math.atanh(stab1[0])+np.math.atanh(stab2[0]))/2.0)
            stabALL11.append(np.math.atanh(stab_all[0]))
            sizesALL1.append(sizetest1)
            coher1.append(sp_coher)

        rate_maps_all = rateMaps[case][:, :, :]
        Ncells = rate_maps_all.shape[0]

        idx = np.argmax(rate_maps_all.squeeze(), axis=1)
        new_idx = np.lexsort((range(Ncells), idx))
        rtMaps = rate_maps_all[new_idx, :, :].squeeze()

        Max = np.max(rtMaps, axis=1).reshape(-1, 1)
        for i in range(Max.shape[0]):
            if Max[i, 0] == 0:
                Max[i, 0] = 1e-12

        rtMaps = rtMaps / np.repeat(Max, Nbins, axis=1)

        fig, axes = plt.subplots(nrows=1, ncols=2)
        SpatialMapsALL['Mouse'+str(ntrial)][case] = rtMaps
        im0 = axes[0].imshow(rtMaps.squeeze(), cmap="jet", aspect='equal')
        divider = make_axes_locatable(axes[0])
        cax = divider.append_axes("right", size="5%", pad=0.05)
        plt.colorbar(im0, cax=cax)
        ax = plt.gca()

        np.savetxt(file_load+'place_cell_idx_trial_'+str(ntrial)+'_'+case+'.txt',
                   place_cells_idx, fmt='%i')

        rate_maps_plc = rateMaps[case][place_cells_idx, :, :]
        Ncells = rate_maps_plc.shape[0]

        idx = np.argmax(rate_maps_plc.squeeze(), axis=1)
        new_idx = np.lexsort((range(Ncells), idx))
        rtMaps = rate_maps_plc[new_idx, :, :].squeeze()

        Max = np.max(rtMaps, axis=1).reshape(-1, 1)
        for i in range(Max.shape[0]):
            if Max[i, 0] == 0:
                Max[i, 0] = 1e-12
        rtMaps = rtMaps / np.repeat(Max, Nbins, axis=1)

        SpatialMapsPLC['Mouse'+str(ntrial)][case] = rtMaps
        im1 = axes[1].imshow(rtMaps.squeeze(), cmap="jet", aspect='equal')
        # create an axes on the right side of ax. The width of cax will be 5%
        # of ax and the padding between cax and ax will be fixed at 0.05 inch.
        divider = make_axes_locatable(axes[1])
        cax = divider.append_axes("right", size="5%", pad=0.05)
        plt.colorbar(im1, cax=cax)
        plt.tight_layout()
        plt.savefig(path_figs+'/'+case+'_PlaceCells_heatmap_' +
                    str(ntrial)+'.pdf', format='pdf', dpi=300)
        plt.cla()
        plt.clf()
        plt.close()

        print(numbersALL/130., case)
        numbers_plc[case] = numbersALL
        information1[case] = inforALL1
        stability1[case] = stabALL1
        stability1_all[case] = stabALL11
        information2[case] = inforALL2
        stability2[case] = stabALL2
        stability2_all[case] = stabALL22
        sparsity[case] = sparsALL
        selectivity[case] = selecALL
        peak_freq[case] = peaksALL
        mean_freq_in[case] = average1ALL
        mean_freq_out[case] = average2ALL
        fieldsize1[case] = sizesALL1
        fieldsize2[case] = sizesALL2
        coherence1[case] = coher1
        coherence2[case] = coher2

        if case in numbers_all.keys():
            numbers_all[case].append(numbers_plc[case]/130.)
        else:
            numbers_all[case] = [numbers_plc[case]/130.]

        information_all = make_dicts(information_all, information1, case)
        information_plc = make_dicts(information_plc, information2, case)
        stability_index_all = make_dicts(stability_index_all, stability1, case)
        stability_index_plc = make_dicts(stability_index_plc, stability2, case)
        stability_all_all = make_dicts(stability_all_all, stability1_all, case)
        stability_all_plc = make_dicts(stability_all_plc, stability2_all, case)
        sparsity_all = make_dicts(sparsity_all, sparsity, case)
        selectivity_all = make_dicts(selectivity_all, selectivity, case)
        peak_freq_all = make_dicts(peak_freq_all, peak_freq, case)
        mean_freq_in_all = make_dicts(mean_freq_in_all, mean_freq_in, case)
        mean_freq_out_all = make_dicts(mean_freq_out_all, mean_freq_out, case)
        fieldsize_all = make_dicts(fieldsize_all, fieldsize1, case)
        fieldsize_plc = make_dicts(fieldsize_plc, fieldsize2, case)
        coherence_all = make_dicts(coherence_all, coherence1, case)
        coherence_plc = make_dicts(coherence_plc, coherence2, case)

mydict_all = {}
mydict_all['numbers_place_cells'] = numbers_all
mydict_all['information_all'] = information_all
mydict_all['stabilityIndex_all'] = stability_index_all
mydict_all['stabilityALL_all'] = stability_all_all
mydict_all['information_plc'] = information_plc
mydict_all['stabilityIndex_plc'] = stability_index_plc
mydict_all['stabilityALL_plc'] = stability_all_plc
mydict_all['sparsity'] = sparsity_all
mydict_all['selectivity'] = selectivity_all
mydict_all['fieldsize_all'] = fieldsize_all
mydict_all['fieldsize_plc'] = fieldsize_plc
mydict_all['peak_freq'] = peak_freq_all
mydict_all['numbers_plc'] = numbers_all
mydict_all['mean_in'] = mean_freq_in_all
mydict_all['mean_out'] = mean_freq_out_all
mydict_all['coherence_all'] = coherence_all
mydict_all['coherence_plc'] = coherence_plc


mydict_all1 = {}
mydict_all1['Description'] = '10 animals, 10 trials/animal, 130 pyramidal cells/animal'
mydict_all1['number_of_place_cells'] = numbers_all
mydict_all1['information_of_all_cells'] = information_all
mydict_all1['information_of_place_cells'] = information_plc
mydict_all1['stability_of_all_cells'] = stability_index_all
mydict_all1['stability_of_place_cells'] = stability_index_plc
mydict_all1['spatial_maps_of_all_cells'] = SpatialMapsALL
mydict_all1['spatial_maps_of_place_cells'] = SpatialMapsPLC

savemat(path_data+'Data_spiros.mat', mydict_all1, oned_as='column')

# placecells = numbers_all.values()
placecells = []
for cased in my_list:
    placecells.append(numbers_all[cased])

fig = plt.figure(1, figsize=(5, 5))
y = placecells
y = [i for i in list(np.mean(placecells, axis=1))]
ye = [i for i in list(scipy.stats.sem(placecells, axis=1))]
labels = my_list
N = len(y)
x = range(N)

colors = ['blue', 'red', 'green', 'yellow',
          'lightblue', 'olive', 'darkmagenta', 'darkorange']
plt.bar(x, y, color=colors[:len(my_list)], yerr=ye)

plt.xticks(x, labels, rotation='90')
plt.ylabel('Number of Place Cells', fontsize=16)
plt.savefig(path_figs2+'/'+'numberOfPlaceCells_barplot.eps',
            format='eps', dpi=300)
plt.savefig(path_figs2+'/'+'numberOfPlaceCells_barplot.png',
            format='png', dpi=300)
plt.cla()
plt.clf()
plt.close()


placecells = []
for cased in my_list:
    placecells.append(numbers_all[cased])

plt.figure(1, dpi=150)

y = placecells
labels = my_list
N = len(y)
x = range(1, N+1)

# notch shape box plot
bplot = plt.boxplot(y, notch=False, vert=True, patch_artist=True,
                    labels=labels)  # will be used to label x-ticks

# fill with colors
colors = ['blue', 'red', 'green', 'yellow',
          'lightblue', 'olive', 'darkmagenta', 'darkorange']
for patch, color in zip(bplot['boxes'], colors):
    patch.set_facecolor(color)

for element in ['fliers', 'means', 'medians', 'caps']:
    plt.setp(bplot[element], color='black')

plt.xticks(x, labels, rotation='45')
plt.ylabel('number of place cells', fontsize=16)
plt.savefig(path_figs2+'/'+'numberOfPlaceCells_boxplot.eps',
            format='eps', dpi=300)
plt.savefig(path_figs2+'/'+'numberOfPlaceCells_boxplot.png',
            format='png', dpi=300)
plt.cla()
plt.clf()
plt.close()


my_list2 = ['information_all', 'information_plc', 'sparsity',
            'selectivity', 'peak_freq', 'fieldsize_all', 'fieldsize_plc', 'stabilityIndex_all',
            'stabilityIndex_plc', 'mean_in', 'mean_out', 'stabilityALL_all',
            'stabilityALL_plc', 'coherence_all', 'coherence_plc']

if what_to_do == 0:
    fnam = '_figure_pval_'
elif what_to_do == 1:
    fnam = '_Desynch_pval_'
elif what_to_do == 2:
    fnam = '_INs_pval_'
elif what_to_do == 3:
    fnam = '_ALL_pval_'

for metric in my_list2:
    bar_plots(mydict_all[metric], metric, path_figs2, my_list)


# USE this for PRISM-GraphPad plotting
my_list2 = ['information_all', 'information_plc', 'sparsity',
            'selectivity', 'peak_freq', 'fieldsize_all', 'fieldsize_plc', 'stabilityIndex_all',
            'stabilityIndex_plc', 'mean_in', 'mean_out', 'stabilityALL_all',
            'stabilityALL_plc', 'coherence_all', 'coherence_plc']

L = len(my_list)
for metric in my_list2:
    B = np.zeros((len(trials), L*3))
    cnt = 0
    for case in my_list:
        path_data = spec+'/data_final3/'
        A = np.array(mydict_all[metric][case])
        B[:, 3*cnt:3*(cnt+1)] = A
        cnt += 1

    df = pd.DataFrame(B)
    df.to_csv(path_data+metric+fnam+str(clevel)+'.csv', sep=' ', index=False)