In [None]:
import sys 
sys.path.append('../..')
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from hippocampus.analysis.daw_analysis import add_relevant_columns
import statsmodels.formula.api as smf
import seaborn as sns
from tqdm import tqdm_notebook as tqdm

from definitions import RESULTS_FOLDER
from hippocampus.environments import HexWaterMaze

In [None]:
en = HexWaterMaze(6)

In [None]:
res_dir = os.path.join(RESULTS_FOLDER, 'mb_spatialmemory')

params = pd.read_csv(os.path.join(res_dir,'params.csv'))


In [None]:
params.hist()

In [None]:
df = pd.read_csv(os.path.join(res_dir, 'spatial_agent0'))

In [None]:
df.columns

In [None]:
df.head()

In [None]:
df.plot(x='total trial', y='escape time')

In [None]:
data = df[['trial', 'total trial', 'session', 'escape time','platform', 'state']]

In [None]:
n_agents = 19

In [None]:
def get_first_trial_info(data):
    d2 = data.pivot_table(index='total trial')
    d2['previous platform'] = d2['platform'].shift(1)
    first_trials = d2[d2['trial']==0]
    first_trials = first_trials.drop(0).pivot_table(index='session')
    return first_trials

In [None]:
def get_surrounding_states(state, env, rec_depth=2):
    surrounding_states = [state]
    for i in range(rec_depth):
        added_states = []
        for s in surrounding_states:
            neighbours = np.flatnonzero(env.adjacency_graph[s])
            for n in neighbours:
                if n not in surrounding_states and n not in added_states:
                    added_states.append(n)
        surrounding_states += added_states
    return surrounding_states


In [None]:
df = pd.read_csv(os.path.join(res_dir, 'spatial_agent{}'.format(1)))

In [None]:
def get_allo_index(agent_data, env):
    """Get the allocentricness index, defined as the amount of time spent around the previous platform location during
    first trials of sessions.

    :param agent_data:
    :param env:
    :return:
    """
    first_trials = get_first_trial_info(agent_data)
    prop_times = []
    for ses in range(1, 11):
        states = np.sort(agent_data[(agent_data.session == ses) & (agent_data.trial == 0)]['state'])
        previous_platform = first_trials['previous platform'][ses]
        surrounding_states = np.sort(np.array(get_surrounding_states(int(previous_platform), env)))

        time_spent = np.isin(states, surrounding_states).sum()
        prop_times.append(time_spent)

    return np.mean(prop_times)



In [None]:
scores = []

data = []

for a in tqdm(range(n_agents)):
    df = pd.read_csv(os.path.join(res_dir, 'spatial_agent{}'.format(a)))
    data.append(df)
    #ft = get_first_trial_info(df)
    scores.append(get_allo_index(df, en))
    

In [None]:
get_first_trial_info(data[0])

In [None]:
df.iloc[1]

In [None]:
# load model-based data

In [None]:
def get_model_weights(data):
    add_relevant_columns(data)
    data['Stay'] = data['Stay'].astype('int')
    data = data[['Stay', 'PreviousReward', 'PreviousTransition']]
    mod = smf.logit(formula='Stay ~ PreviousTransition * PreviousReward', data=data)
    res = mod.fit()
    model_based_weight = -res.params['PreviousTransition[T.rare]:PreviousReward']
    model_free_weight = res.params['PreviousReward']
    return model_based_weight, model_free_weight


In [None]:
weights_mb = []
for a in range(n_agents):
    df = pd.read_csv(os.path.join(res_dir, 'twostep_agent{}'.format(a)))
    mb_weight, mf_weight = get_model_weights(df)
    weights_mb.append(mb_weight)


In [None]:
allocentric_scores = []
for a in tqdm(range(n_agents)):
    df = pd.read_csv(os.path.join(res_dir, 'spatial_agent{}'.format(a)))
    allocentric_scores.append(get_allo_index(df, en))

mb_scores = []
for a in range(n_agents):
    df = pd.read_csv(os.path.join(res_dir, 'twostep_agent{}'.format(a)))
    mb_weight, mf_weight = get_model_weights(df)
    mb_scores.append(mb_weight)

allocentric_scores_lesion = []
for a in tqdm(range(n_agents)):
    df = pd.read_csv(os.path.join(res_dir, 'spatial_agent{}_lesion'.format(a)))
    allocentric_scores_lesion.append(get_allo_index(df, en))

mb_scores_lesion = []
for a in range(n_agents):
    df = pd.read_csv(os.path.join(res_dir, 'twostep_agent{}_lesion'.format(a)))
    mb_weight, mf_weight = get_model_weights(df)
    mb_scores_lesion.append(mb_weight)


In [None]:
plt.scatter(np.log(allocentric_scores), mb_scores, c=sns.cubehelix_palette(len(scores)))

In [None]:
plt.figure(figsize=(4,4))
sns.regplot(allocentric_scores, mb_scores)

plt.ylabel('Model based index')
plt.xlabel('Allocentricness score')

In [None]:
colpal = sns.color_palette()

In [None]:
fig, ax = plt.subplots()

plt.xlim([4.5,0])
plt.ylim([-1,4.5])
sns.regplot(np.log(allocentric_scores), mb_scores,color=colpal[1])
sns.regplot(np.log(allocentric_scores_lesion), mb_scores_lesion, color=colpal[4])

plt.ylabel('Model based index')
plt.xlabel('Allocentricness score')
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)




In [None]:
score_data = pd.DataFrame({})
score_data['model based'] = np.concatenate([mb_scores, mb_scores_lesion])
score_data['allocentric'] = np.concatenate([np.log(allocentric_scores), np.log(allocentric_scores_lesion)])
score_data['group'] = ['control'] * n_agents + ['lesion'] * n_agents


sns.lmplot(y='model based', x='allocentric', data=score_data, hue='group', palette=[colpal[1], colpal[4]])

In [None]:
sns.palplot(sns.color_palette())

In [None]:
arr = np.load(os.path.join(res_dir, 'spatial_agent0value_funcs.npy'))

In [None]:
arr.shape

In [None]:
avg = arr.mean(axis=2)

In [None]:
avg.shape

In [None]:
[np.argmax(a) for a in avg]

In [None]:
np.argmax(avg, axis=1)

In [None]:
df.groupby('session')['platform'].mean().astype('int').to_numpy()

In [None]:
en.grid.distance(54,86)

In [None]:
arr.mean(axis=2).argmax(axis=1)