{ "cells": [ { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import os\n", "\n", "from definitions import RESULTS_FOLDER\n", "from hippocampus.environments import HexWaterMaze" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [], "source": [ "en = HexWaterMaze(10)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "res_dir = os.path.join(RESULTS_FOLDER, 'mb_spatialmemory')" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "df = pd.read_csv(os.path.join(res_dir, 'spatial_memory'))" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Index(['Unnamed: 0', 'DLS reliability', 'HPC reliability', 'M_hat', 'P(SR)',\n", " 'Q_mf', 'RPE', 'R_hat', 'SPE', 'alpha', 'beta', 'choice', 'platform',\n", " 'reward', 'state', 'time', 'Q', 'Q_allo', 'features', 'landmark',\n", " 'weights', 'trial', 'escape time', 'session', 'total trial',\n", " 'total time'],\n", " dtype='object')" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.columns" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | Unnamed: 0 | \n", "DLS reliability | \n", "HPC reliability | \n", "M_hat | \n", "P(SR) | \n", "Q_mf | \n", "RPE | \n", "R_hat | \n", "SPE | \n", "alpha | \n", "... | \n", "Q | \n", "Q_allo | \n", "features | \n", "landmark | \n", "weights | \n", "trial | \n", "escape time | \n", "session | \n", "total trial | \n", "total time | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "0 | \n", "0.0 | \n", "0.8 | \n", "[1.54201195 0.57061154 0.57058036 ... 0.652488... | \n", "0.900000 | \n", "[0. 0. 0. 0. 0. 0.] | \n", "0.0 | \n", "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ... | \n", "0.0 | \n", "1.069722 | \n", "... | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "0 | \n", "1000.0 | \n", "0 | \n", "0 | \n", "0 | \n", "
1 | \n", "1 | \n", "0.0 | \n", "0.8 | \n", "[[1.54201195 0.57061154 0.57058036 ... 0.01565... | \n", "0.867326 | \n", "[0. 0. 0. 0. 0. 0.] | \n", "0.0 | \n", "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ... | \n", "[-1.69531697e-03 -1.46805481e-03 -9.45433645e-... | \n", "1.069722 | \n", "... | \n", "[0. 0. 0. 0. 0. 0.] | \n", "[0. 0. 0. 0. 0. 0.] | \n", "[1.04879789e-46 6.22312918e-43 7.92610564e-40 ... | \n", "(-5, 6.92820323027551) | \n", "[[0. 0. 0. 0. 0. 0.]\\n [0. 0. 0. 0. 0. 0.]\\n [... | \n", "0 | \n", "1000.0 | \n", "0 | \n", "0 | \n", "1 | \n", "
2 | \n", "2 | \n", "0.0 | \n", "0.8 | \n", "[[1.54201195 0.57061154 0.57058036 ... 0.01565... | \n", "0.874674 | \n", "[0. 0. 0. 0. 0. 0.] | \n", "0.0 | \n", "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ... | \n", "[ 1.29177083e-03 1.62847693e-03 9.66093332e-... | \n", "1.069722 | \n", "... | \n", "[0. 0. 0. 0. 0. 0.] | \n", "[0. 0. 0. 0. 0. 0.] | \n", "[6.45282964e-35 1.09111715e-32 3.96028640e-31 ... | \n", "(-5, 6.92820323027551) | \n", "[[0. 0. 0. 0. 0. 0.]\\n [0. 0. 0. 0. 0. 0.]\\n [... | \n", "0 | \n", "1000.0 | \n", "0 | \n", "0 | \n", "2 | \n", "
3 | \n", "3 | \n", "0.0 | \n", "0.8 | \n", "[[1.54201195 0.57061154 0.57058036 ... 0.01565... | \n", "0.873021 | \n", "[0. 0. 0. 0. 0. 0.] | \n", "0.0 | \n", "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ... | \n", "[ 3.91728776e-03 4.93600164e-03 2.89495412e-... | \n", "1.069722 | \n", "... | \n", "[0. 0. 0. 0. 0. 0.] | \n", "[0. 0. 0. 0. 0. 0.] | \n", "[3.44565986e-31 6.04574572e-29 2.27699136e-27 ... | \n", "(-5, 6.92820323027551) | \n", "[[0. 0. 0. 0. 0. 0.]\\n [0. 0. 0. 0. 0. 0.]\\n [... | \n", "0 | \n", "1000.0 | \n", "0 | \n", "0 | \n", "3 | \n", "
4 | \n", "4 | \n", "0.0 | \n", "0.8 | \n", "[[1.54201195 0.57061154 0.57058036 ... 0.01565... | \n", "0.873393 | \n", "[0. 0. 0. 0. 0. 0.] | \n", "0.0 | \n", "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ... | \n", "[ 2.49490053e-03 4.96097982e-03 2.82030639e-... | \n", "1.069722 | \n", "... | \n", "[0. 0. 0. 0. 0. 0.] | \n", "[0. 0. 0. 0. 0. 0.] | \n", "[2.70417405e-34 3.23431029e-31 8.30352991e-29 ... | \n", "(-5, 6.92820323027551) | \n", "[[0. 0. 0. 0. 0. 0.]\\n [0. 0. 0. 0. 0. 0.]\\n [... | \n", "0 | \n", "1000.0 | \n", "0 | \n", "0 | \n", "4 | \n", "
5 rows × 26 columns
\n", "\n", " | escape time | \n", "platform | \n", "previous platform | \n", "trial | \n", "
---|---|---|---|---|
session | \n", "\n", " | \n", " | \n", " | \n", " |
1 | \n", "508.0 | \n", "210.0 | \n", "197.0 | \n", "0 | \n", "
2 | \n", "152.0 | \n", "181.0 | \n", "210.0 | \n", "0 | \n", "
3 | \n", "206.0 | \n", "192.0 | \n", "181.0 | \n", "0 | \n", "
4 | \n", "362.0 | \n", "181.0 | \n", "192.0 | \n", "0 | \n", "
5 | \n", "296.0 | \n", "210.0 | \n", "181.0 | \n", "0 | \n", "
6 | \n", "212.0 | \n", "192.0 | \n", "210.0 | \n", "0 | \n", "
7 | \n", "240.0 | \n", "203.0 | \n", "192.0 | \n", "0 | \n", "
8 | \n", "59.0 | \n", "216.0 | \n", "203.0 | \n", "0 | \n", "
9 | \n", "618.0 | \n", "203.0 | \n", "216.0 | \n", "0 | \n", "
10 | \n", "694.0 | \n", "174.0 | \n", "203.0 | \n", "0 | \n", "