from tqdm import tqdm
import pandas as pd
import statsmodels.formula.api as smf
import numpy as np
import os
import matplotlib.pyplot as plt

from hippocampus.analysis.daw_analysis import add_relevant_columns
from hippocampus.environments import TwoStepTask
from hippocampus.experiments.twostep import CombinedAgent
from hippocampus.agents import CombinedAgent as SpatialAgent
from hippocampus.environments import HexWaterMaze
from definitions import RESULTS_FOLDER

res_dir = os.path.join(RESULTS_FOLDER, 'mb_spatialmemory')
if not os.path.exists(res_dir):
    os.makedirs(res_dir)


def run_twostep_task(num_agents, params, lesion_hpc=False, **kwargs):
    for a in range(num_agents):
        if lesion_hpc:
            filename = 'twostep_agent{}_lesion'.format(a)
        elif 'inact_hpc' in kwargs and np.sum(kwargs['inact_hpc']) > 0:
            filename = 'twostep_partial_lesion_agent{}'.format(a)
        else:
            filename = 'twostep_agent{}'.format(a)

        if os.path.exists(os.path.join(res_dir, filename)):
            tqdm.write('Already done agent {}'.format(a))
            continue

        df = pd.DataFrame({})
        e = TwoStepTask()
        ag = CombinedAgent(e, A_alpha=params['A_alpha'][a], alpha1=params['alpha1'][a],
                           A_beta=params['A_beta'][a], beta1=params['beta1'][a], lesion_hpc=lesion_hpc,
                           inact_hpc=kwargs['inact_hpc'][a])

        for ep in tqdm(range(e.n_trials), leave=False):
            results = ag.one_episode()
            results['Trial'] = ep
            results['Agent'] = 0
            df = df.append(results, ignore_index=True)

        df.to_csv(os.path.join(res_dir, filename))


def determine_platform_seq(platform_states, env):
    indices = np.arange(len(platform_states))
    usage = np.zeros(len(platform_states))

    plat_seq = [np.random.choice(platform_states)]
    for sess in range(1, 11):
        distances = np.array([env.grid.distance(plat_seq[sess - 1], s) for s in platform_states])
        candidates = indices[np.logical_and(usage < 2, distances > env.grid.radius)]
        platform_idx = np.random.choice(candidates)
        plat_seq.append(platform_states[platform_idx])
        usage[platform_idx] += 1.

    return plat_seq


def run_watermaze(num_agents, params, lesion_hpc=False, **kwargs):
    g = HexWaterMaze(6)

    for i_a in tqdm(range(num_agents), desc='Agent'):

        if lesion_hpc:
            filename = 'spatial_agent{}_lesion'.format(i_a)
        elif 'inact_hpc' in kwargs and np.sum(kwargs['inact_hpc']) > 0:
            filename = 'spatial_partial_lesion_agent{}'.format(i_a)
        else:
            filename = 'spatial_agent{}'.format(i_a)

        if os.path.exists(os.path.join(res_dir, filename)):
            tqdm.write('Already done')
            continue

        #possible_platform_states = np.array([192, 185, 181, 174, 216, 210, 203, 197])  # for the r = 10 case
        possible_platform_states = np.array([48, 45, 42, 39, 60, 57, 54, 51])

        platform_sequence = determine_platform_seq(possible_platform_states, g)

        # intialise agent
        agent = SpatialAgent(g, init_sr='rw', A_alpha=params['A_alpha'][i_a], alpha1=params['alpha1'][i_a],
                           A_beta=params['A_beta'][i_a], beta1=params['beta1'][i_a], lesion_hpc=lesion_hpc,
                             inact_hpc=kwargs['inact_hpc'][i_a])
        agent_results = []

        value_func_results = np.zeros((11, g.nr_states, len(g.allo_angles)))
        total_trial_count = 0

        for ses in tqdm(range(11), desc='Session', leave=False):
            for trial in tqdm(range(4), leave=False, desc='Trial'):
                # every first trial of a session, change the platform location
                if trial == 0:
                    g.set_platform_state(platform_sequence[ses])

                    # compute mean value function

                    for s in g.state_indices:
                        for io, o in enumerate(g.allo_angles):
                            value_func_results[ses, s, io] = np.mean(agent.compute_Q(s, o, agent.p_sr)[0])

                res = agent.one_episode(random_policy=False)
                res['trial'] = trial
                res['escape time'] = res.time.max()
                res['session'] = ses
                res['total trial'] = total_trial_count
                agent_results.append(res)
                total_trial_count += 1

        agent_df = pd.concat(agent_results)
        agent_df['total time'] = np.arange(len(agent_df))

        agent_df.to_csv(os.path.join(res_dir, filename))
        np.save(os.path.join(res_dir, filename + 'value_funcs.npy'), value_func_results)


def get_states_around_last_platform(env, rec_depth=2):
    platform_state = env.previous_platform_state
    states = [platform_state]
    for i in range(rec_depth):
        added_states = []
        for s in states:
            neighbours = np.flatnonzero(env.adjacency_graph[s])
            for n in neighbours:
                if n not in states and n not in added_states:
                    added_states.append(n)
        states += added_states
    return states


# model based model free analysis
def get_model_weights(data):
    add_relevant_columns(data)
    data['Stay'] = data['Stay'].astype('int')
    data = data[['Stay', 'PreviousReward', 'PreviousTransition']]
    mod = smf.logit(formula='Stay ~ PreviousTransition * PreviousReward', data=data)
    res = mod.fit()
    model_based_weight = -res.params['PreviousTransition[T.rare]:PreviousReward']
    model_free_weight = res.params['PreviousReward']
    return model_based_weight, model_free_weight


if __name__ == "__main__":
    # Sample some parameters

    n_agents = 20
    np.random.seed(10)

    # loop over different parameter values for the transitions from MF to SR and vice versa
    #A_alpha = np.linspace(.5, 5, n_agents)
    #alpha1 = np.linspace(.01, 2, n_agents)
    #A_beta = np.linspace(2, .5, n_agents)
    #beta1 = np.linspace(.3, .1, n_agents)

    A_alpha = np.linspace(.5, .5 + .7 * (5-.5), n_agents)
    alpha1 = np.linspace(.01, .01 + .7 * (2-.01), n_agents)
    A_beta = np.linspace(5, 2 - .7* (2 -.5), n_agents)
    beta1 = np.linspace(.3, .3 - .7 * (.3 - .1), n_agents)

    parameters = pd.DataFrame({})
    parameters['A_alpha'] = A_alpha
    parameters['alpha1'] = alpha1
    parameters['A_beta'] = A_beta
    parameters['beta1'] = beta1

    parameters.to_csv(os.path.join(res_dir, 'params.csv'))

    # healthy control
    tqdm.write('\n RUNNING TWO-STEP TASK \n')
    run_twostep_task(n_agents, parameters, inact_hpc=np.zeros(n_agents))
    tqdm.write('\n RUNNING WATER MAZE \n')
    run_watermaze(n_agents, parameters, inact_hpc=np.zeros(n_agents))

    # hippocampal partial lesion
    min_inact_prop = .6
    np.random.seed(5)
    inact_hpc = np.random.uniform(min_inact_prop, 1., n_agents)

    tqdm.write('\n RUNNING TWO-STEP TASK (HPC PARTIAL LESION) \n')
    run_twostep_task(n_agents, parameters, lesion_hpc=False, inact_hpc=inact_hpc)
    tqdm.write('\n RUNNING WATER MAZE (HPC PARTIAL LESION) \n')
    run_watermaze(n_agents, parameters, lesion_hpc=False, inact_hpc=inact_hpc)


    tqdm.write('done')