"""
What needs to happen here:
- simulation of devaluation on plus maze


"""
from hippocampus.environments import DevaluationPlusMaze, Environment
from hippocampus.agents import CombinedAgent, LandmarkLearningAgent
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from hippocampus import utils
from definitions import ROOT_FOLDER
import os


groups = {0: 'control',
          1: 'inactivate_HPC'}


group = groups[1]


if group == 'inactivate_HPC':
    inactivate_HPC = True
else:
    inactivate_HPC = False
if group =='inactivate_DLS':
    inactivate_DLS = True
else:
    inactivate_DLS = False


# save location
results_folder = os.path.join(ROOT_FOLDER, 'results', 'plusmaze_deval', group)
figure_folder = os.path.join(results_folder, 'figures')
if not os.path.exists(results_folder):
    os.makedirs(results_folder)
    os.makedirs(figure_folder)


class LM(LandmarkLearningAgent):
    def __init__(self, environment=Environment(), learning_rate=.1, gamma=.9, eta=.03, beta=10):
        super().__init__(environment=environment, learning_rate=learning_rate, gamma=gamma, eta=eta, beta=beta)

    def get_feature_rep(self, state, orientation):
        distance = self.get_distance_to_landmark(state)
        angle = self.angle_to_landmark(state, orientation)
        response = self.features.compute_response(distance, angle)
        return np.append(response, state == 6)


class CA(CombinedAgent):
    def __init__(self, env=Environment(), init_sr='rw', lesion_dls=False, lesion_hpc=False, gamma=.95, eta=.03,
                 inv_temp=10, learning_rate=.1, inact_hpc=0., inact_dls=0., A_alpha=1., A_beta=.5,
                 alpha1=.01, beta1=.1):
        super().__init__(env=env, init_sr=init_sr, lesion_dls=lesion_dls, lesion_hpc=lesion_hpc, gamma=gamma, eta=eta,
                 inv_temp=inv_temp, learning_rate=learning_rate, inact_hpc=inact_hpc, inact_dls=inact_dls, A_alpha=A_alpha, A_beta=A_beta,
                 alpha1=alpha1, beta1=beta1)
        self.DLS = LM(self.env, eta=self.eta)
        self.weights = np.zeros((self.DLS.features.n_cells + 1, self.env.nr_actions))

    def one_episode(self, random_policy=False, setp_sr=None, random_start_loc=True):
        if self.lesion_striatum and self.lesion_hippocampus:
            random_policy = True
        time_limit = 1000
        self.env.reset(random_loc=random_start_loc)
        t = 0
        s = self.env.get_current_state()
        cumulative_reward = 0

        possible_orientations = np.round(np.degrees(self.env.action_directions))
        angles = []
        for i, o in enumerate(possible_orientations):
            angle = utils.angle_to_landmark(self.env.get_state_location(s), self.env.landmark_location, np.radians(o))
            angles.append(angle)
        orientation = possible_orientations[np.argmin(np.abs(angles))]

        # get MF system features
        f = self.DLS.get_feature_rep(s, orientation)
        Q_mf = self.weights.T @ f

        results = pd.DataFrame({})

        results = results.append({'time': t,
                                  'reward': 0,
                                  'SPE': 0,
                                  'RPE': 0,
                                  'HPC reliability': self.HPC.reliability,
                                  'DLS reliability': self.DLS.reliability,
                                  'alpha': self.get_alpha(self.DLS.reliability),
                                  'beta': self.get_beta(self.DLS.reliability),
                                  'state': s,
                                  'P(SR)': self.p_sr,
                                  'choice': self.current_choice,
                                  'M_hat': self.HPC.M_hat.flatten(),
                                  'R_hat': self.HPC.R_hat.copy(),
                                  'Q_mf': Q_mf,
                                  'platform': self.env.get_goal_state()}, ignore_index=True)

        while not self.env.is_terminal(s) and t < time_limit:
            if setp_sr is None:
                self.update_p_sr()
            else:
                self.p_sr = setp_sr

            # select action
            Q_combined, Q_allo = self.compute_Q(s, orientation, self.p_sr)
            if random_policy:
                allo_a = np.random.choice(list(range(self.env.nr_actions)))
            else:
                allo_a = self.softmax_selection(s, Q_combined)
            ego_a = self.get_ego_action(allo_a, orientation)

            if s == 6:
                allo_a = 0
                ego_a = self.get_ego_action(allo_a, orientation)

            # act
            next_state, reward = self.env.act(allo_a)

            # get MF state representation
            orientation = self.DLS.get_orientation(s, next_state, orientation)
            next_f = self.DLS.get_feature_rep(next_state, orientation)

            # SR updates
            SPE = self.HPC.compute_error(next_state, s)
            delta_M = self.HPC.learning_rate * SPE
            self.HPC.M_hat[s, :] += delta_M
            self.HPC.update_R(next_state, reward)

            # MF updates
            next_Q = self.weights.T @ next_f
            if self.env.is_terminal(next_state):
                RPE = reward - Q_mf[ego_a]
            else:
                RPE = reward + self.gamma * np.max(next_Q) - Q_mf[ego_a]

            self.weights[:, ego_a] = self.weights[:, ego_a] + self.learning_rate * RPE * f

            # Reliability updates
            if self.env.is_terminal(next_state):
                self.DLS.update_reliability(RPE)
                self.HPC.update_reliability(SPE, s)

            s = next_state
            f = next_f
            Q_mf = next_Q
            t += 1
            cumulative_reward += reward

            results = results.append({'time': t,
                                      'reward': reward,
                                      'SPE': SPE,
                                      'RPE': RPE,
                                      'HPC reliability': self.HPC.reliability,
                                      'DLS reliability': self.DLS.reliability,
                                      'alpha': self.get_alpha(self.DLS.reliability),
                                      'beta': self.get_beta(self.DLS.reliability),
                                      'state': s,
                                      'P(SR)': self.p_sr,
                                      'choice': self.current_choice,
                                      'M_hat': self.HPC.M_hat.copy(),
                                      'R_hat': self.HPC.R_hat.copy(),
                                      'Q_mf': Q_mf,
                                      'Q_allo': Q_allo,
                                      'Q': Q_combined,
                                      'features': f.copy(),
                                      'weights': self.weights.copy(),
                                      'platform': self.env.get_goal_state(),
                                      'landmark': self.env.landmark_location}, ignore_index=True)
        return results


n_agents = 20
n_trials = 30

pm = DevaluationPlusMaze()
behavioural_scores = pd.DataFrame({})

for ag in tqdm(range(n_agents)):
    agent = CA(env=pm, lesion_hpc=inactivate_HPC, lesion_dls=inactivate_DLS, learning_rate=.07, inv_temp=4)  #inv_temp= 5
    agent_results = []

    for trial in tqdm(range(n_trials), leave=False):
        if trial == n_trials - 3 or trial == n_trials -1:
            agent.env.toggle_probe_trial()
        elif trial == n_trials - 2:
            agent.env.toggle_training_trial()
            agent.env.toggle_devaluation()
        else:
            agent.env.toggle_training_trial()

        res = agent.one_episode(random_policy=False)
        res['trial'] = trial
        res['escape time'] = res.time.max()
        res['goal location'] = agent.env.get_goal_state()
        res['total reward'] = res['reward'].sum()
        last_state = res['state'].iloc[-2]
        res['last state'] = last_state
        res['trial type'] = agent.env.trial_type

        if agent.env.trial_type == 'probe':
            if last_state == agent.env.rewarded_terminal:
                res['score'] = 'place'
            elif last_state == 5:
                res['score'] = 'response'
            else:
                raise ValueError('dkfjkdf')

            behavioural_scores = behavioural_scores.append({'agent': ag,
                                                            'trial': trial,
                                                            'score': res['score'].iloc[0],
                                                            'group': None}, ignore_index=True)
        else:
            if last_state == agent.env.rewarded_terminal:
                res['score'] = 'correct'
            elif last_state == 5:
                res['score'] = 'incorrect'
        agent_results.append(res)

    df = pd.concat(agent_results)
    df['agent'] = ag
    df.to_csv(os.path.join(results_folder, 'agent{}.csv'.format(ag)))


behavioural_scores.to_csv(os.path.join(results_folder, 'summary.csv'))

agg = behavioural_scores.pivot_table(index=['trial', 'score'], aggfunc=len, margins=True)

plt.figure()
ax = sns.barplot(x='trial', y='agent', hue='score', data=agg.reset_index())

plt.xticks([0, 1, 2], ['non-deval', 'deval', '_'])
plt.show()