{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.append('../..')\n", "from hippocampus.environments import SimpleMDP, HexWaterMaze, TwoStepTask\n", "from hippocampus.experiments.reliability_in_twostep import CombinedAgent\n", "# TODO: do this on linear track " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "from tqdm.notebook import tqdm\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "ag = CombinedAgent(env=SimpleMDP(5, reward_probability=.8))\n", "\n", "init_p_sr = .5\n", "ag.p_sr = init_p_sr\n", "\n", "ag.HPC.learning_rate =.01" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "dae9808b1bb047199394955250c90b26", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=399.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "\n", "df = pd.DataFrame({})\n", "for ep in tqdm(range(1,400)):\n", " results = ag.one_episode()\n", " results['trial'] = ep \n", " df = df.append(results, ignore_index=True)\n", " ag.HPC.learning_rate *=.95" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>Action1</th>\n", " <th>DLS reliability</th>\n", " <th>HPC reliability</th>\n", " <th>P(SR)</th>\n", " <th>Qvs</th>\n", " <th>RPE</th>\n", " <th>Reward</th>\n", " <th>SPE0</th>\n", " <th>SPE1</th>\n", " <th>SPE2</th>\n", " <th>StartState</th>\n", " <th>omega</th>\n", " <th>omega_dls</th>\n", " <th>trial</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>0.0</td>\n", " <td>0.000000</td>\n", " <td>0.498243</td>\n", " <td>0.666667</td>\n", " <td>[[0.022561646066606045, 0.025811816006898882],...</td>\n", " <td>1.000000</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>-0.009500</td>\n", " <td>0.0</td>\n", " <td>[0.6633530998498768, 0.5483920786517998, 0.477...</td>\n", " <td>1.000000</td>\n", " <td>1.0</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>0.0</td>\n", " <td>0.006107</td>\n", " <td>0.547370</td>\n", " <td>0.749355</td>\n", " <td>[[0.04058395073055687, 0.060248956915999394], ...</td>\n", " <td>0.796448</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>-0.009405</td>\n", " <td>0.0</td>\n", " <td>[0.5974361142281511, 0.49704715842377384, 0.43...</td>\n", " <td>0.993893</td>\n", " <td>2.0</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>0.0</td>\n", " <td>0.017100</td>\n", " <td>0.666473</td>\n", " <td>0.755606</td>\n", " <td>[[0.08548155436781643, 0.15174347191463325], [...</td>\n", " <td>0.627461</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>-0.009311</td>\n", " <td>0.0</td>\n", " <td>[0.4772037669379944, 0.4228482607952359, 0.338...</td>\n", " <td>0.982900</td>\n", " <td>3.0</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>1.0</td>\n", " <td>0.031965</td>\n", " <td>0.687799</td>\n", " <td>0.771413</td>\n", " <td>[[0.08548735088500274, 0.19567688657430504], [...</td>\n", " <td>0.487392</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>-0.009218</td>\n", " <td>0.0</td>\n", " <td>[0.42615305268415854, 0.40067758366540407, 0.3...</td>\n", " <td>0.968035</td>\n", " <td>4.0</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>0.0</td>\n", " <td>0.050122</td>\n", " <td>0.714651</td>\n", " <td>0.769122</td>\n", " <td>[[0.1262231251822318, 0.27866828742762323], [0...</td>\n", " <td>0.362783</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>-0.009126</td>\n", " <td>0.0</td>\n", " <td>[0.4002969186405008, 0.3743271738561978, 0.295...</td>\n", " <td>0.949878</td>\n", " <td>5.0</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " Action1 DLS reliability HPC reliability P(SR) \\\n", "0 0.0 0.000000 0.498243 0.666667 \n", "1 0.0 0.006107 0.547370 0.749355 \n", "2 0.0 0.017100 0.666473 0.755606 \n", "3 1.0 0.031965 0.687799 0.771413 \n", "4 0.0 0.050122 0.714651 0.769122 \n", "\n", " Qvs RPE Reward SPE0 \\\n", "0 [[0.022561646066606045, 0.025811816006898882],... 1.000000 1.0 0.0 \n", "1 [[0.04058395073055687, 0.060248956915999394], ... 0.796448 1.0 0.0 \n", "2 [[0.08548155436781643, 0.15174347191463325], [... 0.627461 1.0 0.0 \n", "3 [[0.08548735088500274, 0.19567688657430504], [... 0.487392 1.0 0.0 \n", "4 [[0.1262231251822318, 0.27866828742762323], [0... 0.362783 1.0 0.0 \n", "\n", " SPE1 SPE2 StartState \\\n", "0 0.0 -0.009500 0.0 \n", "1 0.0 -0.009405 0.0 \n", "2 0.0 -0.009311 0.0 \n", "3 0.0 -0.009218 0.0 \n", "4 0.0 -0.009126 0.0 \n", "\n", " omega omega_dls trial \n", "0 [0.6633530998498768, 0.5483920786517998, 0.477... 1.000000 1.0 \n", "1 [0.5974361142281511, 0.49704715842377384, 0.43... 0.993893 2.0 \n", "2 [0.4772037669379944, 0.4228482607952359, 0.338... 0.982900 3.0 \n", "3 [0.42615305268415854, 0.40067758366540407, 0.3... 0.968035 4.0 \n", "4 [0.4002969186405008, 0.3743271738561978, 0.295... 0.949878 5.0 " ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.head()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "dls_reliab = pd.concat([pd.Series([0.]), df['DLS reliability']])\n", "hpc_reliab = pd.concat([pd.Series([0.]), df['HPC reliability']])\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(-10, 150)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fig, ax = plt.subplots()\n", "#df.plot(ax=ax, x='trial', y=['DLS reliability', 'HPC reliability'])\n", "ax.plot(dls_reliab)\n", "ax.plot(hpc_reliab)\n", "\n", "# Move left and bottom spines outward by 10 points\n", "ax.spines['left'].set_position(('outward', 10))\n", "ax.spines['bottom'].set_position(('outward', 10))\n", "# Hide the right and top spines\n", "ax.spines['right'].set_visible(False)\n", "ax.spines['top'].set_visible(False)\n", "# Only show ticks on the left and bottom spines\n", "ax.yaxis.set_ticks_position('left')\n", "ax.xaxis.set_ticks_position('bottom')\n", "\n", "#plt.ylim([-.1,1])\n", "plt.xlim([-10,150])" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[<matplotlib.lines.Line2D at 0x7fc796c89320>]" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot([np.mean(df['omega'].iloc[i]) for i in range(len(df))])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#df.plot(x='trial', y='P(SR)')\n", "fig, ax = plt.subplots()\n", "ax.plot(pd.concat([pd.Series([init_p_sr]), df['P(SR)']]))\n", "\n", "# Move left and bottom spines outward by 10 points\n", "ax.spines['left'].set_position(('outward', 10))\n", "ax.spines['bottom'].set_position(('outward', 10))\n", "# Hide the right and top spines\n", "ax.spines['right'].set_visible(False)\n", "ax.spines['top'].set_visible(False)\n", "# Only show ticks on the left and bottom spines\n", "ax.yaxis.set_ticks_position('left')\n", "ax.xaxis.set_ticks_position('bottom')\n", "\n", "\n", "plt.ylim([0,1])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "n_agents = 10\n", "n_trials = 300\n", "dls_reliab_M = np.zeros((n_agents, n_trials))\n", "hpc_reliab_M = np.zeros((n_agents, n_trials))\n", "psr_M = np.zeros((n_agents, n_trials))\n", "\n", "for ia in tqdm(range(n_agents)):\n", "\n", " ag = CombinedAgent(env=SimpleMDP(5, reward_probability=.8))\n", "\n", " init_p_sr = .5\n", " ag.p_sr = init_p_sr\n", " ag.HPC.learning_rate=.01\n", " df = pd.DataFrame({})\n", " for ep in tqdm(range(1, n_trials),leave=False):\n", " results = ag.one_episode()\n", " results['trial'] = ep \n", " df = df.append(results, ignore_index=True)\n", "\n", " dls_reliab_M[ia,:] = pd.concat([pd.Series([0.]), df['DLS reliability']])\n", " hpc_reliab_M[ia,:] = pd.concat([pd.Series([0.]), df['HPC reliability']])\n", " psr_M[ia,:] = pd.concat([pd.Series([init_p_sr]), df['P(SR)']])\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots()\n", "#df.plot(ax=ax, x='trial', y=['DLS reliability', 'HPC reliability'])\n", "ax.plot(dls_reliab_M.mean(axis=0))\n", "ax.plot(hpc_reliab_M.mean(axis=0))\n", "plt.ylim([-.1,1])\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots()\n", "#df.plot(ax=ax, x='trial', y=['DLS reliability', 'HPC reliability'])\n", "ax.plot(dls_reliab_M.mean(axis=0))\n", "ax.plot(hpc_reliab_M.mean(axis=0))\n", "plt.ylim([-.1,1])\n", "plt.xlim([-10,150])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "omegas = df['omega']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "alloms = np.concatenate(np.array(df['omega'])).reshape(ag.env.nr_states, -1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.plot(alloms[4])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df.plot(x='trial', y='omega_dls')\n", "plt.ylim([0,1])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "omg0 = [df['omega'][i][0] for i in range(272)]\n", "omg1 = [df['omega'][i][1] for i in range(272)]\n", "omg2 = [df['omega'][i][2] for i in range(272)]\n", "omg3 = [df['omega'][i][3] for i in range(272)]\n", "omg4 = [df['omega'][i][4] for i in range(272)]\n", "omg5 = [df['omega'][i][5] for i in range(272)]\n", "omg6 = [df['omega'][i][6] for i in range(272)]\n", "omg7 = [df['omega'][i][7] for i in range(272)]\n", "omg8 = [df['omega'][i][8] for i in range(272)]\n", "\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.plot(omg0)\n", "plt.plot(omg1)\n", "plt.plot(omg2)\n", "plt.plot(omg3)\n", "plt.plot(omg4)\n", "plt.plot(omg5)\n", "plt.plot(omg6)\n", "plt.plot(omg7)\n", "plt.plot(omg8)\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "alloms = np.concatenate([omg0, omg1, omg2, omg3, omg4, omg5, omg6, omg7, omg8]).reshape( -1, len(omg0))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.plot(alloms.mean(axis=0))\n", "plt.plot(df['omega_dls'])\n", "plt.ylim([0,1])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }