{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.append('../..')\n", "\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import os\n", "\n", "from definitions import RESULTS_FOLDER, FIGURE_FOLDER\n", "import statsmodels.formula.api as smf\n", "\n", "\n", "figure_location = os.path.join(FIGURE_FOLDER,'twostep')\n", "if not os.path.exists(figure_location):\n", " os.makedirs(figure_location)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_lesion_hpc = pd.read_csv(os.path.join(RESULTS_FOLDER, 'twostep', 'results_lesion_hpc.csv'))\n", "data_lesion_dls = pd.read_csv(os.path.join(RESULTS_FOLDER, 'twostep', 'results_lesion_dls.csv'))\n", "data_control = pd.read_csv(os.path.join(RESULTS_FOLDER, 'twostep', 'results_control.csv'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_lesion_dls.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def is_common_or_rare(action, out):\n", " left_outcomes = (5, 6)\n", " right_outcomes = (7, 8)\n", " if action == 0 and out in left_outcomes:\n", " return 'common'\n", " elif action == 0 and out in right_outcomes:\n", " return 'rare'\n", " elif action == 1 and out in left_outcomes:\n", " return 'rare'\n", " elif action == 1 and out in right_outcomes:\n", " return 'common'\n", " else:\n", " raise ValueError('The combination of action and outcome does not make sense')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def add_relevant_columns(dataframe):\n", " dataframe['PreviousAction'] = dataframe.groupby(['Agent'])['Action1'].shift(1)\n", " dataframe['PreviousStart'] = dataframe.groupby(['Agent'])['StartState'].shift(1)\n", " dataframe['PreviousReward'] = dataframe.groupby(['Agent'])['Reward'].shift(1)\n", " dataframe['Stay'] = (dataframe.PreviousAction == dataframe.Action1)\n", " dataframe['Transition'] = np.vectorize(is_common_or_rare)(dataframe['Action1'], dataframe['Terminus'])\n", " dataframe['PreviousTransition'] = dataframe.groupby(['Agent'])['Transition'].shift(1)\n", " " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "add_relevant_columns(data_control)\n", "add_relevant_columns(data_lesion_dls)\n", "add_relevant_columns(data_lesion_hpc)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_lesion_dls.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def compute_mean_stay_prob(data):\n", " means = data[data['Trial']>0].groupby(['PreviousTransition', 'PreviousReward'])['Stay'].mean()\n", " sems = data.groupby(['PreviousTransition', 'PreviousReward'])['Stay'].sem()\n", " return means, sems" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mean_lesion_hpc, sem_lesion_hpc = compute_mean_stay_prob(data_lesion_hpc)\n", "mean_lesion_dls, sem_lesion_dls = compute_mean_stay_prob(data_lesion_dls)\n", "mean_full, sem_full = compute_mean_stay_prob(data_control)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mean_lesion_hpc" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def plot_daw_style(ax, data, yerr=None, title=''):\n", " lightgray = '#d1d1d1'\n", " darkgray = '#929292'\n", "\n", " bar_width= 0.2\n", "\n", " bars1 = data[:2][::-1]\n", " bars2 = data[2:][::-1]\n", " if yerr is not None:\n", " errs1 = yerr[:2][::-1]\n", " errs2 = yerr[2:][::-1]\n", " else:\n", " errs1 = yerr\n", " errs2 = yerr\n", " \n", " # The x position of bars\n", " r1 = np.array([0.125, 0.625]) \n", " r2 = [x + bar_width + .05 for x in r1]\n", " list(sem_full),\n", " plt.sca(ax)\n", " \n", " plt.bar(r1, bars1, width=bar_width, color='blue', capsize=4,yerr=errs1)\n", " plt.bar(r2, bars2, width=bar_width, color='red', capsize=4,yerr=errs1)\n", " plt.xticks([r+ bar_width/2 +.025 for r in r1], ['Rewarded', 'Unrewarded'], fontsize=12)\n", " plt.yticks(fontsize=12)\n", " plt.title(title, fontsize=16)\n", " plt.ylim([0.4, 1])\n", " plt.xlim([0, 1])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, axes = plt.subplots(1,3, figsize= (7.5,2.4), sharey=True)\n", "\n", "plot_daw_style(axes[0], list(mean_lesion_hpc), yerr=sem_lesion_hpc, title='Striatum')\n", "plot_daw_style(axes[1], list(mean_lesion_dls), yerr=sem_lesion_dls, title='Hippocampus')\n", "plot_daw_style(axes[2], list(mean_full), yerr=sem_full, title='Full model')\n", "\n", "\n", "leg = axes[1].legend(['Common', 'Rare'], fontsize=10, frameon=False, handlelength=0.7, title='Previous transition')\n", "plt.sca(axes[0])\n", "plt.ylabel('Stay probability', fontsize=12)\n", "plt.tight_layout()\n", "plt.savefig(os.path.join(figure_location, 'StayProbability.pdf'))\n", "\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.close()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Doll et al analysis" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "figure_location = os.path.join(FIGURE_FOLDER,'twostep_deterministic')\n", "if not os.path.exists(figure_location):\n", " os.makedirs(figure_location)\n", "\n", "# load data\n", "data_lesion_hpc = pd.read_csv(os.path.join(RESULTS_FOLDER, 'twostep_deterministic', 'results_lesion_hpc.csv'))\n", "data_lesion_dls = pd.read_csv(os.path.join(RESULTS_FOLDER, 'twostep_deterministic', 'results_lesion_dls.csv'))\n", "data_control = pd.read_csv(os.path.join(RESULTS_FOLDER, 'twostep_deterministic', 'results_control.csv'))\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_control.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def add_relevant_columns(dataframe):\n", " dataframe['PreviousAction'] = dataframe.groupby(['Agent'])['Action1'].shift(1)\n", " dataframe['PreviousStart'] = dataframe.groupby(['Agent'])['StartState'].shift(1)\n", " dataframe['PreviousReward'] = dataframe.groupby(['Agent'])['Reward'].shift(1)\n", " dataframe['Stay'] = (dataframe.PreviousAction == dataframe.Action1)\n", " dataframe['SameStart'] = (dataframe.StartState == dataframe.PreviousStart)\n", "\n", "\n", "add_relevant_columns(data_control)\n", "add_relevant_columns(data_lesion_dls)\n", "add_relevant_columns(data_lesion_hpc)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def compute_mean_stay_prob(data):\n", " means = data[data['Trial']>0].groupby(['PreviousReward', 'SameStart'])['Stay'].mean()\n", " sems = data.groupby(['PreviousReward', 'SameStart'])['Stay'].sem()\n", " return means[::-1], sems[::-1]\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mean_lesion_hpc, sem_lesion_hpc = compute_mean_stay_prob(data_lesion_hpc)\n", "mean_lesion_dls, sem_lesion_dls = compute_mean_stay_prob(data_lesion_dls)\n", "mean_full, sem_full = compute_mean_stay_prob(data_control)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mean_lesion_dls" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def plot_doll_style(ax, data, yerr=None, title=''):\n", " lightgray = '#d1d1d1'\n", " darkgray = '#929292'\n", "\n", " bar_width = 0.2\n", "\n", " bars1 = data[:2]\n", " bars2 = data[2:]\n", " if yerr is not None:\n", " errs1 = yerr[:2]\n", " errs2 = yerr[2:]\n", " else:\n", " errs1 = None\n", " errs2 = None \n", " \n", " # The x position of bars\n", " r1 = np.arange(len(bars1)) * .8 + 1.5 * bar_width\n", " r2 = [x + bar_width for x in r1]\n", "\n", " plt.sca(ax)\n", "\n", " handle1 = plt.bar(r1, bars1, width=bar_width, color=lightgray, yerr=errs1, capsize=4)\n", " handle2 = plt.bar(r2, bars2, width=bar_width, color=darkgray, yerr=errs2, capsize=4)\n", " plt.ylabel('Stay probability', fontsize=15)\n", " plt.xticks([r + bar_width / 2 for r in r1], ['same', 'different'], fontsize=15)\n", " plt.yticks(fontsize=15)\n", " plt.title(title, fontsize=18)\n", " plt.ylim([0.45, 1.])\n", " plt.xlim([0, 1.6])\n", "\n", " ax.spines['right'].set_visible(False)\n", " ax.spines['top'].set_visible(False)\n", " return handle1, handle2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, axes = plt.subplots(1,3, figsize= (3.7,2.5), sharey=True)\n", "\n", "h1, h2 = plot_doll_style(axes[0], list(mean_lesion_hpc), yerr=sem_lesion_hpc, title='Striatum')\n", "h1, h2 = plot_doll_style(axes[1], list(mean_lesion_dls), yerr=sem_lesion_dls, title='Hippocampus')\n", "h1, h2 = plot_doll_style(axes[2], list(mean_full), yerr=sem_full, title='Full model')\n", "\n", "\n", "axes[0].set_position([0.1,0.1,0.5,0.7])\n", "axes[1].set_position([0.8,0.1,0.5,0.7])\n", "axes[2].set_position([1.5,0.1,0.5,0.7])\n", "\n", "#leg = axes[2].legend(['Reward', 'No reward'], fontsize=12, frameon=False, handlelength=.7)\n", "\n", "#plt.subplots_adjust(left=0.07, right=.93, wspace=0.25, hspace=0.35)\n", "leg = fig.legend([h1, h2], ['Reward', 'No reward'], bbox_to_anchor=(2.1, .5), loc = (1,.5), title=\"Previous outcome\")\n", "leg.set_title('Previous outcome', prop={'size': 12})\n", "\n", "plt.tight_layout()\n", "\n", "plt.savefig(os.path.join(figure_location, 'StayProbability_DeterministicTask.pdf'), bbox_inches='tight')\n", "\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.close()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_control.groupby([\"SameStart\", \"PreviousReward\"])['P(SR)'].mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Regression analysis" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_control.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = data_control[['Agent', 'PreviousReward', 'PreviousAction', 'SameStart', 'Action1']]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mod = smf.logit(formula='Action1 ~ PreviousReward * PreviousAction * SameStart', data=data)\n", "res = mod.fit()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df = pd.DataFrame({}, columns=res.params.keys())\n", "for agent in data.Agent.unique():\n", " mod = smf.logit(formula='Action1 ~ PreviousReward * PreviousAction * SameStart', data=data[data.Agent==agent])\n", " res = mod.fit()\n", " df = df.append(res.params.to_dict(), ignore_index=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df[\"PreviousReward\"]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "res.summary()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df['PreviousReward']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from scipy.stats import ttest_ind\n", "from scipy.stats import ttest_1samp\n", "\n", "ttest_1samp(df['PreviousReward'],0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ttest_1samp(df['PreviousReward:SameStart[T.True]'], 0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "help(ttest_1samp)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mod = smf.logit(formula='Action1 ~ PreviousReward * PreviousAction * SameStart * Agent', data=data)\n", "res = mod.fit()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "res.summary()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "res.pvalues['PreviousReward:SameStart[T.True]']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }