{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../..')\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import os\n",
    "\n",
    "from definitions import RESULTS_FOLDER, FIGURE_FOLDER\n",
    "import statsmodels.formula.api as smf\n",
    "\n",
    "\n",
    "figure_location = os.path.join(FIGURE_FOLDER,'twostep')\n",
    "if not os.path.exists(figure_location):\n",
    "    os.makedirs(figure_location)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_lesion_hpc = pd.read_csv(os.path.join(RESULTS_FOLDER, 'twostep', 'results_lesion_hpc.csv'))\n",
    "data_lesion_dls = pd.read_csv(os.path.join(RESULTS_FOLDER, 'twostep', 'results_lesion_dls.csv'))\n",
    "data_control = pd.read_csv(os.path.join(RESULTS_FOLDER, 'twostep', 'results_control.csv'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_lesion_dls.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def is_common_or_rare(action, out):\n",
    "    left_outcomes = (5, 6)\n",
    "    right_outcomes = (7, 8)\n",
    "    if action == 0 and out in left_outcomes:\n",
    "        return 'common'\n",
    "    elif action == 0 and out in right_outcomes:\n",
    "        return 'rare'\n",
    "    elif action == 1 and out in left_outcomes:\n",
    "        return 'rare'\n",
    "    elif action == 1 and out in right_outcomes:\n",
    "        return 'common'\n",
    "    else:\n",
    "        raise ValueError('The combination of action and outcome does not make sense')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def add_relevant_columns(dataframe):\n",
    "    dataframe['PreviousAction'] = dataframe.groupby(['Agent'])['Action1'].shift(1)\n",
    "    dataframe['PreviousStart'] = dataframe.groupby(['Agent'])['StartState'].shift(1)\n",
    "    dataframe['PreviousReward'] = dataframe.groupby(['Agent'])['Reward'].shift(1)\n",
    "    dataframe['Stay'] = (dataframe.PreviousAction == dataframe.Action1)\n",
    "    dataframe['Transition'] = np.vectorize(is_common_or_rare)(dataframe['Action1'], dataframe['Terminus'])\n",
    "    dataframe['PreviousTransition'] = dataframe.groupby(['Agent'])['Transition'].shift(1)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "add_relevant_columns(data_control)\n",
    "add_relevant_columns(data_lesion_dls)\n",
    "add_relevant_columns(data_lesion_hpc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_lesion_dls.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_mean_stay_prob(data):\n",
    "    means = data[data['Trial']>0].groupby(['PreviousTransition', 'PreviousReward'])['Stay'].mean()\n",
    "    sems = data.groupby(['PreviousTransition', 'PreviousReward'])['Stay'].sem()\n",
    "    return means, sems"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_lesion_hpc, sem_lesion_hpc = compute_mean_stay_prob(data_lesion_hpc)\n",
    "mean_lesion_dls, sem_lesion_dls = compute_mean_stay_prob(data_lesion_dls)\n",
    "mean_full, sem_full = compute_mean_stay_prob(data_control)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_lesion_hpc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_daw_style(ax, data, yerr=None, title=''):\n",
    "    lightgray = '#d1d1d1'\n",
    "    darkgray = '#929292'\n",
    "\n",
    "    bar_width= 0.2\n",
    "\n",
    "    bars1 = data[:2][::-1]\n",
    "    bars2 = data[2:][::-1]\n",
    "    if yerr is not None:\n",
    "        errs1 = yerr[:2][::-1]\n",
    "        errs2 = yerr[2:][::-1]\n",
    "    else:\n",
    "        errs1 = yerr\n",
    "        errs2 = yerr\n",
    "        \n",
    "    # The x position of bars\n",
    "    r1 = np.array([0.125, 0.625]) \n",
    "    r2 = [x + bar_width + .05 for x in r1]\n",
    "    list(sem_full),\n",
    "    plt.sca(ax)\n",
    "    \n",
    "    plt.bar(r1, bars1, width=bar_width, color='blue',  capsize=4,yerr=errs1)\n",
    "    plt.bar(r2, bars2, width=bar_width, color='red',  capsize=4,yerr=errs1)\n",
    "    plt.xticks([r+ bar_width/2 +.025 for r in r1], ['Rewarded', 'Unrewarded'], fontsize=12)\n",
    "    plt.yticks(fontsize=12)\n",
    "    plt.title(title, fontsize=16)\n",
    "    plt.ylim([0.4, 1])\n",
    "    plt.xlim([0, 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(1,3, figsize= (7.5,2.4), sharey=True)\n",
    "\n",
    "plot_daw_style(axes[0], list(mean_lesion_hpc), yerr=sem_lesion_hpc,  title='Striatum')\n",
    "plot_daw_style(axes[1], list(mean_lesion_dls), yerr=sem_lesion_dls, title='Hippocampus')\n",
    "plot_daw_style(axes[2], list(mean_full), yerr=sem_full, title='Full model')\n",
    "\n",
    "\n",
    "leg = axes[1].legend(['Common', 'Rare'], fontsize=10, frameon=False, handlelength=0.7, title='Previous transition')\n",
    "plt.sca(axes[0])\n",
    "plt.ylabel('Stay probability', fontsize=12)\n",
    "plt.tight_layout()\n",
    "plt.savefig(os.path.join(figure_location, 'StayProbability.pdf'))\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Doll et al analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "figure_location = os.path.join(FIGURE_FOLDER,'twostep_deterministic')\n",
    "if not os.path.exists(figure_location):\n",
    "    os.makedirs(figure_location)\n",
    "\n",
    "# load data\n",
    "data_lesion_hpc = pd.read_csv(os.path.join(RESULTS_FOLDER, 'twostep_deterministic', 'results_lesion_hpc.csv'))\n",
    "data_lesion_dls = pd.read_csv(os.path.join(RESULTS_FOLDER, 'twostep_deterministic', 'results_lesion_dls.csv'))\n",
    "data_control = pd.read_csv(os.path.join(RESULTS_FOLDER, 'twostep_deterministic', 'results_control.csv'))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_control.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def add_relevant_columns(dataframe):\n",
    "    dataframe['PreviousAction'] = dataframe.groupby(['Agent'])['Action1'].shift(1)\n",
    "    dataframe['PreviousStart'] = dataframe.groupby(['Agent'])['StartState'].shift(1)\n",
    "    dataframe['PreviousReward'] = dataframe.groupby(['Agent'])['Reward'].shift(1)\n",
    "    dataframe['Stay'] = (dataframe.PreviousAction == dataframe.Action1)\n",
    "    dataframe['SameStart'] = (dataframe.StartState == dataframe.PreviousStart)\n",
    "\n",
    "\n",
    "add_relevant_columns(data_control)\n",
    "add_relevant_columns(data_lesion_dls)\n",
    "add_relevant_columns(data_lesion_hpc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_mean_stay_prob(data):\n",
    "    means = data[data['Trial']>0].groupby(['PreviousReward', 'SameStart'])['Stay'].mean()\n",
    "    sems = data.groupby(['PreviousReward', 'SameStart'])['Stay'].sem()\n",
    "    return means[::-1], sems[::-1]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_lesion_hpc, sem_lesion_hpc = compute_mean_stay_prob(data_lesion_hpc)\n",
    "mean_lesion_dls, sem_lesion_dls = compute_mean_stay_prob(data_lesion_dls)\n",
    "mean_full, sem_full = compute_mean_stay_prob(data_control)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_lesion_dls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_doll_style(ax, data, yerr=None, title=''):\n",
    "    lightgray = '#d1d1d1'\n",
    "    darkgray = '#929292'\n",
    "\n",
    "    bar_width = 0.2\n",
    "\n",
    "    bars1 = data[:2]\n",
    "    bars2 = data[2:]\n",
    "    if yerr is not None:\n",
    "        errs1 = yerr[:2]\n",
    "        errs2 = yerr[2:]\n",
    "    else:\n",
    "        errs1 = None\n",
    "        errs2 = None \n",
    "        \n",
    "    # The x position of bars\n",
    "    r1 = np.arange(len(bars1)) * .8 + 1.5 * bar_width\n",
    "    r2 = [x + bar_width for x in r1]\n",
    "\n",
    "    plt.sca(ax)\n",
    "\n",
    "    handle1 = plt.bar(r1, bars1, width=bar_width, color=lightgray, yerr=errs1, capsize=4)\n",
    "    handle2 = plt.bar(r2, bars2, width=bar_width, color=darkgray, yerr=errs2, capsize=4)\n",
    "    plt.ylabel('Stay probability', fontsize=15)\n",
    "    plt.xticks([r + bar_width / 2 for r in r1], ['same', 'different'], fontsize=15)\n",
    "    plt.yticks(fontsize=15)\n",
    "    plt.title(title, fontsize=18)\n",
    "    plt.ylim([0.45, 1.])\n",
    "    plt.xlim([0, 1.6])\n",
    "\n",
    "    ax.spines['right'].set_visible(False)\n",
    "    ax.spines['top'].set_visible(False)\n",
    "    return handle1, handle2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(1,3, figsize= (3.7,2.5), sharey=True)\n",
    "\n",
    "h1, h2 = plot_doll_style(axes[0], list(mean_lesion_hpc), yerr=sem_lesion_hpc, title='Striatum')\n",
    "h1, h2 = plot_doll_style(axes[1], list(mean_lesion_dls), yerr=sem_lesion_dls, title='Hippocampus')\n",
    "h1, h2 = plot_doll_style(axes[2], list(mean_full), yerr=sem_full, title='Full model')\n",
    "\n",
    "\n",
    "axes[0].set_position([0.1,0.1,0.5,0.7])\n",
    "axes[1].set_position([0.8,0.1,0.5,0.7])\n",
    "axes[2].set_position([1.5,0.1,0.5,0.7])\n",
    "\n",
    "#leg = axes[2].legend(['Reward', 'No reward'], fontsize=12, frameon=False, handlelength=.7)\n",
    "\n",
    "#plt.subplots_adjust(left=0.07, right=.93, wspace=0.25, hspace=0.35)\n",
    "leg = fig.legend([h1, h2], ['Reward', 'No reward'], bbox_to_anchor=(2.1, .5), loc = (1,.5), title=\"Previous outcome\")\n",
    "leg.set_title('Previous outcome', prop={'size': 12})\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig(os.path.join(figure_location, 'StayProbability_DeterministicTask.pdf'), bbox_inches='tight')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_control.groupby([\"SameStart\", \"PreviousReward\"])['P(SR)'].mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Regression analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_control.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = data_control[['Agent', 'PreviousReward', 'PreviousAction', 'SameStart', 'Action1']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mod = smf.logit(formula='Action1 ~ PreviousReward * PreviousAction * SameStart', data=data)\n",
    "res = mod.fit()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame({}, columns=res.params.keys())\n",
    "for agent in data.Agent.unique():\n",
    "    mod = smf.logit(formula='Action1 ~ PreviousReward * PreviousAction * SameStart', data=data[data.Agent==agent])\n",
    "    res = mod.fit()\n",
    "    df = df.append(res.params.to_dict(), ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df[\"PreviousReward\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df['PreviousReward']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import ttest_ind\n",
    "from scipy.stats import ttest_1samp\n",
    "\n",
    "ttest_1samp(df['PreviousReward'],0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ttest_1samp(df['PreviousReward:SameStart[T.True]'], 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "help(ttest_1samp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mod = smf.logit(formula='Action1 ~ PreviousReward * PreviousAction * SameStart * Agent', data=data)\n",
    "res = mod.fit()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res.pvalues['PreviousReward:SameStart[T.True]']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}