import os
import os.path as op
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

from definitions import PEARCE_RESULTS_DIR, FIGURE_FOLDER

data_directories = {
    'lesion': op.join(PEARCE_RESULTS_DIR, 'lesion', '2019-11-06 01:22:41.370359'),  # 2019-11-05 10:46:28.222103'),
    'control': op.join(PEARCE_RESULTS_DIR, 'control', '2019-11-05 00:39:03.764484')  # '2019-11-05 00:39:03.764484')
}
figure_location = op.join(FIGURE_FOLDER, 'pearce')


def get_immediate_subdirectories(a_dir):
    return [name for name in os.listdir(a_dir)
            if os.path.isdir(os.path.join(a_dir, name))]


def create_summary_file(results_folder, n_agents=None):
    if n_agents is None:
        params = pd.read_csv(op.join(results_folder, 'params.csv'))
        n_agents = params['n_agents'].iloc[0]
    else:
        n_agents = n_agents

    all_data = []
    for ag in tqdm(range(n_agents), desc='loading data...', leave=False):
        df = pd.read_csv(op.join(results_folder, 'agent{}.csv'.format(ag)))
        summary = df.pivot_table(index=['agent', 'session', 'trial'], aggfunc='mean')
        all_data.append(summary)

    df = pd.concat(all_data)
    df['platform location'] = df['platform'].astype('category')
    df.to_csv(op.join(results_folder, 'summary.csv'))


def create_summaries():
    control_folder = op.join(PEARCE_RESULTS_DIR, 'control')
    lesion_folder = op.join(PEARCE_RESULTS_DIR, 'lesion')
    all_subs = get_immediate_subdirectories(control_folder)
    all_subs_l = get_immediate_subdirectories(lesion_folder)
    all_dirs = [op.join(control_folder, d) for d in all_subs] + [op.join(lesion_folder, d) for d in all_subs_l]

    for d in tqdm(all_dirs):
        if not op.isfile(op.join(d, 'summary.csv')):
            create_summary_file(d)
        else:
            tqdm.write('file already exists')


def load_summary(group):
    directory = data_directories[group]
    df = pd.read_csv(op.join(directory, 'summary.csv'))
    df['group'] = group
    df['Trial'] = df['trial'].astype('str')
    return df


def plot_escape_time():
    color_pal = sns.color_palette()

    df_l = load_summary('lesion')
    df_c = load_summary('control')

    agents_l = df_l.agent.unique()
    agents_c = df_c.agent.unique()

    df_l = df_l.set_index(['agent', 'session', 'trial'])
    df_c = df_c.set_index(['agent', 'session', 'trial'])

    first_l = df_l.loc[(list(agents_l), list(range(11)), 0)]
    last_l = df_l.loc[(list(agents_l), list(range(11)), 3)]
    first_c = df_c.loc[(list(agents_c), list(range(11)), 0)]
    last_c = df_c.loc[(list(agents_c), list(range(11)), 3)]

    # plot the escape time per session for trials 1 and 4
    fig, ax = plt.subplots()
    sns.lineplot(ax=ax, data=first_l.reset_index(), x='session', y='escape time', ci=None, c=color_pal[3], linewidth=3)
    sns.lineplot(ax=ax, data=last_l.reset_index(), x='session', y='escape time', ci=None, c=color_pal[3], linewidth=3)

    sns.lineplot(ax=ax, data=first_c.reset_index(), x='session', y='escape time', ci=None, c=color_pal[0], linewidth=3)
    sns.lineplot(ax=ax, data=last_c.reset_index(), x='session', y='escape time', ci=None, c=color_pal[0], linewidth=3)

    ax.lines[1].set_linestyle("--")
    ax.lines[3].set_linestyle("--")

    plt.legend(['HPC lesion - trial 1',
                'HPC lesion - trial 4',
                'Control - trial 1',
                'Control - trial 4'])

    if not op.exists(figure_location):
        os.makedirs(figure_location)
    plt.savefig(os.path.join(figure_location, 'pearce_escapetime_firstlast'), format='pdf')
    plt.show()
    plt.close()


if __name__ == '__main__':
    from hippocampus.environments import HexWaterMaze

    # create_summary_file(data_directories['lesion'])
    plot_escape_time()

    maze = HexWaterMaze(10)
    # pick example agent

    for group in ['control', 'lesion']:

        agents = {'control': 89, 'lesion': 12}

        session = 6

        df = pd.read_csv(op.join(data_directories[group], 'agent{}.csv'.format(agents[group])))

        s6t0 = df[np.logical_and(df.trial == 0, df.session == session)]
        previous_platform = df[df.session == session - 1].platform.iloc[0]
        s6t0['previous platform'] = previous_platform
        current_platform = s6t0.platform.iloc[0]
        maze.plot_occupancy_on_grid(s6t0, alpha=1., show_state_idx=False)

        plt.savefig(op.join(figure_location, 'pearce_occupancy_{}_agent{}_session{}.pdf'.format(group,
                                                                                                agents[group],
                                                                                                session)), format='pdf')
        plt.show()