{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.append('../..')\n", "\n", "from tqdm import tqdm_notebook as tqdm\n", "from hippocampus.plotting import tsplot_boot\n", "import matplotlib.pyplot as plt\n", "from definitions import ROOT_FOLDER\n", "import os\n", "import pandas as pd\n", "from hippocampus.agents import CombinedAgent\n", "from hippocampus.environments import HexWaterMaze\n", "import numpy as np\n", "import random\n", "import seaborn as sns\n", "from multiprocessing import Pool\n", "%matplotlib notebook" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "g = HexWaterMaze(6)\n", "g.plot_grid()\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "inv_temp = 6.\n", "\n", "# determine platform sequence\n", "possible_platform_states = np.array([48, 45, 42, 39, 60, 57, 54, 51])\n", "#possible_platform_states = np.array([192, 185, 181, 174, 216, 210, 203, 197]) # for the r = 10 case\n", "\n", "indices = np.arange(len(possible_platform_states))\n", "usage = np.zeros(possible_platform_states.shape)\n", "\n", "platform_sequence = [np.random.choice(possible_platform_states)]\n", "for ses in range(1,11):\n", " distances = np.array([g.grid.distance(platform_sequence[ses-1], s) for s in possible_platform_states])\n", " candidates = indices[np.logical_and(usage < 2, distances > g.grid.radius)]\n", " platform_idx = np.random.choice(candidates)\n", " platform_sequence.append(possible_platform_states[platform_idx])\n", " usage[platform_idx] += 1." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "platform_sequence" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "\n", "#random.shuffle(possible_platform_states)\n", "g.set_platform_state(possible_platform_states[6]) \n", "\n", "\n", "agent = CombinedAgent(g, init_sr='rw', lesion_dls=False, lesion_hpc=True, inv_temp=inv_temp, gamma=.99)\n", "agent_results = []\n", "agent_ets = []\n", "session = 0\n", "\n", "total_trial_count = 0\n", "\n", "for ses in tqdm(range(11)):\n", " for trial in tqdm(range(4),leave=False):\n", " if trial == 0: \n", " g.set_platform_state(platform_sequence[ses])\n", " res = agent.one_episode(random_policy=False)\n", " res['trial'] = trial\n", " res['escape time'] = res.time.max()\n", " res['session'] = ses\n", " res['total trial'] = total_trial_count\n", " agent_results.append(res)\n", " agent_ets.append(res.time.max())\n", "\n", " total_trial_count += 1\n", " #inv_temp += .8\n", " #agent.set_exploration(inv_temp)\n", " \n", "agent_df = pd.concat(agent_results)\n", "agent_df['total time']= np.arange(len(agent_df))\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "agent.weights.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "agent.DLS.get_feature_rep(0,30).shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.figure()\n", "sns.lineplot(data=agent_df, x='total trial', y='escape time')\n", "\n", "for i in range(44):\n", " if (i % 4) == 0:\n", " plt.axvline(x=i, ymin=0, ymax=1, linewidth=1, color='r', alpha=.3)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.lineplot(data=agent_df, x='total trial', y='P(SR)')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.regplot(data=agent_df[agent_df.trial==0], x='P(SR)', y='escape time')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "first_and_last = agent_df[ np.logical_or(agent_df.trial == 0, agent_df.trial==3)]\n", "plt.figure()\n", "sns.lineplot(data=first_and_last, x='session', y='escape time', hue='trial', estimator=None)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# analyse two subsequent sessions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "\n", "ses6 = agent_df[agent_df.session==2]\n", "ses5 = agent_df[agent_df.session==1]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "states = ses6[ses6['trial']==0]['state']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "agent.env.plot_occupancy_on_grid(ses6[ses6['trial']==0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ses6.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ses6.state.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mm = ses6.M_hat.iloc[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# previous platform \n", "ses5.platform.iloc[1]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "rr = ses6.R_hat.iloc[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "agent.env.plot_grid(mm[72], show_state_idx=True) " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "agent.env.plot_grid(mm @ rr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ses6[ses6['trial']==0]['Q_mf'].iloc[1]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ses6[ses6['trial']==0]['Q'].iloc[1]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "v = mm @ rr" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "v[56]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ses6[ses6['trial']==0]['P(SR)'].iloc[1]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df1 = agent_df[['P(SR)','escape time', 'total trial', 'session', 'HPC reliability', 'DLS reliability']]\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.pairplot(df1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "[agent_df[agent_df.session==i]['platform'].iloc[0] for i in range(11)]agent.inv_temp" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trial_40 = agent_df[agent_df['total trial']==36]\n", "trial_39 = agent_df[agent_df['total trial']==35]\n", "\n", "g.plot_occupancy_on_grid(trial_40)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g.plot_grid(trial_40['M_hat'].iloc[0] @ trial_40['R_hat'].iloc[0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trial_40['R_hat'].iloc[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trial_40['platform'].iloc[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trial_39['platform'].iloc[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from hippocampus.utils import softmax" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "softmax(trial_40['weights'].iloc[1].T @ trial_40['features'].iloc[1], beta=agent.inv_temp)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trial_40['Q_mf'].iloc[1]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trial_40['Q'].iloc[1]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "V = trial_40['M_hat'].iloc[1] @ trial_40['R_hat'].iloc[1]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "V[88]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# what are platform sequences for good runs versus bad runs? \n", "platform_sequence" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }