import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
from definitions import ROOT_FOLDER, FIGURE_FOLDER
from hippocampus.environments import HexWaterMaze
import matplotlib

en = HexWaterMaze(6)


def classify_strategy(trial_data, previous_platform):
    """Classify strategy as place, response or neither.

    :return:
    """
    threshold = 60
    last_state = trial_data.state.iloc[-1]
    if last_state == previous_platform:
        strategy = 'place'
    elif last_state == trial_data.platform.iloc[0]:
        strategy = 'response'

    if len(trial_data) > threshold:
        strategy = 'neither'
    return strategy


results_folder = os.path.join(ROOT_FOLDER, 'results', 'cue_vs_place_watermaze')
figure_dir = os.path.join(FIGURE_FOLDER, 'cue_vs_place_watermaze')
if not os.path.exists(figure_dir):
    os.makedirs(figure_dir)

all_strategies = []
for group in ['sham', 'hpc', 'dls', 'double']:

    escape_times = []
    res = []
    for agent_nr in range(15):

        df = pd.DataFrame.from_csv(os.path.join(results_folder, '{}_agent{}.csv'.format(group, agent_nr)))

        probe = df[df['trial type'] == 'probe']
        previous_platform = int(df.platform.iloc[0])
        probe['previous platform'] = np.repeat(previous_platform, len(probe))

        strategy = classify_strategy(probe, previous_platform)
        res.append(strategy)
        escape_times.append(len(probe))

    all_strategies.append(res)

#en.plot_occupancy_on_grid(probe)


df = pd.DataFrame(columns=['Sham', 'HPC', 'DLS', 'HPC + DLS'], data=np.array(all_strategies).T)
counts = df.stack().reset_index().pivot_table(index=['level_1', 0], aggfunc='count')

counts = counts.reset_index()
counts.columns = ['Lesion', 'Strategy', 'Count']

reshaped = counts.pivot(index='Lesion', columns='Strategy')
reshaped = reshaped.fillna(0)

reshaped = reshaped.reindex(['Sham', 'HPC', 'DLS', 'HPC + DLS'])

reshaped = reshaped['Count'] / 15 * 100

fig, ax = plt.subplots(figsize=(6,4))

bar = reshaped[['place', 'response', 'neither']].plot.bar(stacked=True, edgecolor='k',
                                        color=['black', 'white', 'white'], ax=ax,
                                                          width=.7)

hatches = [''] * 4 + [''] * 4 + ['.....'] * 4

for i, thisbar in enumerate(bar.patches):
    # Set a different hatch for each bar
    thisbar.set_hatch(hatches[i])

font = {'size': 12, 'weight': 'bold'}

matplotlib.rc('font', **font)


plt.title('Probe test (simulation)')
plt.ylabel('agents (%)')
plt.legend(['Place', 'Cue', 'Neither'], bbox_to_anchor=(1.05, .4))
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.ylabel('Agents (%)')

plt.tight_layout()
plt.xlabel('')

plt.savefig(os.path.join(figure_dir, 'watermaze_lesions_stackedbar.pdf'))
plt.show()

final_df = pd.DataFrame(columns=['Lesion', 'Place', 'Response', 'Neither'])



#locs = np.array([en.get_state_location(s) for s in probe['state']])
#plt.plot(locs[:, 0],  locs[:, 1])