{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../..')\n",
    "\n",
    "from tqdm import tqdm_notebook as tqdm\n",
    "from hippocampus.plotting import tsplot_boot\n",
    "import matplotlib.pyplot as plt\n",
    "from definitions import ROOT_FOLDER\n",
    "import os\n",
    "import pandas as pd\n",
    "from hippocampus.agents import CombinedAgent\n",
    "from hippocampus.environments import HexWaterMaze\n",
    "import numpy as np\n",
    "import random\n",
    "import seaborn as sns\n",
    "from multiprocessing import Pool\n",
    "%matplotlib notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "g = HexWaterMaze(6)\n",
    "g.plot_grid()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inv_temp = 6.\n",
    "\n",
    "# determine platform sequence\n",
    "possible_platform_states = np.array([48, 45, 42, 39, 60, 57, 54, 51])\n",
    "#possible_platform_states = np.array([192, 185, 181, 174, 216, 210, 203, 197])  # for the r = 10 case\n",
    "\n",
    "indices = np.arange(len(possible_platform_states))\n",
    "usage = np.zeros(possible_platform_states.shape)\n",
    "\n",
    "platform_sequence = [np.random.choice(possible_platform_states)]\n",
    "for ses in range(1,11):\n",
    "    distances = np.array([g.grid.distance(platform_sequence[ses-1], s) for s in possible_platform_states])\n",
    "    candidates = indices[np.logical_and(usage < 2, distances > g.grid.radius)]\n",
    "    platform_idx = np.random.choice(candidates)\n",
    "    platform_sequence.append(possible_platform_states[platform_idx])\n",
    "    usage[platform_idx] += 1."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "platform_sequence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "#random.shuffle(possible_platform_states)\n",
    "g.set_platform_state(possible_platform_states[6]) \n",
    "\n",
    "\n",
    "agent = CombinedAgent(g, init_sr='rw', lesion_dls=False, lesion_hpc=True, inv_temp=inv_temp, gamma=.99)\n",
    "agent_results = []\n",
    "agent_ets = []\n",
    "session = 0\n",
    "\n",
    "total_trial_count = 0\n",
    "\n",
    "for ses in tqdm(range(11)):\n",
    "    for trial in tqdm(range(4),leave=False):\n",
    "        if trial == 0: \n",
    "            g.set_platform_state(platform_sequence[ses])\n",
    "        res = agent.one_episode(random_policy=False)\n",
    "        res['trial'] = trial\n",
    "        res['escape time'] = res.time.max()\n",
    "        res['session'] = ses\n",
    "        res['total trial'] = total_trial_count\n",
    "        agent_results.append(res)\n",
    "        agent_ets.append(res.time.max())\n",
    "\n",
    "        total_trial_count += 1\n",
    "    #inv_temp += .8\n",
    "    #agent.set_exploration(inv_temp)\n",
    "    \n",
    "agent_df = pd.concat(agent_results)\n",
    "agent_df['total time']= np.arange(len(agent_df))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agent.weights.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agent.DLS.get_feature_rep(0,30).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "sns.lineplot(data=agent_df, x='total trial', y='escape time')\n",
    "\n",
    "for i in range(44):\n",
    "    if (i % 4) == 0:\n",
    "        plt.axvline(x=i, ymin=0, ymax=1, linewidth=1, color='r', alpha=.3)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.lineplot(data=agent_df, x='total trial', y='P(SR)')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.regplot(data=agent_df[agent_df.trial==0], x='P(SR)', y='escape time')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "first_and_last = agent_df[ np.logical_or(agent_df.trial == 0, agent_df.trial==3)]\n",
    "plt.figure()\n",
    "sns.lineplot(data=first_and_last, x='session', y='escape time', hue='trial', estimator=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# analyse two subsequent sessions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "ses6 = agent_df[agent_df.session==2]\n",
    "ses5 = agent_df[agent_df.session==1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "states = ses6[ses6['trial']==0]['state']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agent.env.plot_occupancy_on_grid(ses6[ses6['trial']==0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ses6.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ses6.state.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mm = ses6.M_hat.iloc[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# previous platform \n",
    "ses5.platform.iloc[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rr = ses6.R_hat.iloc[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agent.env.plot_grid(mm[72], show_state_idx=True) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agent.env.plot_grid(mm @ rr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ses6[ses6['trial']==0]['Q_mf'].iloc[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ses6[ses6['trial']==0]['Q'].iloc[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "v = mm @ rr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "v[56]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ses6[ses6['trial']==0]['P(SR)'].iloc[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df1 = agent_df[['P(SR)','escape time', 'total trial', 'session', 'HPC reliability', 'DLS reliability']]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.pairplot(df1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "[agent_df[agent_df.session==i]['platform'].iloc[0] for i in range(11)]agent.inv_temp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trial_40 = agent_df[agent_df['total trial']==36]\n",
    "trial_39 = agent_df[agent_df['total trial']==35]\n",
    "\n",
    "g.plot_occupancy_on_grid(trial_40)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g.plot_grid(trial_40['M_hat'].iloc[0] @ trial_40['R_hat'].iloc[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trial_40['R_hat'].iloc[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trial_40['platform'].iloc[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trial_39['platform'].iloc[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from hippocampus.utils import softmax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "softmax(trial_40['weights'].iloc[1].T @ trial_40['features'].iloc[1], beta=agent.inv_temp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trial_40['Q_mf'].iloc[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trial_40['Q'].iloc[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "V = trial_40['M_hat'].iloc[1] @ trial_40['R_hat'].iloc[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "V[88]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# what are platform sequences for good runs versus bad runs? \n",
    "platform_sequence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}