{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys \n", "sys.path.append('../..')\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import os\n", "from hippocampus.analysis.daw_analysis import add_relevant_columns\n", "import statsmodels.formula.api as smf\n", "import seaborn as sns\n", "from tqdm import tqdm_notebook as tqdm\n", "\n", "from definitions import RESULTS_FOLDER\n", "from hippocampus.environments import HexWaterMaze" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "en = HexWaterMaze(6)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "res_dir = os.path.join(RESULTS_FOLDER, 'mb_spatialmemory')\n", "\n", "params = pd.read_csv(os.path.join(res_dir,'params.csv'))\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "params.hist()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df = pd.read_csv(os.path.join(res_dir, 'spatial_agent0'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df.columns" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df.plot(x='total trial', y='escape time')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = df[['trial', 'total trial', 'session', 'escape time','platform', 'state']]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "n_agents = 19" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_first_trial_info(data):\n", " d2 = data.pivot_table(index='total trial')\n", " d2['previous platform'] = d2['platform'].shift(1)\n", " first_trials = d2[d2['trial']==0]\n", " first_trials = first_trials.drop(0).pivot_table(index='session')\n", " return first_trials" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_surrounding_states(state, env, rec_depth=2):\n", " surrounding_states = [state]\n", " for i in range(rec_depth):\n", " added_states = []\n", " for s in surrounding_states:\n", " neighbours = np.flatnonzero(env.adjacency_graph[s])\n", " for n in neighbours:\n", " if n not in surrounding_states and n not in added_states:\n", " added_states.append(n)\n", " surrounding_states += added_states\n", " return surrounding_states\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df = pd.read_csv(os.path.join(res_dir, 'spatial_agent{}'.format(1)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_allo_index(agent_data, env):\n", " \"\"\"Get the allocentricness index, defined as the amount of time spent around the previous platform location during\n", " first trials of sessions.\n", "\n", " :param agent_data:\n", " :param env:\n", " :return:\n", " \"\"\"\n", " first_trials = get_first_trial_info(agent_data)\n", " prop_times = []\n", " for ses in range(1, 11):\n", " states = np.sort(agent_data[(agent_data.session == ses) & (agent_data.trial == 0)]['state'])\n", " previous_platform = first_trials['previous platform'][ses]\n", " surrounding_states = np.sort(np.array(get_surrounding_states(int(previous_platform), env)))\n", "\n", " time_spent = np.isin(states, surrounding_states).sum()\n", " prop_times.append(time_spent)\n", "\n", " return np.mean(prop_times)\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "scores = []\n", "\n", "data = []\n", "\n", "for a in tqdm(range(n_agents)):\n", " df = pd.read_csv(os.path.join(res_dir, 'spatial_agent{}'.format(a)))\n", " data.append(df)\n", " #ft = get_first_trial_info(df)\n", " scores.append(get_allo_index(df, en))\n", " " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "get_first_trial_info(data[0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df.iloc[1]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# load model-based data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_model_weights(data):\n", " add_relevant_columns(data)\n", " data['Stay'] = data['Stay'].astype('int')\n", " data = data[['Stay', 'PreviousReward', 'PreviousTransition']]\n", " mod = smf.logit(formula='Stay ~ PreviousTransition * PreviousReward', data=data)\n", " res = mod.fit()\n", " model_based_weight = -res.params['PreviousTransition[T.rare]:PreviousReward']\n", " model_free_weight = res.params['PreviousReward']\n", " return model_based_weight, model_free_weight\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "weights_mb = []\n", "for a in range(n_agents):\n", " df = pd.read_csv(os.path.join(res_dir, 'twostep_agent{}'.format(a)))\n", " mb_weight, mf_weight = get_model_weights(df)\n", " weights_mb.append(mb_weight)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "allocentric_scores = []\n", "for a in tqdm(range(n_agents)):\n", " df = pd.read_csv(os.path.join(res_dir, 'spatial_agent{}'.format(a)))\n", " allocentric_scores.append(get_allo_index(df, en))\n", "\n", "mb_scores = []\n", "for a in range(n_agents):\n", " df = pd.read_csv(os.path.join(res_dir, 'twostep_agent{}'.format(a)))\n", " mb_weight, mf_weight = get_model_weights(df)\n", " mb_scores.append(mb_weight)\n", "\n", "allocentric_scores_lesion = []\n", "for a in tqdm(range(n_agents)):\n", " df = pd.read_csv(os.path.join(res_dir, 'spatial_agent{}_lesion'.format(a)))\n", " allocentric_scores_lesion.append(get_allo_index(df, en))\n", "\n", "mb_scores_lesion = []\n", "for a in range(n_agents):\n", " df = pd.read_csv(os.path.join(res_dir, 'twostep_agent{}_lesion'.format(a)))\n", " mb_weight, mf_weight = get_model_weights(df)\n", " mb_scores_lesion.append(mb_weight)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.scatter(np.log(allocentric_scores), mb_scores, c=sns.cubehelix_palette(len(scores)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(4,4))\n", "sns.regplot(allocentric_scores, mb_scores)\n", "\n", "plt.ylabel('Model based index')\n", "plt.xlabel('Allocentricness score')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "colpal = sns.color_palette()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots()\n", "\n", "plt.xlim([4.5,0])\n", "plt.ylim([-1,4.5])\n", "sns.regplot(np.log(allocentric_scores), mb_scores,color=colpal[1])\n", "sns.regplot(np.log(allocentric_scores_lesion), mb_scores_lesion, color=colpal[4])\n", "\n", "plt.ylabel('Model based index')\n", "plt.xlabel('Allocentricness score')\n", "ax.spines['right'].set_visible(False)\n", "ax.spines['top'].set_visible(False)\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "score_data = pd.DataFrame({})\n", "score_data['model based'] = np.concatenate([mb_scores, mb_scores_lesion])\n", "score_data['allocentric'] = np.concatenate([np.log(allocentric_scores), np.log(allocentric_scores_lesion)])\n", "score_data['group'] = ['control'] * n_agents + ['lesion'] * n_agents\n", "\n", "\n", "sns.lmplot(y='model based', x='allocentric', data=score_data, hue='group', palette=[colpal[1], colpal[4]])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.palplot(sns.color_palette())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "arr = np.load(os.path.join(res_dir, 'spatial_agent0value_funcs.npy'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "arr.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "avg = arr.mean(axis=2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "avg.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "[np.argmax(a) for a in avg]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "np.argmax(avg, axis=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df.groupby('session')['platform'].mean().astype('int').to_numpy()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "en.grid.distance(54,86)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "arr.mean(axis=2).argmax(axis=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }