{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys \n",
    "sys.path.append('../..')\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "from hippocampus.analysis.daw_analysis import add_relevant_columns\n",
    "import statsmodels.formula.api as smf\n",
    "import seaborn as sns\n",
    "from tqdm import tqdm_notebook as tqdm\n",
    "\n",
    "from definitions import RESULTS_FOLDER\n",
    "from hippocampus.environments import HexWaterMaze"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "en = HexWaterMaze(6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res_dir = os.path.join(RESULTS_FOLDER, 'mb_spatialmemory')\n",
    "\n",
    "params = pd.read_csv(os.path.join(res_dir,'params.csv'))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "params.hist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(os.path.join(res_dir, 'spatial_agent0'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.plot(x='total trial', y='escape time')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = df[['trial', 'total trial', 'session', 'escape time','platform', 'state']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_agents = 19"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_first_trial_info(data):\n",
    "    d2 = data.pivot_table(index='total trial')\n",
    "    d2['previous platform'] = d2['platform'].shift(1)\n",
    "    first_trials = d2[d2['trial']==0]\n",
    "    first_trials = first_trials.drop(0).pivot_table(index='session')\n",
    "    return first_trials"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_surrounding_states(state, env, rec_depth=2):\n",
    "    surrounding_states = [state]\n",
    "    for i in range(rec_depth):\n",
    "        added_states = []\n",
    "        for s in surrounding_states:\n",
    "            neighbours = np.flatnonzero(env.adjacency_graph[s])\n",
    "            for n in neighbours:\n",
    "                if n not in surrounding_states and n not in added_states:\n",
    "                    added_states.append(n)\n",
    "        surrounding_states += added_states\n",
    "    return surrounding_states\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(os.path.join(res_dir, 'spatial_agent{}'.format(1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_allo_index(agent_data, env):\n",
    "    \"\"\"Get the allocentricness index, defined as the amount of time spent around the previous platform location during\n",
    "    first trials of sessions.\n",
    "\n",
    "    :param agent_data:\n",
    "    :param env:\n",
    "    :return:\n",
    "    \"\"\"\n",
    "    first_trials = get_first_trial_info(agent_data)\n",
    "    prop_times = []\n",
    "    for ses in range(1, 11):\n",
    "        states = np.sort(agent_data[(agent_data.session == ses) & (agent_data.trial == 0)]['state'])\n",
    "        previous_platform = first_trials['previous platform'][ses]\n",
    "        surrounding_states = np.sort(np.array(get_surrounding_states(int(previous_platform), env)))\n",
    "\n",
    "        time_spent = np.isin(states, surrounding_states).sum()\n",
    "        prop_times.append(time_spent)\n",
    "\n",
    "    return np.mean(prop_times)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scores = []\n",
    "\n",
    "data = []\n",
    "\n",
    "for a in tqdm(range(n_agents)):\n",
    "    df = pd.read_csv(os.path.join(res_dir, 'spatial_agent{}'.format(a)))\n",
    "    data.append(df)\n",
    "    #ft = get_first_trial_info(df)\n",
    "    scores.append(get_allo_index(df, en))\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_first_trial_info(data[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.iloc[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load model-based data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_model_weights(data):\n",
    "    add_relevant_columns(data)\n",
    "    data['Stay'] = data['Stay'].astype('int')\n",
    "    data = data[['Stay', 'PreviousReward', 'PreviousTransition']]\n",
    "    mod = smf.logit(formula='Stay ~ PreviousTransition * PreviousReward', data=data)\n",
    "    res = mod.fit()\n",
    "    model_based_weight = -res.params['PreviousTransition[T.rare]:PreviousReward']\n",
    "    model_free_weight = res.params['PreviousReward']\n",
    "    return model_based_weight, model_free_weight\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "weights_mb = []\n",
    "for a in range(n_agents):\n",
    "    df = pd.read_csv(os.path.join(res_dir, 'twostep_agent{}'.format(a)))\n",
    "    mb_weight, mf_weight = get_model_weights(df)\n",
    "    weights_mb.append(mb_weight)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "allocentric_scores = []\n",
    "for a in tqdm(range(n_agents)):\n",
    "    df = pd.read_csv(os.path.join(res_dir, 'spatial_agent{}'.format(a)))\n",
    "    allocentric_scores.append(get_allo_index(df, en))\n",
    "\n",
    "mb_scores = []\n",
    "for a in range(n_agents):\n",
    "    df = pd.read_csv(os.path.join(res_dir, 'twostep_agent{}'.format(a)))\n",
    "    mb_weight, mf_weight = get_model_weights(df)\n",
    "    mb_scores.append(mb_weight)\n",
    "\n",
    "allocentric_scores_lesion = []\n",
    "for a in tqdm(range(n_agents)):\n",
    "    df = pd.read_csv(os.path.join(res_dir, 'spatial_agent{}_lesion'.format(a)))\n",
    "    allocentric_scores_lesion.append(get_allo_index(df, en))\n",
    "\n",
    "mb_scores_lesion = []\n",
    "for a in range(n_agents):\n",
    "    df = pd.read_csv(os.path.join(res_dir, 'twostep_agent{}_lesion'.format(a)))\n",
    "    mb_weight, mf_weight = get_model_weights(df)\n",
    "    mb_scores_lesion.append(mb_weight)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(np.log(allocentric_scores), mb_scores, c=sns.cubehelix_palette(len(scores)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(4,4))\n",
    "sns.regplot(allocentric_scores, mb_scores)\n",
    "\n",
    "plt.ylabel('Model based index')\n",
    "plt.xlabel('Allocentricness score')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "colpal = sns.color_palette()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots()\n",
    "\n",
    "plt.xlim([4.5,0])\n",
    "plt.ylim([-1,4.5])\n",
    "sns.regplot(np.log(allocentric_scores), mb_scores,color=colpal[1])\n",
    "sns.regplot(np.log(allocentric_scores_lesion), mb_scores_lesion, color=colpal[4])\n",
    "\n",
    "plt.ylabel('Model based index')\n",
    "plt.xlabel('Allocentricness score')\n",
    "ax.spines['right'].set_visible(False)\n",
    "ax.spines['top'].set_visible(False)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "score_data = pd.DataFrame({})\n",
    "score_data['model based'] = np.concatenate([mb_scores, mb_scores_lesion])\n",
    "score_data['allocentric'] = np.concatenate([np.log(allocentric_scores), np.log(allocentric_scores_lesion)])\n",
    "score_data['group'] = ['control'] * n_agents + ['lesion'] * n_agents\n",
    "\n",
    "\n",
    "sns.lmplot(y='model based', x='allocentric', data=score_data, hue='group', palette=[colpal[1], colpal[4]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.palplot(sns.color_palette())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "arr = np.load(os.path.join(res_dir, 'spatial_agent0value_funcs.npy'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "arr.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg = arr.mean(axis=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "[np.argmax(a) for a in avg]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.argmax(avg, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.groupby('session')['platform'].mean().astype('int').to_numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "en.grid.distance(54,86)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "arr.mean(axis=2).argmax(axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}