{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.append('../..')\n", "from hippocampus.environments import SimpleMDP, HexWaterMaze, TwoStepTask\n", "from hippocampus.experiments.reliability_in_twostep import CombinedAgent\n", "# TODO: do this on linear track " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "from tqdm.notebook import tqdm\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "ag = CombinedAgent(env=SimpleMDP(5, reward_probability=.8))\n", "\n", "init_p_sr = .5\n", "ag.p_sr = init_p_sr\n", "\n", "ag.HPC.learning_rate =.01" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "dae9808b1bb047199394955250c90b26", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=399.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "\n", "df = pd.DataFrame({})\n", "for ep in tqdm(range(1,400)):\n", " results = ag.one_episode()\n", " results['trial'] = ep \n", " df = df.append(results, ignore_index=True)\n", " ag.HPC.learning_rate *=.95" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | Action1 | \n", "DLS reliability | \n", "HPC reliability | \n", "P(SR) | \n", "Qvs | \n", "RPE | \n", "Reward | \n", "SPE0 | \n", "SPE1 | \n", "SPE2 | \n", "StartState | \n", "omega | \n", "omega_dls | \n", "trial | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "0.0 | \n", "0.000000 | \n", "0.498243 | \n", "0.666667 | \n", "[[0.022561646066606045, 0.025811816006898882],... | \n", "1.000000 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "-0.009500 | \n", "0.0 | \n", "[0.6633530998498768, 0.5483920786517998, 0.477... | \n", "1.000000 | \n", "1.0 | \n", "
1 | \n", "0.0 | \n", "0.006107 | \n", "0.547370 | \n", "0.749355 | \n", "[[0.04058395073055687, 0.060248956915999394], ... | \n", "0.796448 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "-0.009405 | \n", "0.0 | \n", "[0.5974361142281511, 0.49704715842377384, 0.43... | \n", "0.993893 | \n", "2.0 | \n", "
2 | \n", "0.0 | \n", "0.017100 | \n", "0.666473 | \n", "0.755606 | \n", "[[0.08548155436781643, 0.15174347191463325], [... | \n", "0.627461 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "-0.009311 | \n", "0.0 | \n", "[0.4772037669379944, 0.4228482607952359, 0.338... | \n", "0.982900 | \n", "3.0 | \n", "
3 | \n", "1.0 | \n", "0.031965 | \n", "0.687799 | \n", "0.771413 | \n", "[[0.08548735088500274, 0.19567688657430504], [... | \n", "0.487392 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "-0.009218 | \n", "0.0 | \n", "[0.42615305268415854, 0.40067758366540407, 0.3... | \n", "0.968035 | \n", "4.0 | \n", "
4 | \n", "0.0 | \n", "0.050122 | \n", "0.714651 | \n", "0.769122 | \n", "[[0.1262231251822318, 0.27866828742762323], [0... | \n", "0.362783 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "-0.009126 | \n", "0.0 | \n", "[0.4002969186405008, 0.3743271738561978, 0.295... | \n", "0.949878 | \n", "5.0 | \n", "