import numpy as np
import pandas as pd
import statsmodels.formula.api as smf
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats

from definitions import RESULTS_FOLDER
from hippocampus.analysis.daw_analysis import add_relevant_columns
from hippocampus.environments import HexWaterMaze


def get_first_trial_info(data):
    d2 = data.pivot_table(index='total trial')
    d2['previous platform'] = d2['platform'].shift(1)
    first_trials = d2[d2['trial']==0]
    first_trials = first_trials.drop(0).pivot_table(index='session')
    return first_trials


def get_surrounding_states(state, env, rec_depth=4):
    surrounding_states = [state]
    for i in range(rec_depth):
        added_states = []
        for s in surrounding_states:
            neighbours = np.flatnonzero(env.adjacency_graph[s])
            for n in neighbours:
                if n not in surrounding_states and n not in added_states:
                    added_states.append(n)
        surrounding_states += added_states
    return surrounding_states


def get_allo_index(agent_data, env):
    """Get the allocentricness index, defined as the amount of time spent around the previous platform location during
    first trials of sessions.

    :param agent_data:
    :param env:
    :return:
    """
    first_trials = get_first_trial_info(agent_data)
    prop_times = []
    for ses in range(1, 11):
        states = np.sort(agent_data[(agent_data.session == ses) & (agent_data.trial == 0)]['state'])
        previous_platform = first_trials['previous platform'][ses]
        surrounding_states = np.sort(np.array(get_surrounding_states(int(previous_platform), env)))

        time_spent = np.isin(states, surrounding_states).sum()
        prop_times.append(time_spent)

    return np.mean(prop_times)


def get_distance_err(agent_data, val_func_data, env):
    """Get the (neg) allocentricness index, defined as distance error between the top of the value function and the
    previous platform location during the first trials of sessions.

    :param agent_data:
    :param env:
    :return:
    """
    distances = []

    #  value_func_tops = val_func_data.max(axis=2).argmax(axis=1)
    platform_locs = agent_data.groupby('session')['platform'].mean().astype('int').to_numpy()

    val_funcs = val_func_data.max(axis=2)

    n = 7
    for ses in range(1, 11):
        top_n_states = val_funcs[ses].argsort()[-n:]

        loc = np.array([0., 0.])
        for s in top_n_states:
            loc += np.array(en.grid.cart_coords[s])
        loc = loc / n

        dist = np.linalg.norm(loc - en.grid.cart_coords[platform_locs[ses-1]])
        distances.append(dist)

    #val_funcs = val_func_data.max(axis=2)

    #for ses in range(10):
    #    env.set_platform_state(platform_locs[ses])
    #    opt_pol, val = value_iteration(env)
    #
    #    dist = np.linalg.norm(val_funcs[ses+1] - val)
    #    distances.append(dist)


    #distances = []
    #for ses in range(1, 11):
    #    distances.append(env.grid.distance(value_func_tops[ses], platform_locs[ses-1]))

    return np.mean(distances) #np.sort(distances)[8:].mean()


def get_model_weights(data):
    add_relevant_columns(data)
    data['Stay'] = data['Stay'].astype('int')
    data = data[['Stay', 'PreviousReward', 'PreviousTransition']]
    mod = smf.logit(formula='Stay ~ PreviousTransition * PreviousReward', data=data)
    res = mod.fit()
    model_based_weight = -res.params['PreviousTransition[T.rare]:PreviousReward']
    model_free_weight = res.params['PreviousReward']
    return model_based_weight, model_free_weight


def compute_scores():
    allocentric_scores = []
    allo_time = []
    for a in tqdm(range(n_agents)):
        df = pd.read_csv(os.path.join(res_dir, 'spatial_agent{}'.format(a)))
        val_funcs = np.load(os.path.join(res_dir, 'spatial_agent{}value_funcs.npy'.format(a)))
        allocentric_scores.append(get_distance_err(df, val_funcs, en))
        allo_time.append(get_allo_index(df, en))

    mb_scores = []
    for a in range(n_agents):
        df = pd.read_csv(os.path.join(res_dir, 'twostep_agent{}'.format(a)))
        mb_weight, mf_weight = get_model_weights(df)
        mb_scores.append(mb_weight)

    allocentric_scores_lesion = []
    allo_time_lesion = []
    for a in tqdm(range(n_agents)):
        df = pd.read_csv(os.path.join(res_dir, 'spatial_partial_lesion_agent{}'.format(a)))
        val_funcs = np.load(os.path.join(res_dir, 'spatial_partial_lesion_agent{}value_funcs.npy'.format(a)))
        allocentric_scores_lesion.append(get_distance_err(df, val_funcs,  en))
        allo_time_lesion.append(get_allo_index(df, en))

    mb_scores_lesion = []
    for a in range(n_agents):
        df = pd.read_csv(os.path.join(res_dir, 'twostep_partial_lesion_agent{}'.format(a)))
        mb_weight, mf_weight = get_model_weights(df)
        mb_scores_lesion.append(mb_weight)

    score_data = pd.DataFrame({})
    score_data['model based'] = np.concatenate([mb_scores, mb_scores_lesion])
    score_data['allocentric'] = np.concatenate([allocentric_scores, allocentric_scores_lesion])
    score_data['allo time'] = np.concatenate([np.log(allo_time), np.log(allo_time_lesion)])
    score_data['group'] = ['control'] * n_agents + ['lesion'] * n_agents

    score_data.to_csv(os.path.join(res_dir, 'score_data.csv'))
    return score_data


def get_correlation_diff_z(r1, r2, n1, n2):
    z1 = np.arctanh(r1)
    z2 = np.arctanh(r2)
    denom = np.sqrt(1 / (n1 - 3) + 1 / (n2 - 3))
    return (z1 - z2) / denom


def compute_correlations(ctrl, lesion):
    # correlations
    r_ctrl, p_ctrl = stats.pearsonr(ctrl['model based'], ctrl['allocentric'])
    r_les, p_les = stats.pearsonr(lesion['model based'], lesion['allocentric'])

    z = get_correlation_diff_z(r_ctrl, r_les, len(ctrl), len(lesion))
    p_diff = stats.norm.sf(abs(z)) * 2
    results = {'z_control': np.arctanh(r_ctrl),
               'p_control': p_ctrl,
               'z_lesion': np.arctanh(r_les),
               'p_lesion': p_les,
               'z_diff': z,
               'p_diff': p_diff}
    return results


def fisher_z(pearson_r):
    z = np.log((1+pearson_r) / (1-pearson_r))
    return z


if __name__ == '__main__':
    import matplotlib
    from scipy import stats

    font = {'family': 'normal',
            'size': 15}

    matplotlib.rc('font', **font)


    n_agents = 20
    colpal = sns.color_palette()

    en = HexWaterMaze(6)

    res_dir = os.path.join(RESULTS_FOLDER, 'mb_spatialmemory')
    params = pd.read_csv(os.path.join(res_dir, 'params.csv'))

    score_data_file = os.path.join(res_dir, 'score_data.csv')

    if not os.path.exists(score_data_file):
        score_data = compute_scores()
    else:
        score_data = pd.read_csv(score_data_file)

    fig, ax = plt.subplots()

    lesion_data = score_data[score_data.group == 'lesion']
    control_data = score_data[score_data.group == 'control']

    # do stats
    r_cont, p_cont = stats.pearsonr(control_data['allocentric'], control_data['model based'])
    r_les, p_les = stats.pearsonr(lesion_data['allocentric'], lesion_data['model based'])
    z_cont = fisher_z(r_cont)
    z_les = fisher_z(r_les)

    z_diff = (z_les - z_cont) / (np.sqrt(
        1/(n_agents-3) + 1/(n_agents-3)
    ))
    p_diff = (1 - stats.norm.cdf(np.abs(z_diff))) * 2

    stats_df = pd.DataFrame({'control': [z_cont, p_cont], 'lesion': [z_les, p_les], 'difference': [z_diff, p_diff]},
                            index=['z', 'p']).T
    stats_df.to_csv(os.path.join(res_dir, 'stats.csv'))


    plt.xlim([0, 6])
    plt.ylim([-1, 4])
    sns.regplot(control_data['allocentric'], control_data['model based'], color=colpal[1], ci=80)
    sns.regplot(lesion_data['allocentric'], lesion_data['model based'], color=colpal[4], ci=80)

    plt.ylabel('Model-based estimate')
    plt.xlabel('Boundary distance error (a.u.)')
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)


    plt.tight_layout()
    plt.savefig(os.path.join(RESULTS_FOLDER, 'figures', 'mb_spatial_cov.pdf'))
    plt.show()
    print('done')