{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../..')\n",
    "from hippocampus.environments import SimpleMDP, HexWaterMaze, TwoStepTask\n",
    "from hippocampus.experiments.reliability_in_twostep import CombinedAgent\n",
    "from definitions import FIGURE_FOLDER\n",
    "# TODO: do this on linear track "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm.notebook import  tqdm\n",
    "import pandas as pd\n",
    "import os\n",
    "import matplotlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.palplot(sns.color_palette())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ag = CombinedAgent(env=SimpleMDP(5, reward_probability=.85),inv_temp=10)\n",
    "\n",
    "init_p_sr = .5\n",
    "ag.p_sr = init_p_sr\n",
    "\n",
    "ag.HPC.learning_rate =.01"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "df = pd.DataFrame({})\n",
    "for ep in tqdm(range(1,400)):\n",
    "    results = ag.one_episode(deterministic_policy=False)\n",
    "    results['trial'] = ep \n",
    "    df = df.append(results, ignore_index=True)\n",
    "    ag.HPC.learning_rate *=.95"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(np.array([df['omega'].iloc[i] for i in range(len(df))]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dls_reliab = pd.concat([pd.Series([0.]), df['DLS reliability']])\n",
    "hpc_reliab = pd.concat([pd.Series([0.]), df['HPC reliability']])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "font = {'size': 22}\n",
    "\n",
    "matplotlib.rc('font', **font)\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "#df.plot(ax=ax, x='trial', y=['DLS reliability', 'HPC reliability'])\n",
    "ax.plot(dls_reliab, color=sns.color_palette()[1],  linewidth=2)\n",
    "ax.plot(hpc_reliab, color=sns.color_palette()[2],  linewidth=2)\n",
    "\n",
    "# Move left and bottom spines outward by 10 points\n",
    "ax.spines['left'].set_position(('outward', 10))\n",
    "ax.spines['bottom'].set_position(('outward', 10))\n",
    "# Hide the right and top spines\n",
    "ax.spines['right'].set_visible(False)\n",
    "ax.spines['top'].set_visible(False)\n",
    "# Only show ticks on the left and bottom spines\n",
    "ax.yaxis.set_ticks_position('left')\n",
    "ax.xaxis.set_ticks_position('bottom')\n",
    "\n",
    "plt.ylabel('Reliability')\n",
    "plt.xlabel('Trial')\n",
    "plt.legend(['DLS reliability', 'HPC reliability'])\n",
    "plt.ylim([-.1,1])\n",
    "plt.tight_layout()\n",
    "plt.savefig(os.path.join(FIGURE_FOLDER, 'reliability.pdf'))\n",
    "#plt.xlim([-10,150])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#df.plot(x='trial', y='P(SR)')\n",
    "fig, ax = plt.subplots()\n",
    "\n",
    "font = {'size': 22}\n",
    "\n",
    "\n",
    "ax.plot(pd.concat([pd.Series([init_p_sr]),  df['P(SR)']]), color=sns.color_palette()[0],linewidth=2)\n",
    "\n",
    "# Move left and bottom spines outward by 10 points\n",
    "ax.spines['left'].set_position(('outward', 10))\n",
    "ax.spines['bottom'].set_position(('outward', 10))\n",
    "# Hide the right and top spines\n",
    "ax.spines['right'].set_visible(False)\n",
    "ax.spines['top'].set_visible(False)\n",
    "# Only show ticks on the left and bottom spines\n",
    "ax.yaxis.set_ticks_position('left')\n",
    "ax.xaxis.set_ticks_position('bottom')\n",
    "\n",
    "plt.ylabel('Pr(HPC)')\n",
    "plt.xlabel('Trial')\n",
    "\n",
    "plt.ylim([0,1])\n",
    "plt.tight_layout()\n",
    "plt.savefig(os.path.join(FIGURE_FOLDER, 'psr.pdf'))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_agents = 10\n",
    "n_trials = 300\n",
    "dls_reliab_M = np.zeros((n_agents, n_trials))\n",
    "hpc_reliab_M = np.zeros((n_agents, n_trials))\n",
    "psr_M = np.zeros((n_agents, n_trials))\n",
    "\n",
    "for ia in tqdm(range(n_agents)):\n",
    "\n",
    "    ag = CombinedAgent(env=SimpleMDP(5, reward_probability=.8))\n",
    "\n",
    "    init_p_sr = .5\n",
    "    ag.p_sr = init_p_sr\n",
    "    ag.HPC.learning_rate=.01\n",
    "    df = pd.DataFrame({})\n",
    "    for ep in tqdm(range(1, n_trials),leave=False):\n",
    "        results = ag.one_episode()\n",
    "        results['trial'] = ep \n",
    "        df = df.append(results, ignore_index=True)\n",
    "\n",
    "    dls_reliab_M[ia,:] = pd.concat([pd.Series([0.]), df['DLS reliability']])\n",
    "    hpc_reliab_M[ia,:] = pd.concat([pd.Series([0.]), df['HPC reliability']])\n",
    "    psr_M[ia,:] =  pd.concat([pd.Series([init_p_sr]),  df['P(SR)']])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots()\n",
    "#df.plot(ax=ax, x='trial', y=['DLS reliability', 'HPC reliability'])\n",
    "ax.plot(dls_reliab_M.mean(axis=0))\n",
    "ax.plot(hpc_reliab_M.mean(axis=0))\n",
    "plt.ylim([-.1,1])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots()\n",
    "#df.plot(ax=ax, x='trial', y=['DLS reliability', 'HPC reliability'])\n",
    "ax.plot(dls_reliab_M.mean(axis=0))\n",
    "ax.plot(hpc_reliab_M.mean(axis=0))\n",
    "plt.ylim([-.1,1])\n",
    "plt.xlim([-10,150])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "omegas = df['omega']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "alloms = np.concatenate(np.array(df['omega'])).reshape(ag.env.nr_states, -1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(alloms[4])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.plot(x='trial', y='omega_dls')\n",
    "plt.ylim([0,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "omg0 = [df['omega'][i][0] for i in range(272)]\n",
    "omg1 = [df['omega'][i][1] for i in range(272)]\n",
    "omg2 = [df['omega'][i][2] for i in range(272)]\n",
    "omg3 = [df['omega'][i][3] for i in range(272)]\n",
    "omg4 = [df['omega'][i][4] for i in range(272)]\n",
    "omg5 = [df['omega'][i][5] for i in range(272)]\n",
    "omg6 = [df['omega'][i][6] for i in range(272)]\n",
    "omg7 = [df['omega'][i][7] for i in range(272)]\n",
    "omg8 = [df['omega'][i][8] for i in range(272)]\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(omg0)\n",
    "plt.plot(omg1)\n",
    "plt.plot(omg2)\n",
    "plt.plot(omg3)\n",
    "plt.plot(omg4)\n",
    "plt.plot(omg5)\n",
    "plt.plot(omg6)\n",
    "plt.plot(omg7)\n",
    "plt.plot(omg8)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "alloms = np.concatenate([omg0, omg1, omg2, omg3, omg4, omg5, omg6, omg7, omg8]).reshape( -1, len(omg0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(alloms.mean(axis=0))\n",
    "plt.plot(df['omega_dls'])\n",
    "plt.ylim([0,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}