#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Thu Apr 27 11:01:48 2017.

@author: spiros
"""
from place_cell_metrics import field_size
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.stats import sem
import numpy as np
import pickle
import os
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42


def make_dicts(dict_all, dict_input, case):
    if isinstance(dict_input[case], int):
        if case in dict_input.keys():
            dict_all[case].append(dict_input[case])
        else:
            dict_all[case] = [dict_input[case]]

    if case in dict_all:
        dict_all[case] += dict_input[case]
    else:
        dict_all[case] = dict_input[case]

    return dict_all


def bar_plots(mydict, metric, learning, path_figs, baseline):
    import matplotlib.pyplot as plt
    import numpy as np
    import scipy.stats
    import matplotlib
    matplotlib.rcParams['pdf.fonttype'] = 42
    matplotlib.rcParams['ps.fonttype'] = 42

    my_list = ['Control', 'No_VIPcells', 'No_VIPCR', 'No_VIPCCK',
               'No_VIPPVM', 'No_VIPNVM', 'No_VIPCRtoBC', 'No_VIPCRtoOLM']
    A = mydict
    A_means = [np.mean(A['Control']), np.mean(A['No_VIPcells']),
               np.mean(A['No_VIPCR']), np.mean(A['No_VIPCCK']),
               np.mean(A['No_VIPPVM']), np.mean(A['No_VIPNVM']),
               np.mean(A['No_VIPCRtoBC']), np.mean(A['No_VIPCRtoOLM'])]
    A_sems = [sem(A['Control']), sem(A['No_VIPcells']), sem(A['No_VIPCR']),
              sem(A['No_VIPCCK']), sem(A['No_VIPPVM']), sem(A['No_VIPNVM']),
              sem(A['No_VIPCRtoBC']), sem(A['No_VIPCRtoOLM'])]

    plt.figure(1, dpi=300)

    y = A_means
    labels = my_list
    N = len(y)
    x = range(N)

    pControl, pVIPcells, pVIPCR, pVIPCCK, pVIPPVM, pVIPNVM, pVIPtoBC, pVIPtoOLM = plt.bar(
        x, y, yerr=A_sems)

    pControl.set_facecolor('blue')
    pVIPcells.set_facecolor('green')
    pVIPCR.set_facecolor('yellow')
    pVIPCCK.set_facecolor('red')
    pVIPPVM.set_facecolor('lightblue')
    pVIPNVM.set_facecolor('lightgreen')
    pVIPtoBC.set_facecolor('yellowgreen')
    pVIPtoOLM.set_facecolor('darkred')

    plt.axhline(y=baseline, linestyle='--', linewidth=2)

    plt.plot()
    plt.xticks(x, labels)
    plt.ylabel(metric, fontsize=16)
    plt.title(learning)
    plt.ylim([0, 0.4])
    plt.savefig(path_figs+learning+'/'+metric +
                '_barplot.pdf', format='pdf', dpi=300)

    plt.cla()
    plt.clf()
    plt.close()

    # Make Boxplots
    A_list = [list(A['Control']), list(A['No_VIPcells']), list(A['No_VIPCR']),
              list(A['No_VIPCCK']), list(A['No_VIPPVM']),
              list(A['No_VIPNVM']), list(A['No_VIPCRtoBC']),
              list(A['No_VIPCRtoOLM'])]

    plt.figure(1, dpi=300)

    y = A_list
    labels = my_list
    N = len(y)
    x = range(1, N+1)

    # notch shape box plot
    # will be used to label x-ticks
    bplot = plt.boxplot(y, notch=True, vert=True,
                        patch_artist=True, labels=labels)

    # fill with colors
    colors = ['blue', 'green', 'yellow', 'red',
              'lightblue', 'lightgreen', 'yellowgreen', 'darkred']
    for patch, color in zip(bplot['boxes'], colors):
        patch.set_facecolor(color)

    for element in ['fliers', 'means', 'medians', 'caps']:
        plt.setp(bplot[element], color='black')

    plt.xticks(x, labels)
    plt.ylabel(metric, fontsize=16)
    plt.title(learning)

    plt.savefig(path_figs+learning+'/'+metric +
                '_boxplot.pdf', format='pdf', dpi=300)

    plt.cla()
    plt.clf()
    plt.close()


fnames = 'Simulation_Results/'

my_list = ['Control', 'No_VIPcells', 'No_VIPCR', 'No_VIPCCK',
           'No_VIPPVM', 'No_VIPNVM', 'No_VIPCRtoBC', 'No_VIPCRtoOLM']

npath_x, npath_y = 200, 1
xlim1, xlim2 = 80, 110
Nbins = 100

xrew1, xrew2 = xlim1/(npath_x/Nbins), xlim2/(npath_x/Nbins)+1

trialsAll = 10
Npyramidals = 130


everything = {}
for learning in ['prelearning', 'locomotion', 'reward']:

    print("\nLEARNING: ", learning)
    print
    print
    spec = 'data_analysis'
    path_figs = spec+'/figures/'
    file_load = spec+'/metrics/'+learning
    trials = [str(i) for i in range(1, trialsAll+1)]
    maindir = os.getcwd()

    numbers_all = {}
    numbers_rwd_all = {}

    for ntrial in trials:
        print("TRIAL:", ntrial)
        rateMaps = {}
        numbers_plc = {}
        numbers_rwd = {}

        for case in my_list:
            with open(file_load+'/pickled_sn_'+case+'_'+ntrial+'.pkl', 'rb') as f:
                loaded_data = pickle.load(f)

            rateMaps[case] = loaded_data['maps']

            numbersALL = 0
            numbersrwdALL = 0
            for npyr in range(Npyramidals):

                rate_map = rateMaps[case][npyr, :, :]
                maxpeak = np.max(rate_map)
                sizetest1 = field_size(
                    rate_map, relfreq=0.1*maxpeak, track_length=Nbins)[0]
                maxpeak = np.max(rate_map)

                lim1 = 8/(npath_x/Nbins)
                lim2 = 40/(npath_x/Nbins)
                if maxpeak >= 3.0 and lim1 <= sizetest1:
                    numbersALL += 1

                    if xrew1 <= np.argmax(rate_map) <= xrew2:
                        numbersrwdALL += 1

            numbers_plc[case] = numbersALL
            numbers_rwd[case] = numbersrwdALL

            # END of loop in different cases ####

            if case in numbers_all.keys():
                numbers_all[case].append(numbers_plc[case])
            else:
                numbers_all[case] = [numbers_plc[case]]

            if case in numbers_rwd_all.keys():
                numbers_rwd_all[case].append(
                    numbers_rwd[case]/float(numbers_plc[case]))
            else:
                numbers_rwd_all[case] = [
                    numbers_rwd[case]/float(numbers_plc[case])]

    mydict_all = {}
    mydict_all['numbers_plc'] = numbers_all
    mydict_all['numbers_rwd'] = numbers_rwd_all

    everything[learning] = mydict_all

baseline = 1.0/6.0
A1 = {}
A2 = {}
A3 = {}
for case in my_list:
    A = []
    for learning in ['prelearning', 'locomotion', 'reward']:

        A.append(everything[learning]['numbers_rwd'][case])

    A1[case] = [x for x in A[0]]
    A2[case] = [x for x in A[1]]
    A3[case] = [x for x in A[2]]

learning = ['prelearning', 'locomotion', 'reward']
N = len(my_list)
pre_means = [np.mean(A1[case]) for case in my_list]
pre_sems = [sem(A1[case]) for case in my_list]

ind = np.arange(N)  # the x locations for the groups
width = 0.25       # the width of the bars

fig, ax = plt.subplots()
rects1 = ax.bar(ind, pre_means, width, color='r', yerr=pre_sems)

loc_means = [np.mean(A2[case]) for case in my_list]
loc_sems = [sem(A2[case]) for case in my_list]
rects2 = ax.bar(ind+width, loc_means, width, color='y', yerr=loc_sems)

rwd_means = [np.mean(A3[case]) for case in my_list]
rwd_sems = [sem(A3[case]) for case in my_list]
rects3 = ax.bar(ind+2*width, rwd_means, width, color='b', yerr=rwd_sems)

# add some text for labels, title and axes ticks
ax.set_ylabel('Enrichment (%)')
ax.set_xticks(ind + width / 2)
ax.set_xticklabels(tuple(my_list), rotation='vertical')

ax.legend((rects1[0], rects2[0], rects3[0]),
          ('Before Learning', 'Locomotion', 'Reward'))

plt.axhline(y=baseline, linestyle='--', linewidth=2)
plt.savefig('All.pdf', format='pdf', dpi=600)

pre = [A1[case] for case in my_list]
loc = [A2[case] for case in my_list]
rwd = [A3[case] for case in my_list]