{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.append('../..')\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "from definitions import ROOT_FOLDER\n", "import os\n", "from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec\n", "from matplotlib.patches import Circle\n", "import matplotlib.patches\n", "from hippocampus.environments import BlockingStudy\n", "\n", "from hippocampus.plotting import tsplot_boot\n", "%matplotlib notebook" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "results_folder = os.path.join(ROOT_FOLDER, 'results', 'blocking')\n", "figure_folder = os.path.join(ROOT_FOLDER, 'results', 'figures', 'blocking')\n", "if not os.path.exists(figure_folder):\n", " os.makedirs(figure_folder)\n", "\n", "boundary_blocking_data = np.load(os.path.join(results_folder, 'boundary_blocking_results.npy'))\n", "landmark_blocking_data = np.load(os.path.join(results_folder, 'landmark_blocking_results.npy'))\n", "\n", "en = BlockingStudy()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots()\n", "\n", "tsplot_boot(ax, boundary_blocking_data)\n", "tsplot_boot(ax, landmark_blocking_data)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# define colours\n", "colour_palette = sns.color_palette()\n", "cue1_colour = colour_palette[8]\n", "cue2_colour = colour_palette[9]\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig = plt.figure()\n", "\n", "gs = GridSpec(1, 2)\n", "\n", "### Make paradigm illustration\n", "inner = GridSpecFromSubplotSpec(3, 1, subplot_spec=gs[0, 0], wspace=0.01, hspace=0.1)\n", "\n", "ax1 = plt.Subplot(fig, inner[0, 0])\n", "ax1.axis('equal')\n", "ax1.set_xlim(0, 1)\n", "ax1.set_ylim(0, 1)\n", "ax1.axis('off')\n", "ax1.text(-.1, .7, 'Initial\\nlearning', fontsize=12, transform=ax1.transAxes, ha='left', va='top')\n", "ax1.scatter([.65], [.5], marker='P', color=cue1_colour, s=400, linestyle='None')\n", "ax1.text(.65, .43, 'L1', va='top', ha='center', color=cue1_colour)\n", "ax1.scatter([.8], [.3], marker='P', color=cue2_colour, s=400, linestyle='dotted', facecolors='none')\n", "platform = Circle((.45, .3), .05, fill=False, linestyle='--')\n", "ax1.add_artist(platform)\n", "\n", "\n", "ax2 = plt.Subplot(fig, inner[1, 0])\n", "ax2.axis('equal')\n", "ax2.set_xlim(0, 1)\n", "ax2.set_ylim(0, 1)\n", "ax2.axis('off')\n", "ax2.text(-.1, .7, 'Compound\\nlearning', fontsize=12, transform=ax2.transAxes, ha='left', va='top')\n", "ax2.scatter([.65], [.5], marker='P', color=cue1_colour, s=400, linestyle='None')\n", "ax2.scatter([.8], [.3], marker='P', color=cue2_colour, s=400)\n", "platform = Circle((.45, .3), .05, fill=False, linestyle='--')\n", "ax2.add_artist(platform)\n", "ax2.text(.65, .43, 'L1', va='top', ha='center', color=cue1_colour)\n", "ax2.text(.8, .23, 'L2', va='top', ha='center', color=cue2_colour)\n", "\n", "\n", "ax3 = plt.Subplot(fig, inner[2, 0])\n", "ax3.axis('equal')\n", "ax3.set_xlim(0, 1)\n", "ax3.set_ylim(0, 1)\n", "ax3.axis('off')\n", "ax3.text(-.1, .7, 'Testing', fontsize=12, transform=ax3.transAxes, ha='left', va='top')\n", "ax3.scatter([.65], [.5], marker='P', color=cue1_colour, s=400, linestyle='dotted', facecolors='none')\n", "ax3.scatter([.8], [.3], marker='P', color=cue2_colour, s=400)\n", "platform = Circle((.45, .3), .05, fill=False, linestyle='--')\n", "ax3.add_artist(platform)\n", "ax3.text(.8, .23, 'L2', va='top', ha='center', color=cue2_colour)\n", "\n", "for ax in [ax1, ax2, ax3]:\n", " ax.text(.45, .23, 'Platform', va='top', ha='center', fontstyle='italic')\n", "\n", "for ax in [ax2, ax3]:\n", " ax.axhline(y=.7, xmin=0.3, xmax=.9, color='k', linestyle='dashed')\n", "\n", "fig.add_subplot(ax1)\n", "fig.add_subplot(ax2)\n", "fig.add_subplot(ax3)\n", "\n", "### Plot results \n", "results_ax = fig.add_subplot(gs[0, 1])\n", "\n", "tsplot_boot(results_ax, landmark_blocking_data, color=colour_palette[3])\n", "\n", "results_ax.set_ylabel('Time steps', fontsize=12)\n", "results_ax.set_xlabel('Trials', fontsize=12)\n", "# Hide the right and top spines\n", "results_ax.spines['right'].set_visible(False)\n", "results_ax.spines['top'].set_visible(False)\n", "# Show landmark timelines\n", "results_ax.axhline(y=results_ax.get_ylim()[1] * .90, xmin=0, xmax=.6667, color=cue1_colour, LineWidth=5)\n", "results_ax.axhline(y=results_ax.get_ylim()[1] * .85, xmin=.3333, xmax=1, color=cue2_colour, LineWidth=5)\n", "results_ax.text(results_ax.get_xlim()[1] * .01, results_ax.get_ylim()[1] * .92, 'Landmark 1 present', color=cue1_colour)\n", "results_ax.text(results_ax.get_xlim()[1] * .99, results_ax.get_ylim()[1] * .83, 'Landmark 2 present',\n", " ha='right', va='top', color=cue2_colour)\n", "\n", "results_ax.margins(.05)\n", "\n", "b = BlockingStudy()\n", "# Add figure labels\n", "ax1.text(-.15, 1.15, 'A', transform=ax1.transAxes, fontsize=16, fontweight='bold', va='top', ha='right')\n", "results_ax.text(-.15, 1.05, 'B', transform=results_ax.transAxes,\n", " fontsize=16, fontweight='bold', va='top', ha='right')\n", "\n", "\n", "# Now some arrows connecting them\n", "# 1. Get transformation operators for axis and figure\n", "ax1tr = ax1.transData # Axis 1 -> Display\n", "ax2tr = ax2.transData # Axis 2 -> Display\n", "ax3tr = ax3.transData # Axis 3 -> Display\n", "figtr = fig.transFigure.inverted() # Display -> Figure\n", "# 2. Transform arrow start point from axis 0 to figure coordinates\n", "# 2. Transform arrow start point from axis 0 to figure coordinates\n", "ptB = figtr.transform(ax1tr.transform((0, .4)))\n", "# 3. Transform arrow end point from axis 1 to figure coordinates\n", "ptE = figtr.transform(ax2tr.transform((0, .95)))\n", "\n", "ptB2 = figtr.transform(ax2tr.transform((0, .4)))\n", "# 3. Transform arrow end point from axis 1 to figure coordinates\n", "ptE2 = figtr.transform(ax3tr.transform((0, .95)))\n", "\n", "# 4. Create the patch\n", "arrow = matplotlib.patches.FancyArrowPatch(\n", " ptB, ptE, transform=fig.transFigure, fc=\"k\", arrowstyle='simple', alpha=1., mutation_scale=40.)\n", "fig.patches.append(arrow)\n", "arrow2 = matplotlib.patches.FancyArrowPatch(\n", " ptB2, ptE2, transform=fig.transFigure, fc=\"k\", arrowstyle='simple', alpha=1., mutation_scale=40.)\n", "fig.patches.append(arrow2)\n", "\n", "### make boundary blocking explanation\n", "#inner = GridSpecFromSubplotSpec(3, 1, subplot_spec=gs[2], wspace=0.01, hspace=0.1)\n", "#ax1 = plt.Subplot(fig, inner[1, 0])\n", "\n", "#en.draw_boundaries(ax1)\n", "\n", "plt.savefig(os.path.join(figure_folder, 'landmark_blocking_plot.pdf'), format='pdf')\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig = plt.figure()\n", "\n", "gs = GridSpec(1, 2)\n", "\n", "colours = [cue2_colour, cue2_colour, cue1_colour, cue1_colour, cue1_colour, cue2_colour]\n", "\n", "### Make paradigm illustration\n", "inner = GridSpecFromSubplotSpec(3, 1, subplot_spec=gs[0, 0], wspace=0.01, hspace=0.1)\n", "\n", "ax1 = plt.Subplot(fig, inner[0, 0])\n", "en.draw_boundaries(ax1, colors=colours, linestyles=[':', ':', '-', '-', '-', ':'])\n", "ax1.axis('off')\n", "ax1.text(-.2, .7, 'Initial\\nlearning', fontsize=12, transform=ax1.transAxes, ha='left', va='top')\n", "ax1.text(-.15, 1.15, 'C', transform=ax1.transAxes, fontsize=16, fontweight='bold', va='top', ha='right')\n", "\n", "\n", "ax2 = plt.Subplot(fig, inner[1, 0])\n", "en.draw_boundaries(ax2, colors=colours, linestyles=['-', '-', '-', '-', '-', '-'])\n", "ax2.text(-.2, .7, 'Compound\\nlearning', fontsize=12, transform=ax2.transAxes, ha='left', va='top')\n", "ax2.axis('off')\n", "\n", "\n", "ax3 = plt.Subplot(fig, inner[2, 0])\n", "en.draw_boundaries(ax3, colors=colours, linestyles=['-', '-', ':', ':', ':', '-'])\n", "ax3.axis('off')\n", "ax3.text(-.2, .7, 'Testing', fontsize=12, transform=ax3.transAxes, ha='left', va='top')\n", "\n", "for ax in [ax2, ax3]:\n", " ax.axhline(y=1000, xmin=0.2, xmax=.8, color='k', linestyle='dashed')\n", "\n", "\n", "fig.add_subplot(ax1)\n", "fig.add_subplot(ax2)\n", "fig.add_subplot(ax3)\n", "\n", "####################################################\n", "# Now plot the data\n", "\n", "### Plot results \n", "results_ax = fig.add_subplot(gs[0, 1])\n", "\n", "tsplot_boot(results_ax, boundary_blocking_data, color=colour_palette[0])\n", "\n", "\n", "results_ax.set_ylabel('Time steps', fontsize=12)\n", "results_ax.set_xlabel('Trials', fontsize=12)\n", "# Hide the right and top spines\n", "results_ax.spines['right'].set_visible(False)\n", "results_ax.spines['top'].set_visible(False)\n", "# Show landmark timelines\n", "results_ax.axhline(y=results_ax.get_ylim()[1] * .90, xmin=0, xmax=.6667, color=cue1_colour, LineWidth=5)\n", "results_ax.axhline(y=results_ax.get_ylim()[1] * .85, xmin=.3333, xmax=1, color=cue2_colour, LineWidth=5)\n", "results_ax.text(results_ax.get_xlim()[1] * .01, results_ax.get_ylim()[1] * .92, 'Left boundary present', color=cue1_colour)\n", "results_ax.text(results_ax.get_xlim()[1] * .99, results_ax.get_ylim()[1] * .83, 'Right boundary present',\n", " ha='right', va='top', color=cue2_colour)\n", "\n", "results_ax.margins(.05)\n", "\n", "b = BlockingStudy()\n", "# Add figure labels\n", "results_ax.text(-.15, 1.05, 'D', transform=results_ax.transAxes,\n", " fontsize=16, fontweight='bold', va='top', ha='right')\n", "\n", "results_ax.set_ylim([0, 230])\n", "\n", "# Now some arrows connecting them\n", "# 1. Get transformation operators for axis and figure\n", "ax1tr = ax1.transData # Axis 1 -> Display\n", "ax2tr = ax2.transData # Axis 2 -> Display\n", "ax3tr = ax3.transData # Axis 3 -> Display\n", "figtr = fig.transFigure.inverted() # Display -> Figure\n", "# 2. Transform arrow start point from axis 0 to figure coordinates\n", "# 2. Transform arrow start point from axis 0 to figure coordinates\n", "ptB = figtr.transform(ax1tr.transform((-800, -200)))\n", "# 3. Transform arrow end point from axis 1 to figure coordinates\n", "ptE = figtr.transform(ax2tr.transform((-800, 500)))\n", "\n", "ptB2 = figtr.transform(ax2tr.transform((-800, -200)))\n", "# 3. Transform arrow end point from axis 1 to figure coordinates\n", "ptE2 = figtr.transform(ax3tr.transform((-800, 500)))\n", "\n", "# 4. Create the patch\n", "arrow = matplotlib.patches.FancyArrowPatch(\n", " ptB, ptE, transform=fig.transFigure, fc=\"k\", arrowstyle='simple', alpha=1., mutation_scale=40.)\n", "fig.patches.append(arrow)\n", "arrow2 = matplotlib.patches.FancyArrowPatch(\n", " ptB2, ptE2, transform=fig.transFigure, fc=\"k\", arrowstyle='simple', alpha=1., mutation_scale=40.)\n", "fig.patches.append(arrow2)\n", "\n", "\n", "plt.savefig(os.path.join(figure_folder, 'boundary_blocking_plot.pdf'), format='pdf')\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }