{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import os\n", "\n", "from definitions import RESULTS_FOLDER, FIGURE_FOLDER\n", "\n", "figure_location = os.path.join(FIGURE_FOLDER,'twostep')\n", "if not os.path.exists(figure_location):\n", " os.makedirs(figure_location)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "data_lesion_hpc = pd.read_csv(os.path.join(RESULTS_FOLDER, 'twostep', 'results_lesion_hpc.csv'))\n", "data_lesion_dls = pd.read_csv(os.path.join(RESULTS_FOLDER, 'twostep', 'results_lesion_dls.csv'))\n", "data_control = pd.read_csv(os.path.join(RESULTS_FOLDER, 'twostep', 'results_control.csv'))" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>Unnamed: 0</th>\n", " <th>Action1</th>\n", " <th>Action2</th>\n", " <th>Agent</th>\n", " <th>DLS reliability</th>\n", " <th>HPC reliability</th>\n", " <th>P(SR)</th>\n", " <th>Reward</th>\n", " <th>StartState</th>\n", " <th>State2</th>\n", " <th>Terminus</th>\n", " <th>Trial</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.057327</td>\n", " <td>0.087327</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>3.0</td>\n", " <td>5.0</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>1</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.115198</td>\n", " <td>0.167028</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>3.0</td>\n", " <td>5.0</td>\n", " <td>1.0</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>2</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.175993</td>\n", " <td>0.239769</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>3.0</td>\n", " <td>5.0</td>\n", " <td>2.0</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>3</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.224014</td>\n", " <td>0.306158</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>3.0</td>\n", " <td>5.0</td>\n", " <td>3.0</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>4</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.252047</td>\n", " <td>0.366749</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>4.0</td>\n", " <td>7.0</td>\n", " <td>4.0</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " Unnamed: 0 Action1 Action2 Agent DLS reliability HPC reliability \\\n", "0 0 0.0 0.0 0.0 0.057327 0.087327 \n", "1 1 0.0 0.0 0.0 0.115198 0.167028 \n", "2 2 0.0 0.0 0.0 0.175993 0.239769 \n", "3 3 0.0 0.0 0.0 0.224014 0.306158 \n", "4 4 0.0 0.0 0.0 0.252047 0.366749 \n", "\n", " P(SR) Reward StartState State2 Terminus Trial \n", "0 1.0 1.0 0.0 3.0 5.0 0.0 \n", "1 1.0 1.0 0.0 3.0 5.0 1.0 \n", "2 1.0 1.0 0.0 3.0 5.0 2.0 \n", "3 1.0 0.0 0.0 3.0 5.0 3.0 \n", "4 1.0 1.0 0.0 4.0 7.0 4.0 " ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data_lesion_dls.head()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def is_common_or_rare(action, out):\n", " left_outcomes = (5, 6)\n", " right_outcomes = (7, 8)\n", " if action == 0 and out in left_outcomes:\n", " return 'common'\n", " elif action == 0 and out in right_outcomes:\n", " return 'rare'\n", " elif action == 1 and out in left_outcomes:\n", " return 'rare'\n", " elif action == 1 and out in right_outcomes:\n", " return 'common'\n", " else:\n", " raise ValueError('The combination of action and outcome does not make sense')" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def add_relevant_columns(dataframe):\n", " dataframe['PreviousAction'] = dataframe.groupby(['Agent'])['Action1'].shift(1)\n", " dataframe['PreviousStart'] = dataframe.groupby(['Agent'])['StartState'].shift(1)\n", " dataframe['PreviousReward'] = dataframe.groupby(['Agent'])['Reward'].shift(1)\n", " dataframe['Stay'] = (dataframe.PreviousAction == dataframe.Action1)\n", " dataframe['Transition'] = np.vectorize(is_common_or_rare)(dataframe['Action1'], dataframe['Terminus'])\n", " dataframe['PreviousTransition'] = dataframe.groupby(['Agent'])['Transition'].shift(1)\n", " " ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "add_relevant_columns(data_control)\n", "add_relevant_columns(data_lesion_dls)\n", "add_relevant_columns(data_lesion_hpc)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>Unnamed: 0</th>\n", " <th>Action1</th>\n", " <th>Action2</th>\n", " <th>Agent</th>\n", " <th>DLS reliability</th>\n", " <th>HPC reliability</th>\n", " <th>P(SR)</th>\n", " <th>Reward</th>\n", " <th>StartState</th>\n", " <th>State2</th>\n", " <th>Terminus</th>\n", " <th>Trial</th>\n", " <th>PreviousAction</th>\n", " <th>PreviousStart</th>\n", " <th>PreviousReward</th>\n", " <th>Stay</th>\n", " <th>Transition</th>\n", " <th>PreviousTransition</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.057327</td>\n", " <td>0.087327</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>3.0</td>\n", " <td>5.0</td>\n", " <td>0.0</td>\n", " <td>NaN</td>\n", " <td>NaN</td>\n", " <td>NaN</td>\n", " <td>False</td>\n", " <td>common</td>\n", " <td>NaN</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>1</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.115198</td>\n", " <td>0.167028</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>3.0</td>\n", " <td>5.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>True</td>\n", " <td>common</td>\n", " <td>common</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>2</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.175993</td>\n", " <td>0.239769</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>3.0</td>\n", " <td>5.0</td>\n", " <td>2.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>True</td>\n", " <td>common</td>\n", " <td>common</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>3</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.224014</td>\n", " <td>0.306158</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>3.0</td>\n", " <td>5.0</td>\n", " <td>3.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>True</td>\n", " <td>common</td>\n", " <td>common</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>4</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.252047</td>\n", " <td>0.366749</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>4.0</td>\n", " <td>7.0</td>\n", " <td>4.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>True</td>\n", " <td>rare</td>\n", " <td>common</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " Unnamed: 0 Action1 Action2 Agent DLS reliability HPC reliability \\\n", "0 0 0.0 0.0 0.0 0.057327 0.087327 \n", "1 1 0.0 0.0 0.0 0.115198 0.167028 \n", "2 2 0.0 0.0 0.0 0.175993 0.239769 \n", "3 3 0.0 0.0 0.0 0.224014 0.306158 \n", "4 4 0.0 0.0 0.0 0.252047 0.366749 \n", "\n", " P(SR) Reward StartState State2 Terminus Trial PreviousAction \\\n", "0 1.0 1.0 0.0 3.0 5.0 0.0 NaN \n", "1 1.0 1.0 0.0 3.0 5.0 1.0 0.0 \n", "2 1.0 1.0 0.0 3.0 5.0 2.0 0.0 \n", "3 1.0 0.0 0.0 3.0 5.0 3.0 0.0 \n", "4 1.0 1.0 0.0 4.0 7.0 4.0 0.0 \n", "\n", " PreviousStart PreviousReward Stay Transition PreviousTransition \n", "0 NaN NaN False common NaN \n", "1 0.0 1.0 True common common \n", "2 0.0 1.0 True common common \n", "3 0.0 1.0 True common common \n", "4 0.0 0.0 True rare common " ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data_lesion_dls.head()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "def compute_mean_stay_prob(data):\n", " means = data[data['Trial']>0].groupby(['PreviousTransition', 'PreviousReward'])['Stay'].mean()\n", " sems = data.groupby(['PreviousTransition', 'PreviousReward'])['Stay'].sem()\n", " return means, sems" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "mean_lesion_hpc, sem_lesion_hpc = compute_mean_stay_prob(data_lesion_hpc)\n", "mean_lesion_dls, sem_lesion_dls = compute_mean_stay_prob(data_lesion_dls)\n", "mean_full, sem_full = compute_mean_stay_prob(data_control)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PreviousTransition PreviousReward\n", "common 0.0 0.558473\n", " 1.0 0.836690\n", "rare 0.0 0.563003\n", " 1.0 0.823360\n", "Name: Stay, dtype: float64" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mean_lesion_hpc" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "def plot_daw_style(ax, data, yerr, title=''):\n", " lightgray = '#d1d1d1'\n", " darkgray = '#929292'\n", "\n", " bar_width= 0.2\n", "\n", " bars1 = data[:2][::-1]\n", " bars2 = data[2:][::-1]\n", " errs1 = yerr[:2][::-1]\n", " errs2 = yerr[2:][::-1]\n", "\n", " # The x position of bars\n", " r1 = np.array([0.125, 0.625]) \n", " r2 = [x + bar_width + .05 for x in r1]\n", " \n", " plt.sca(ax)\n", " \n", " plt.bar(r1, bars1, width=bar_width, color='blue', capsize=4)\n", " plt.bar(r2, bars2, width=bar_width, color='red', capsize=4)\n", " plt.xticks([r+ bar_width/2 +.025 for r in r1], ['Rewarded', 'Unrewarded'], fontsize=12)\n", " plt.yticks(fontsize=12)\n", " plt.title(title, fontsize=16)\n", " plt.ylim([0.4, 1])\n", " plt.xlim([0, 1])" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "<Figure size 720x180 with 3 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fig, axes = plt.subplots(1,3, figsize= (10,2.5), sharey=True)\n", "\n", "plot_daw_style(axes[0], list(mean_lesion_hpc), list(sem_lesion_hpc), title='Striatum')\n", "plot_daw_style(axes[1], list(mean_lesion_dls), list(sem_lesion_dls), title='Hippocampus')\n", "plot_daw_style(axes[2], list(mean_full), list(sem_full), title='Full model')\n", "\n", "\n", "\n", "leg = axes[0].legend(['Common', 'Rare'], fontsize=12, frameon=False, handlelength=0.7, title='Previous transition')\n", "plt.sca(axes[0])\n", "plt.ylabel('Stay probability', fontsize=12)\n", "\n", "plt.savefig(os.path.join(figure_location, 'StayProbability.pdf'))\n", "\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Doll et al analysis" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "figure_location = os.path.join(FIGURE_FOLDER,'twostep_deterministic')\n", "if not os.path.exists(figure_location):\n", " os.makedirs(figure_location)\n", "\n", "# load data\n", "data_lesion_hpc = pd.read_csv(os.path.join(RESULTS_FOLDER, 'twostep_deterministic', 'results_lesion_hpc.csv'))\n", "data_lesion_dls = pd.read_csv(os.path.join(RESULTS_FOLDER, 'twostep_deterministic', 'results_lesion_dls.csv'))\n", "data_control = pd.read_csv(os.path.join(RESULTS_FOLDER, 'twostep_deterministic', 'results_control.csv'))\n" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>Unnamed: 0</th>\n", " <th>Action1</th>\n", " <th>Action2</th>\n", " <th>Agent</th>\n", " <th>DLS reliability</th>\n", " <th>HPC reliability</th>\n", " <th>P(SR)</th>\n", " <th>Reward</th>\n", " <th>StartState</th>\n", " <th>State2</th>\n", " <th>Terminus</th>\n", " <th>Trial</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.030000</td>\n", " <td>0.030000</td>\n", " <td>0.681250</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>2.0</td>\n", " <td>4.0</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>1</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.059100</td>\n", " <td>0.059100</td>\n", " <td>0.657567</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>3.0</td>\n", " <td>7.0</td>\n", " <td>1.0</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>2</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.087327</td>\n", " <td>0.087327</td>\n", " <td>0.644885</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>3.0</td>\n", " <td>7.0</td>\n", " <td>2.0</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>3</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.084707</td>\n", " <td>0.114707</td>\n", " <td>0.632511</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>3.0</td>\n", " <td>7.0</td>\n", " <td>3.0</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>4</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.109166</td>\n", " <td>0.141266</td>\n", " <td>0.636400</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>3.0</td>\n", " <td>7.0</td>\n", " <td>4.0</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " Unnamed: 0 Action1 Action2 Agent DLS reliability HPC reliability \\\n", "0 0 0.0 0.0 0.0 0.030000 0.030000 \n", "1 1 1.0 1.0 0.0 0.059100 0.059100 \n", "2 2 1.0 1.0 0.0 0.087327 0.087327 \n", "3 3 1.0 1.0 0.0 0.084707 0.114707 \n", "4 4 1.0 1.0 0.0 0.109166 0.141266 \n", "\n", " P(SR) Reward StartState State2 Terminus Trial \n", "0 0.681250 0.0 0.0 2.0 4.0 0.0 \n", "1 0.657567 0.0 1.0 3.0 7.0 1.0 \n", "2 0.644885 0.0 1.0 3.0 7.0 2.0 \n", "3 0.632511 1.0 1.0 3.0 7.0 3.0 \n", "4 0.636400 0.0 1.0 3.0 7.0 4.0 " ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data_control.head()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "def add_relevant_columns(dataframe):\n", " dataframe['PreviousAction'] = dataframe.groupby(['Agent'])['Action1'].shift(1)\n", " dataframe['PreviousStart'] = dataframe.groupby(['Agent'])['StartState'].shift(1)\n", " dataframe['PreviousReward'] = dataframe.groupby(['Agent'])['Reward'].shift(1)\n", " dataframe['Stay'] = (dataframe.PreviousAction == dataframe.Action1)\n", " dataframe['SameStart'] = (dataframe.StartState == dataframe.PreviousStart)\n", "\n", "\n", "add_relevant_columns(data_control)\n", "add_relevant_columns(data_lesion_dls)\n", "add_relevant_columns(data_lesion_hpc)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "def compute_mean_stay_prob(data):\n", " means = data[data['Trial']>0].groupby(['PreviousReward', 'SameStart'])['Stay'].mean()\n", " sems = data.groupby(['PreviousReward', 'SameStart'])['Stay'].sem()\n", " return means[::-1], sems[::-1]\n" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "mean_lesion_hpc, sem_lesion_hpc = compute_mean_stay_prob(data_lesion_hpc)\n", "mean_lesion_dls, sem_lesion_dls = compute_mean_stay_prob(data_lesion_dls)\n", "mean_full, sem_full = compute_mean_stay_prob(data_control)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PreviousReward SameStart\n", "1.0 True 0.953491\n", " False 0.956948\n", "0.0 True 0.475589\n", " False 0.475560\n", "Name: Stay, dtype: float64" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mean_lesion_dls" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "def plot_doll_style(ax, data, yerr, title=''):\n", " lightgray = '#d1d1d1'\n", " darkgray = '#929292'\n", "\n", " bar_width = 0.2\n", "\n", " bars1 = data[:2]\n", " bars2 = data[2:]\n", " errs1 = yerr[:2]\n", " errs2 = yerr[2:]\n", "\n", " # The x position of bars\n", " r1 = np.arange(len(bars1)) * .8 + 1.5 * bar_width\n", " r2 = [x + bar_width for x in r1]\n", "\n", " plt.sca(ax)\n", "\n", " plt.bar(r1, bars1, width=bar_width, color=lightgray, yerr=errs1, capsize=4)\n", " plt.bar(r2, bars2, width=bar_width, color=darkgray, yerr=errs2, capsize=4)\n", " plt.ylabel('Stay probability', fontsize=15)\n", " plt.xticks([r + bar_width / 2 for r in r1], ['same', 'different'], fontsize=15)\n", " plt.yticks(fontsize=15)\n", " plt.title(title, fontsize=18)\n", " plt.ylim([0.40, .96])\n", " plt.xlim([0, 1.6])\n", "\n", " ax.spines['right'].set_visible(False)\n", " ax.spines['top'].set_visible(False)\n" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "<Figure size 720x180 with 3 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fig, axes = plt.subplots(1,3, figsize= (10,2.5), sharey=True)\n", "\n", "plot_doll_style(axes[0], list(mean_lesion_hpc), list(sem_lesion_hpc), title='Striatum')\n", "plot_doll_style(axes[1], list(mean_lesion_dls), list(sem_lesion_dls), title='Hippocampus')\n", "plot_doll_style(axes[2], list(mean_full), list(sem_full), title='Full model')\n", "\n", "\n", "\n", "leg = axes[1].legend(['Reward', 'No reward'], fontsize=12, frameon=False, handlelength=.7)\n", "leg.set_title('Previous outcome', prop={'size': 12})\n", "plt.tight_layout()\n", "\n", "plt.savefig(os.path.join(figure_location, 'StayProbability_DeterministicTask.pdf'))\n", "\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PreviousReward SameStart\n", "1.0 True 0.742813\n", " False 0.551995\n", "0.0 True 0.632391\n", " False 0.525771\n", "Name: Stay, dtype: float64" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mean_lesion_hpc" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }