{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../..')\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from definitions import ROOT_FOLDER\n",
    "import os\n",
    "from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec\n",
    "from matplotlib.patches import Circle\n",
    "import matplotlib.patches\n",
    "from hippocampus.environments import BlockingStudy\n",
    "\n",
    "from hippocampus.plotting import tsplot_boot\n",
    "%matplotlib notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "results_folder = os.path.join(ROOT_FOLDER, 'results', 'blocking')\n",
    "figure_folder = os.path.join(ROOT_FOLDER, 'results', 'figures', 'blocking')\n",
    "if not os.path.exists(figure_folder):\n",
    "    os.makedirs(figure_folder)\n",
    "\n",
    "boundary_blocking_data = np.load(os.path.join(results_folder, 'boundary_blocking_results.npy'))\n",
    "landmark_blocking_data = np.load(os.path.join(results_folder, 'landmark_blocking_results.npy'))\n",
    "\n",
    "en = BlockingStudy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots()\n",
    "\n",
    "tsplot_boot(ax, boundary_blocking_data)\n",
    "tsplot_boot(ax, landmark_blocking_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define colours\n",
    "colour_palette = sns.color_palette()\n",
    "cue1_colour = colour_palette[8]\n",
    "cue2_colour = colour_palette[9]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure()\n",
    "\n",
    "gs = GridSpec(1, 2)\n",
    "\n",
    "### Make paradigm illustration\n",
    "inner = GridSpecFromSubplotSpec(3, 1, subplot_spec=gs[0, 0], wspace=0.01, hspace=0.1)\n",
    "\n",
    "ax1 = plt.Subplot(fig, inner[0, 0])\n",
    "ax1.axis('equal')\n",
    "ax1.set_xlim(0, 1)\n",
    "ax1.set_ylim(0, 1)\n",
    "ax1.axis('off')\n",
    "ax1.text(-.1, .7, 'Initial\\nlearning', fontsize=12, transform=ax1.transAxes, ha='left', va='top')\n",
    "ax1.scatter([.65], [.5], marker='P', color=cue1_colour, s=400, linestyle='None')\n",
    "ax1.text(.65, .43, 'L1', va='top', ha='center', color=cue1_colour)\n",
    "ax1.scatter([.8], [.3], marker='P', color=cue2_colour, s=400, linestyle='dotted', facecolors='none')\n",
    "platform = Circle((.45, .3), .05, fill=False, linestyle='--')\n",
    "ax1.add_artist(platform)\n",
    "\n",
    "\n",
    "ax2 = plt.Subplot(fig, inner[1, 0])\n",
    "ax2.axis('equal')\n",
    "ax2.set_xlim(0, 1)\n",
    "ax2.set_ylim(0, 1)\n",
    "ax2.axis('off')\n",
    "ax2.text(-.1, .7, 'Compound\\nlearning', fontsize=12, transform=ax2.transAxes, ha='left', va='top')\n",
    "ax2.scatter([.65], [.5], marker='P', color=cue1_colour, s=400, linestyle='None')\n",
    "ax2.scatter([.8], [.3], marker='P', color=cue2_colour, s=400)\n",
    "platform = Circle((.45, .3), .05, fill=False, linestyle='--')\n",
    "ax2.add_artist(platform)\n",
    "ax2.text(.65, .43, 'L1', va='top', ha='center', color=cue1_colour)\n",
    "ax2.text(.8, .23, 'L2', va='top', ha='center', color=cue2_colour)\n",
    "\n",
    "\n",
    "ax3 = plt.Subplot(fig, inner[2, 0])\n",
    "ax3.axis('equal')\n",
    "ax3.set_xlim(0, 1)\n",
    "ax3.set_ylim(0, 1)\n",
    "ax3.axis('off')\n",
    "ax3.text(-.1, .7, 'Testing', fontsize=12, transform=ax3.transAxes, ha='left', va='top')\n",
    "ax3.scatter([.65], [.5], marker='P', color=cue1_colour, s=400, linestyle='dotted', facecolors='none')\n",
    "ax3.scatter([.8], [.3], marker='P', color=cue2_colour, s=400)\n",
    "platform = Circle((.45, .3), .05, fill=False, linestyle='--')\n",
    "ax3.add_artist(platform)\n",
    "ax3.text(.8, .23, 'L2', va='top', ha='center', color=cue2_colour)\n",
    "\n",
    "for ax in [ax1, ax2, ax3]:\n",
    "    ax.text(.45, .23, 'Platform', va='top', ha='center', fontstyle='italic')\n",
    "\n",
    "for ax in [ax2, ax3]:\n",
    "    ax.axhline(y=.7, xmin=0.3, xmax=.9, color='k', linestyle='dashed')\n",
    "\n",
    "fig.add_subplot(ax1)\n",
    "fig.add_subplot(ax2)\n",
    "fig.add_subplot(ax3)\n",
    "\n",
    "### Plot results \n",
    "results_ax = fig.add_subplot(gs[0, 1])\n",
    "\n",
    "tsplot_boot(results_ax, landmark_blocking_data, color=colour_palette[3])\n",
    "\n",
    "results_ax.set_ylabel('Time steps', fontsize=12)\n",
    "results_ax.set_xlabel('Trials', fontsize=12)\n",
    "# Hide the right and top spines\n",
    "results_ax.spines['right'].set_visible(False)\n",
    "results_ax.spines['top'].set_visible(False)\n",
    "# Show landmark timelines\n",
    "results_ax.axhline(y=results_ax.get_ylim()[1] * .90, xmin=0, xmax=.6667, color=cue1_colour, LineWidth=5)\n",
    "results_ax.axhline(y=results_ax.get_ylim()[1] * .85, xmin=.3333, xmax=1, color=cue2_colour, LineWidth=5)\n",
    "results_ax.text(results_ax.get_xlim()[1] * .01, results_ax.get_ylim()[1] * .92, 'Landmark 1 present', color=cue1_colour)\n",
    "results_ax.text(results_ax.get_xlim()[1] * .99, results_ax.get_ylim()[1] * .83, 'Landmark 2 present',\n",
    "                ha='right', va='top', color=cue2_colour)\n",
    "\n",
    "results_ax.margins(.05)\n",
    "\n",
    "b = BlockingStudy()\n",
    "# Add figure labels\n",
    "ax1.text(-.15, 1.15, 'A', transform=ax1.transAxes, fontsize=16, fontweight='bold', va='top', ha='right')\n",
    "results_ax.text(-.15, 1.05, 'B', transform=results_ax.transAxes,\n",
    "                     fontsize=16, fontweight='bold', va='top', ha='right')\n",
    "\n",
    "\n",
    "# Now some arrows connecting them\n",
    "# 1. Get transformation operators for axis and figure\n",
    "ax1tr = ax1.transData  # Axis 1 -> Display\n",
    "ax2tr = ax2.transData  # Axis 2 -> Display\n",
    "ax3tr = ax3.transData  # Axis 3 -> Display\n",
    "figtr = fig.transFigure.inverted()  # Display -> Figure\n",
    "# 2. Transform arrow start point from axis 0 to figure coordinates\n",
    "# 2. Transform arrow start point from axis 0 to figure coordinates\n",
    "ptB = figtr.transform(ax1tr.transform((0, .4)))\n",
    "# 3. Transform arrow end point from axis 1 to figure coordinates\n",
    "ptE = figtr.transform(ax2tr.transform((0, .95)))\n",
    "\n",
    "ptB2 = figtr.transform(ax2tr.transform((0, .4)))\n",
    "# 3. Transform arrow end point from axis 1 to figure coordinates\n",
    "ptE2 = figtr.transform(ax3tr.transform((0, .95)))\n",
    "\n",
    "# 4. Create the patch\n",
    "arrow = matplotlib.patches.FancyArrowPatch(\n",
    "    ptB, ptE, transform=fig.transFigure, fc=\"k\", arrowstyle='simple', alpha=1., mutation_scale=40.)\n",
    "fig.patches.append(arrow)\n",
    "arrow2 = matplotlib.patches.FancyArrowPatch(\n",
    "    ptB2, ptE2, transform=fig.transFigure, fc=\"k\", arrowstyle='simple', alpha=1., mutation_scale=40.)\n",
    "fig.patches.append(arrow2)\n",
    "\n",
    "### make boundary blocking explanation\n",
    "#inner = GridSpecFromSubplotSpec(3, 1, subplot_spec=gs[2], wspace=0.01, hspace=0.1)\n",
    "#ax1 = plt.Subplot(fig, inner[1, 0])\n",
    "\n",
    "#en.draw_boundaries(ax1)\n",
    "\n",
    "plt.savefig(os.path.join(figure_folder, 'landmark_blocking_plot.pdf'), format='pdf')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure()\n",
    "\n",
    "gs = GridSpec(1, 2)\n",
    "\n",
    "colours = [cue2_colour, cue2_colour, cue1_colour, cue1_colour, cue1_colour, cue2_colour]\n",
    "\n",
    "### Make paradigm illustration\n",
    "inner = GridSpecFromSubplotSpec(3, 1, subplot_spec=gs[0, 0], wspace=0.01, hspace=0.1)\n",
    "\n",
    "ax1 = plt.Subplot(fig, inner[0, 0])\n",
    "en.draw_boundaries(ax1, colors=colours, linestyles=[':', ':', '-', '-', '-', ':'])\n",
    "ax1.axis('off')\n",
    "ax1.text(-.2, .7, 'Initial\\nlearning', fontsize=12, transform=ax1.transAxes, ha='left', va='top')\n",
    "ax1.text(-.15, 1.15, 'C', transform=ax1.transAxes, fontsize=16, fontweight='bold', va='top', ha='right')\n",
    "\n",
    "\n",
    "ax2 = plt.Subplot(fig, inner[1, 0])\n",
    "en.draw_boundaries(ax2, colors=colours, linestyles=['-', '-', '-', '-', '-', '-'])\n",
    "ax2.text(-.2, .7, 'Compound\\nlearning', fontsize=12, transform=ax2.transAxes, ha='left', va='top')\n",
    "ax2.axis('off')\n",
    "\n",
    "\n",
    "ax3 = plt.Subplot(fig, inner[2, 0])\n",
    "en.draw_boundaries(ax3, colors=colours, linestyles=['-', '-', ':', ':', ':', '-'])\n",
    "ax3.axis('off')\n",
    "ax3.text(-.2, .7, 'Testing', fontsize=12, transform=ax3.transAxes, ha='left', va='top')\n",
    "\n",
    "for ax in [ax2, ax3]:\n",
    "    ax.axhline(y=1000, xmin=0.2, xmax=.8, color='k', linestyle='dashed')\n",
    "\n",
    "\n",
    "fig.add_subplot(ax1)\n",
    "fig.add_subplot(ax2)\n",
    "fig.add_subplot(ax3)\n",
    "\n",
    "####################################################\n",
    "# Now plot the data\n",
    "\n",
    "### Plot results \n",
    "results_ax = fig.add_subplot(gs[0, 1])\n",
    "\n",
    "tsplot_boot(results_ax, boundary_blocking_data, color=colour_palette[0])\n",
    "\n",
    "\n",
    "results_ax.set_ylabel('Time steps', fontsize=12)\n",
    "results_ax.set_xlabel('Trials', fontsize=12)\n",
    "# Hide the right and top spines\n",
    "results_ax.spines['right'].set_visible(False)\n",
    "results_ax.spines['top'].set_visible(False)\n",
    "# Show landmark timelines\n",
    "results_ax.axhline(y=results_ax.get_ylim()[1] * .90, xmin=0, xmax=.6667, color=cue1_colour, LineWidth=5)\n",
    "results_ax.axhline(y=results_ax.get_ylim()[1] * .85, xmin=.3333, xmax=1, color=cue2_colour, LineWidth=5)\n",
    "results_ax.text(results_ax.get_xlim()[1] * .01, results_ax.get_ylim()[1] * .92, 'Left boundary present', color=cue1_colour)\n",
    "results_ax.text(results_ax.get_xlim()[1] * .99, results_ax.get_ylim()[1] * .83, 'Right boundary present',\n",
    "                ha='right', va='top', color=cue2_colour)\n",
    "\n",
    "results_ax.margins(.05)\n",
    "\n",
    "b = BlockingStudy()\n",
    "# Add figure labels\n",
    "results_ax.text(-.15, 1.05, 'D', transform=results_ax.transAxes,\n",
    "                     fontsize=16, fontweight='bold', va='top', ha='right')\n",
    "\n",
    "results_ax.set_ylim([0, 230])\n",
    "\n",
    "# Now some arrows connecting them\n",
    "# 1. Get transformation operators for axis and figure\n",
    "ax1tr = ax1.transData  # Axis 1 -> Display\n",
    "ax2tr = ax2.transData  # Axis 2 -> Display\n",
    "ax3tr = ax3.transData  # Axis 3 -> Display\n",
    "figtr = fig.transFigure.inverted()  # Display -> Figure\n",
    "# 2. Transform arrow start point from axis 0 to figure coordinates\n",
    "# 2. Transform arrow start point from axis 0 to figure coordinates\n",
    "ptB = figtr.transform(ax1tr.transform((-800, -200)))\n",
    "# 3. Transform arrow end point from axis 1 to figure coordinates\n",
    "ptE = figtr.transform(ax2tr.transform((-800, 500)))\n",
    "\n",
    "ptB2 = figtr.transform(ax2tr.transform((-800, -200)))\n",
    "# 3. Transform arrow end point from axis 1 to figure coordinates\n",
    "ptE2 = figtr.transform(ax3tr.transform((-800, 500)))\n",
    "\n",
    "# 4. Create the patch\n",
    "arrow = matplotlib.patches.FancyArrowPatch(\n",
    "    ptB, ptE, transform=fig.transFigure, fc=\"k\", arrowstyle='simple', alpha=1., mutation_scale=40.)\n",
    "fig.patches.append(arrow)\n",
    "arrow2 = matplotlib.patches.FancyArrowPatch(\n",
    "    ptB2, ptE2, transform=fig.transFigure, fc=\"k\", arrowstyle='simple', alpha=1., mutation_scale=40.)\n",
    "fig.patches.append(arrow2)\n",
    "\n",
    "\n",
    "plt.savefig(os.path.join(figure_folder, 'boundary_blocking_plot.pdf'), format='pdf')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}