{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm_notebook as tqdm\n",
    "from hippocampus.plotting import tsplot_boot\n",
    "import matplotlib.pyplot as plt\n",
    "from definitions import ROOT_FOLDER\n",
    "import os\n",
    "import pandas as pd\n",
    "from hippocampus.agents import CombinedAgent\n",
    "from hippocampus.environments import HexWaterMaze\n",
    "import numpy as np\n",
    "import random\n",
    "import seaborn as sns\n",
    "from multiprocessing import Pool\n",
    "%matplotlib notebook\n",
    "\n",
    "from definitions import ROOT_FOLDER\n",
    "from datetime import datetime\n",
    "results_folder = os.path.join(ROOT_FOLDER, 'results/pearce/bigmaze')\n",
    "if not os.path.exists(results_folder):\n",
    "    os.makedirs(results_folder)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# some hyperparameters \n",
    "\n",
    "eta = 0.04 #.06  # learning rate for reliability estimators\n",
    "learning_rate = .1  # step size for value and SR learning\n",
    "inv_temp = 1.5  # inverse temperature parameter for softmax action selection\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = HexWaterMaze(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e9db49a6f2e64e269f6c39cf7b13cfa6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=11), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=4), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=4), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=4), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=4), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=4), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=4), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=4), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=4), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=4), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=4), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=4), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "df = pd.DataFrame({})\n",
    "\n",
    "datetime = str(datetime.now())\n",
    "\n",
    "run_results_folder = os.path.join(results_folder, datetime)\n",
    "if not os.path.exists(run_results_folder):\n",
    "    os.makedirs(run_results_folder)\n",
    "\n",
    "#possible_platform_states = [48, 45, 42, 39, 60, 57, 54, 57]\n",
    "possible_platform_states = [192, 185, 181, 174, 216, 210, 203, 197]  # for the r = 10 case\n",
    "\n",
    "for n_agent in tqdm(range(30)):\n",
    "    \n",
    "    \n",
    "    \n",
    "    random.seed(n_agent)\n",
    "    random.shuffle(possible_platform_states)\n",
    "    g.set_platform_state(possible_platform_states[6]) \n",
    "\n",
    "\n",
    "    agent = CombinedAgent(g, init_sr='rw', lesion_dls=False, lesion_hpc=False, eta=eta,\n",
    "                         learning_rate=learning_rate, inv_temp=inv_temp)\n",
    "    agent_results = []\n",
    "    agent_ets = []\n",
    "    session = 0\n",
    "\n",
    "    total_trial_count = 0\n",
    "\n",
    "    for ses in tqdm(range(11), leave=False):\n",
    "        for trial in tqdm(range(4),leave=False):\n",
    "            if trial == 0: \n",
    "                g.set_platform_state(possible_platform_states[ses % len(possible_platform_states)])\n",
    "            res = agent.one_episode(random_policy=False)\n",
    "            res['trial'] = trial\n",
    "            res['escape time'] = res.time.max()\n",
    "            res['session'] = ses\n",
    "            res['total trial'] = total_trial_count\n",
    "            agent_results.append(res)\n",
    "            agent_ets.append(res.time.max())\n",
    "            \n",
    "            df = df.append({'agent': n_agent,\n",
    "                           'session': ses,\n",
    "                           'trial': trial,\n",
    "                           'escape time': res.time.max(),\n",
    "                           'platform': res.platform.iloc[-1],\n",
    "                           'start': agent.env.starting_state, \n",
    "                           'P(SR)': res['P(SR)'].mean(),\n",
    "                           'total trial': total_trial_count}, ignore_index=True)\n",
    "            total_trial_count += 1\n",
    "            \n",
    "        agent.inv_temp += .8\n",
    "        \n",
    "    agent_df = pd.concat(agent_results)\n",
    "    agent_df['agent'] = n_agent\n",
    "    agent_df['total time']= np.arange(len(agent_df))\n",
    "\n",
    "    agent_df.to_csv(os.path.join(run_results_folder,'results_r10_agent{}.csv'.format(n_agent)))\n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.head(20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "first_last = df[np.logical_or(df.trial==0,df.trial==3)]\n",
    "\n",
    "sns.lineplot(data=first_last, x='session', y='escape time', hue='trial', ci=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "sns.lineplot(data=df,x='total trial', y='escape time', hue='agent')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.agent.unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}