{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from tqdm import tqdm_notebook as tqdm\n", "from hippocampus.plotting import tsplot_boot\n", "import matplotlib.pyplot as plt\n", "from definitions import ROOT_FOLDER\n", "import os\n", "import pandas as pd\n", "from hippocampus.agents import CombinedAgent\n", "from hippocampus.environments import HexWaterMaze\n", "import numpy as np\n", "import random\n", "import seaborn as sns\n", "from multiprocessing import Pool\n", "%matplotlib notebook\n", "\n", "from definitions import ROOT_FOLDER\n", "from datetime import datetime\n", "results_folder = os.path.join(ROOT_FOLDER, 'results/pearce/bigmaze')\n", "if not os.path.exists(results_folder):\n", " os.makedirs(results_folder)\n", " " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# some hyperparameters \n", "\n", "eta = 0.04 #.06 # learning rate for reliability estimators\n", "learning_rate = .1 # step size for value and SR learning\n", "inv_temp = 1.5 # inverse temperature parameter for softmax action selection\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "g = HexWaterMaze(10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e9db49a6f2e64e269f6c39cf7b13cfa6", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=11), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=4), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=4), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=4), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=4), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=4), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=4), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=4), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=4), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=4), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=4), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=4), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "df = pd.DataFrame({})\n", "\n", "datetime = str(datetime.now())\n", "\n", "run_results_folder = os.path.join(results_folder, datetime)\n", "if not os.path.exists(run_results_folder):\n", " os.makedirs(run_results_folder)\n", "\n", "#possible_platform_states = [48, 45, 42, 39, 60, 57, 54, 57]\n", "possible_platform_states = [192, 185, 181, 174, 216, 210, 203, 197] # for the r = 10 case\n", "\n", "for n_agent in tqdm(range(30)):\n", " \n", " \n", " \n", " random.seed(n_agent)\n", " random.shuffle(possible_platform_states)\n", " g.set_platform_state(possible_platform_states[6]) \n", "\n", "\n", " agent = CombinedAgent(g, init_sr='rw', lesion_dls=False, lesion_hpc=False, eta=eta,\n", " learning_rate=learning_rate, inv_temp=inv_temp)\n", " agent_results = []\n", " agent_ets = []\n", " session = 0\n", "\n", " total_trial_count = 0\n", "\n", " for ses in tqdm(range(11), leave=False):\n", " for trial in tqdm(range(4),leave=False):\n", " if trial == 0: \n", " g.set_platform_state(possible_platform_states[ses % len(possible_platform_states)])\n", " res = agent.one_episode(random_policy=False)\n", " res['trial'] = trial\n", " res['escape time'] = res.time.max()\n", " res['session'] = ses\n", " res['total trial'] = total_trial_count\n", " agent_results.append(res)\n", " agent_ets.append(res.time.max())\n", " \n", " df = df.append({'agent': n_agent,\n", " 'session': ses,\n", " 'trial': trial,\n", " 'escape time': res.time.max(),\n", " 'platform': res.platform.iloc[-1],\n", " 'start': agent.env.starting_state, \n", " 'P(SR)': res['P(SR)'].mean(),\n", " 'total trial': total_trial_count}, ignore_index=True)\n", " total_trial_count += 1\n", " \n", " agent.inv_temp += .8\n", " \n", " agent_df = pd.concat(agent_results)\n", " agent_df['agent'] = n_agent\n", " agent_df['total time']= np.arange(len(agent_df))\n", "\n", " agent_df.to_csv(os.path.join(run_results_folder,'results_r10_agent{}.csv'.format(n_agent)))\n", "\n", " " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df.head(20)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.figure()\n", "first_last = df[np.logical_or(df.trial==0,df.trial==3)]\n", "\n", "sns.lineplot(data=first_last, x='session', y='escape time', hue='trial', ci=None)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.figure()\n", "sns.lineplot(data=df,x='total trial', y='escape time', hue='agent')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df.agent.unique()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "res.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }