{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.append('../..')\n", "from hippocampus.environments import SimpleMDP, HexWaterMaze, TwoStepTask\n", "from hippocampus.experiments.reliability_in_twostep import CombinedAgent\n", "from definitions import FIGURE_FOLDER\n", "# TODO: do this on linear track " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "from tqdm.notebook import tqdm\n", "import pandas as pd\n", "import os\n", "import matplotlib" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.palplot(sns.color_palette())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ag = CombinedAgent(env=SimpleMDP(5, reward_probability=.85),inv_temp=10)\n", "\n", "init_p_sr = .5\n", "ag.p_sr = init_p_sr\n", "\n", "ag.HPC.learning_rate =.01" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "df = pd.DataFrame({})\n", "for ep in tqdm(range(1,400)):\n", " results = ag.one_episode(deterministic_policy=False)\n", " results['trial'] = ep \n", " df = df.append(results, ignore_index=True)\n", " ag.HPC.learning_rate *=.95" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.plot(np.array([df['omega'].iloc[i] for i in range(len(df))]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dls_reliab = pd.concat([pd.Series([0.]), df['DLS reliability']])\n", "hpc_reliab = pd.concat([pd.Series([0.]), df['HPC reliability']])\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "font = {'size': 22}\n", "\n", "matplotlib.rc('font', **font)\n", "\n", "\n", "fig, ax = plt.subplots()\n", "#df.plot(ax=ax, x='trial', y=['DLS reliability', 'HPC reliability'])\n", "ax.plot(dls_reliab, color=sns.color_palette()[1], linewidth=2)\n", "ax.plot(hpc_reliab, color=sns.color_palette()[2], linewidth=2)\n", "\n", "# Move left and bottom spines outward by 10 points\n", "ax.spines['left'].set_position(('outward', 10))\n", "ax.spines['bottom'].set_position(('outward', 10))\n", "# Hide the right and top spines\n", "ax.spines['right'].set_visible(False)\n", "ax.spines['top'].set_visible(False)\n", "# Only show ticks on the left and bottom spines\n", "ax.yaxis.set_ticks_position('left')\n", "ax.xaxis.set_ticks_position('bottom')\n", "\n", "plt.ylabel('Reliability')\n", "plt.xlabel('Trial')\n", "plt.legend(['DLS reliability', 'HPC reliability'])\n", "plt.ylim([-.1,1])\n", "plt.tight_layout()\n", "plt.savefig(os.path.join(FIGURE_FOLDER, 'reliability.pdf'))\n", "#plt.xlim([-10,150])\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#df.plot(x='trial', y='P(SR)')\n", "fig, ax = plt.subplots()\n", "\n", "font = {'size': 22}\n", "\n", "\n", "ax.plot(pd.concat([pd.Series([init_p_sr]), df['P(SR)']]), color=sns.color_palette()[0],linewidth=2)\n", "\n", "# Move left and bottom spines outward by 10 points\n", "ax.spines['left'].set_position(('outward', 10))\n", "ax.spines['bottom'].set_position(('outward', 10))\n", "# Hide the right and top spines\n", "ax.spines['right'].set_visible(False)\n", "ax.spines['top'].set_visible(False)\n", "# Only show ticks on the left and bottom spines\n", "ax.yaxis.set_ticks_position('left')\n", "ax.xaxis.set_ticks_position('bottom')\n", "\n", "plt.ylabel('Pr(HPC)')\n", "plt.xlabel('Trial')\n", "\n", "plt.ylim([0,1])\n", "plt.tight_layout()\n", "plt.savefig(os.path.join(FIGURE_FOLDER, 'psr.pdf'))\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "n_agents = 10\n", "n_trials = 300\n", "dls_reliab_M = np.zeros((n_agents, n_trials))\n", "hpc_reliab_M = np.zeros((n_agents, n_trials))\n", "psr_M = np.zeros((n_agents, n_trials))\n", "\n", "for ia in tqdm(range(n_agents)):\n", "\n", " ag = CombinedAgent(env=SimpleMDP(5, reward_probability=.8))\n", "\n", " init_p_sr = .5\n", " ag.p_sr = init_p_sr\n", " ag.HPC.learning_rate=.01\n", " df = pd.DataFrame({})\n", " for ep in tqdm(range(1, n_trials),leave=False):\n", " results = ag.one_episode()\n", " results['trial'] = ep \n", " df = df.append(results, ignore_index=True)\n", "\n", " dls_reliab_M[ia,:] = pd.concat([pd.Series([0.]), df['DLS reliability']])\n", " hpc_reliab_M[ia,:] = pd.concat([pd.Series([0.]), df['HPC reliability']])\n", " psr_M[ia,:] = pd.concat([pd.Series([init_p_sr]), df['P(SR)']])\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots()\n", "#df.plot(ax=ax, x='trial', y=['DLS reliability', 'HPC reliability'])\n", "ax.plot(dls_reliab_M.mean(axis=0))\n", "ax.plot(hpc_reliab_M.mean(axis=0))\n", "plt.ylim([-.1,1])\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots()\n", "#df.plot(ax=ax, x='trial', y=['DLS reliability', 'HPC reliability'])\n", "ax.plot(dls_reliab_M.mean(axis=0))\n", "ax.plot(hpc_reliab_M.mean(axis=0))\n", "plt.ylim([-.1,1])\n", "plt.xlim([-10,150])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "omegas = df['omega']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "alloms = np.concatenate(np.array(df['omega'])).reshape(ag.env.nr_states, -1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.plot(alloms[4])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df.plot(x='trial', y='omega_dls')\n", "plt.ylim([0,1])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "omg0 = [df['omega'][i][0] for i in range(272)]\n", "omg1 = [df['omega'][i][1] for i in range(272)]\n", "omg2 = [df['omega'][i][2] for i in range(272)]\n", "omg3 = [df['omega'][i][3] for i in range(272)]\n", "omg4 = [df['omega'][i][4] for i in range(272)]\n", "omg5 = [df['omega'][i][5] for i in range(272)]\n", "omg6 = [df['omega'][i][6] for i in range(272)]\n", "omg7 = [df['omega'][i][7] for i in range(272)]\n", "omg8 = [df['omega'][i][8] for i in range(272)]\n", "\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.plot(omg0)\n", "plt.plot(omg1)\n", "plt.plot(omg2)\n", "plt.plot(omg3)\n", "plt.plot(omg4)\n", "plt.plot(omg5)\n", "plt.plot(omg6)\n", "plt.plot(omg7)\n", "plt.plot(omg8)\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "alloms = np.concatenate([omg0, omg1, omg2, omg3, omg4, omg5, omg6, omg7, omg8]).reshape( -1, len(omg0))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.plot(alloms.mean(axis=0))\n", "plt.plot(df['omega_dls'])\n", "plt.ylim([0,1])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }