{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../..')\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import os\n",
    "\n",
    "from definitions import RESULTS_FOLDER, FIGURE_FOLDER\n",
    "import statsmodels.formula.api as smf\n",
    "\n",
    "\n",
    "figure_location = os.path.join(FIGURE_FOLDER,'twostep')\n",
    "if not os.path.exists(figure_location):\n",
    "    os.makedirs(figure_location)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_lesion_hpc = pd.read_csv(os.path.join(RESULTS_FOLDER, 'twostep', 'results_lesion_hpc.csv'))\n",
    "data_lesion_dls = pd.read_csv(os.path.join(RESULTS_FOLDER, 'twostep', 'results_lesion_dls.csv'))\n",
    "data_control = pd.read_csv(os.path.join(RESULTS_FOLDER, 'twostep', 'results_control.csv'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Unnamed: 0</th>\n",
       "      <th>Action1</th>\n",
       "      <th>Action2</th>\n",
       "      <th>Agent</th>\n",
       "      <th>DLS reliability</th>\n",
       "      <th>HPC reliability</th>\n",
       "      <th>P(SR)</th>\n",
       "      <th>Reward</th>\n",
       "      <th>StartState</th>\n",
       "      <th>State2</th>\n",
       "      <th>Terminus</th>\n",
       "      <th>Trial</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.057327</td>\n",
       "      <td>0.087327</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.115198</td>\n",
       "      <td>0.167028</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.175993</td>\n",
       "      <td>0.239769</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.224014</td>\n",
       "      <td>0.306158</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>3.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.252047</td>\n",
       "      <td>0.366749</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>4.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Unnamed: 0  Action1  Action2  Agent  DLS reliability  HPC reliability  \\\n",
       "0           0      0.0      0.0    0.0         0.057327         0.087327   \n",
       "1           1      0.0      0.0    0.0         0.115198         0.167028   \n",
       "2           2      0.0      0.0    0.0         0.175993         0.239769   \n",
       "3           3      0.0      0.0    0.0         0.224014         0.306158   \n",
       "4           4      0.0      0.0    0.0         0.252047         0.366749   \n",
       "\n",
       "   P(SR)  Reward  StartState  State2  Terminus  Trial  \n",
       "0    1.0     1.0         0.0     3.0       5.0    0.0  \n",
       "1    1.0     1.0         0.0     3.0       5.0    1.0  \n",
       "2    1.0     1.0         0.0     3.0       5.0    2.0  \n",
       "3    1.0     0.0         0.0     3.0       5.0    3.0  \n",
       "4    1.0     1.0         0.0     4.0       7.0    4.0  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_lesion_dls.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def is_common_or_rare(action, out):\n",
    "    left_outcomes = (5, 6)\n",
    "    right_outcomes = (7, 8)\n",
    "    if action == 0 and out in left_outcomes:\n",
    "        return 'common'\n",
    "    elif action == 0 and out in right_outcomes:\n",
    "        return 'rare'\n",
    "    elif action == 1 and out in left_outcomes:\n",
    "        return 'rare'\n",
    "    elif action == 1 and out in right_outcomes:\n",
    "        return 'common'\n",
    "    else:\n",
    "        raise ValueError('The combination of action and outcome does not make sense')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def add_relevant_columns(dataframe):\n",
    "    dataframe['PreviousAction'] = dataframe.groupby(['Agent'])['Action1'].shift(1)\n",
    "    dataframe['PreviousStart'] = dataframe.groupby(['Agent'])['StartState'].shift(1)\n",
    "    dataframe['PreviousReward'] = dataframe.groupby(['Agent'])['Reward'].shift(1)\n",
    "    dataframe['Stay'] = (dataframe.PreviousAction == dataframe.Action1)\n",
    "    dataframe['Transition'] = np.vectorize(is_common_or_rare)(dataframe['Action1'], dataframe['Terminus'])\n",
    "    dataframe['PreviousTransition'] = dataframe.groupby(['Agent'])['Transition'].shift(1)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "add_relevant_columns(data_control)\n",
    "add_relevant_columns(data_lesion_dls)\n",
    "add_relevant_columns(data_lesion_hpc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Unnamed: 0</th>\n",
       "      <th>Action1</th>\n",
       "      <th>Action2</th>\n",
       "      <th>Agent</th>\n",
       "      <th>DLS reliability</th>\n",
       "      <th>HPC reliability</th>\n",
       "      <th>P(SR)</th>\n",
       "      <th>Reward</th>\n",
       "      <th>StartState</th>\n",
       "      <th>State2</th>\n",
       "      <th>Terminus</th>\n",
       "      <th>Trial</th>\n",
       "      <th>PreviousAction</th>\n",
       "      <th>PreviousStart</th>\n",
       "      <th>PreviousReward</th>\n",
       "      <th>Stay</th>\n",
       "      <th>Transition</th>\n",
       "      <th>PreviousTransition</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.057327</td>\n",
       "      <td>0.087327</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>False</td>\n",
       "      <td>common</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.115198</td>\n",
       "      <td>0.167028</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>True</td>\n",
       "      <td>common</td>\n",
       "      <td>common</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.175993</td>\n",
       "      <td>0.239769</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>True</td>\n",
       "      <td>common</td>\n",
       "      <td>common</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.224014</td>\n",
       "      <td>0.306158</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>True</td>\n",
       "      <td>common</td>\n",
       "      <td>common</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.252047</td>\n",
       "      <td>0.366749</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>True</td>\n",
       "      <td>rare</td>\n",
       "      <td>common</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Unnamed: 0  Action1  Action2  Agent  DLS reliability  HPC reliability  \\\n",
       "0           0      0.0      0.0    0.0         0.057327         0.087327   \n",
       "1           1      0.0      0.0    0.0         0.115198         0.167028   \n",
       "2           2      0.0      0.0    0.0         0.175993         0.239769   \n",
       "3           3      0.0      0.0    0.0         0.224014         0.306158   \n",
       "4           4      0.0      0.0    0.0         0.252047         0.366749   \n",
       "\n",
       "   P(SR)  Reward  StartState  State2  Terminus  Trial  PreviousAction  \\\n",
       "0    1.0     1.0         0.0     3.0       5.0    0.0             NaN   \n",
       "1    1.0     1.0         0.0     3.0       5.0    1.0             0.0   \n",
       "2    1.0     1.0         0.0     3.0       5.0    2.0             0.0   \n",
       "3    1.0     0.0         0.0     3.0       5.0    3.0             0.0   \n",
       "4    1.0     1.0         0.0     4.0       7.0    4.0             0.0   \n",
       "\n",
       "   PreviousStart  PreviousReward   Stay Transition PreviousTransition  \n",
       "0            NaN             NaN  False     common                NaN  \n",
       "1            0.0             1.0   True     common             common  \n",
       "2            0.0             1.0   True     common             common  \n",
       "3            0.0             1.0   True     common             common  \n",
       "4            0.0             0.0   True       rare             common  "
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_lesion_dls.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_mean_stay_prob(data):\n",
    "    means = data[data['Trial']>0].groupby(['PreviousTransition', 'PreviousReward'])['Stay'].mean()\n",
    "    sems = data.groupby(['PreviousTransition', 'PreviousReward'])['Stay'].sem()\n",
    "    return means, sems"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_lesion_hpc, sem_lesion_hpc = compute_mean_stay_prob(data_lesion_hpc)\n",
    "mean_lesion_dls, sem_lesion_dls = compute_mean_stay_prob(data_lesion_dls)\n",
    "mean_full, sem_full = compute_mean_stay_prob(data_control)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PreviousTransition  PreviousReward\n",
       "common              0.0               0.558473\n",
       "                    1.0               0.836690\n",
       "rare                0.0               0.563003\n",
       "                    1.0               0.823360\n",
       "Name: Stay, dtype: float64"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mean_lesion_hpc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_daw_style(ax, data, yerr=None, title=''):\n",
    "    lightgray = '#d1d1d1'\n",
    "    darkgray = '#929292'\n",
    "\n",
    "    bar_width= 0.2\n",
    "\n",
    "    bars1 = data[:2][::-1]\n",
    "    bars2 = data[2:][::-1]\n",
    "    if yerr is not None:\n",
    "        errs1 = yerr[:2][::-1]\n",
    "        errs2 = yerr[2:][::-1]\n",
    "    else:\n",
    "        errs1 = yerr\n",
    "        errs2 = yerr\n",
    "        \n",
    "    # The x position of bars\n",
    "    r1 = np.array([0.125, 0.625]) \n",
    "    r2 = [x + bar_width + .05 for x in r1]\n",
    "    list(sem_full),\n",
    "    plt.sca(ax)\n",
    "    \n",
    "    plt.bar(r1, bars1, width=bar_width, color='blue',  capsize=4,yerr=errs1)\n",
    "    plt.bar(r2, bars2, width=bar_width, color='red',  capsize=4,yerr=errs1)\n",
    "    plt.xticks([r+ bar_width/2 +.025 for r in r1], ['Rewarded', 'Unrewarded'], fontsize=12)\n",
    "    plt.yticks(fontsize=12)\n",
    "    plt.title(title, fontsize=16)\n",
    "    plt.ylim([0.4, 1])\n",
    "    plt.xlim([0, 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 648x180 with 3 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig, axes = plt.subplots(1,3, figsize= (9.5,2.5), sharey=True)\n",
    "\n",
    "plot_daw_style(axes[0], list(mean_lesion_hpc), yerr=sem_lesion_hpc,  title='Striatum')\n",
    "plot_daw_style(axes[1], list(mean_lesion_dls), yerr=sem_lesion_dls, title='Hippocampus')\n",
    "plot_daw_style(axes[2], list(mean_full), yerr=sem_full, title='Full model')\n",
    "\n",
    "\n",
    "leg = axes[1].legend(['Common', 'Rare'], fontsize=12, frameon=False, handlelength=0.7, title='Previous transition')\n",
    "plt.sca(axes[0])\n",
    "plt.ylabel('Stay probability', fontsize=12)\n",
    "\n",
    "plt.savefig(os.path.join(figure_location, 'StayProbability.pdf'))\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Doll et al analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "figure_location = os.path.join(FIGURE_FOLDER,'twostep_deterministic')\n",
    "if not os.path.exists(figure_location):\n",
    "    os.makedirs(figure_location)\n",
    "\n",
    "# load data\n",
    "data_lesion_hpc = pd.read_csv(os.path.join(RESULTS_FOLDER, 'twostep_deterministic', 'results_lesion_hpc.csv'))\n",
    "data_lesion_dls = pd.read_csv(os.path.join(RESULTS_FOLDER, 'twostep_deterministic', 'results_lesion_dls.csv'))\n",
    "data_control = pd.read_csv(os.path.join(RESULTS_FOLDER, 'twostep_deterministic', 'results_control.csv'))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Unnamed: 0</th>\n",
       "      <th>Action1</th>\n",
       "      <th>Action2</th>\n",
       "      <th>Agent</th>\n",
       "      <th>DLS reliability</th>\n",
       "      <th>HPC reliability</th>\n",
       "      <th>P(SR)</th>\n",
       "      <th>Reward</th>\n",
       "      <th>StartState</th>\n",
       "      <th>State2</th>\n",
       "      <th>Terminus</th>\n",
       "      <th>Trial</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.030000</td>\n",
       "      <td>0.681250</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.059100</td>\n",
       "      <td>0.671966</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.021000</td>\n",
       "      <td>0.087327</td>\n",
       "      <td>0.675625</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.029370</td>\n",
       "      <td>0.114707</td>\n",
       "      <td>0.670047</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>3.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.055789</td>\n",
       "      <td>0.141266</td>\n",
       "      <td>0.669577</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>4.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Unnamed: 0  Action1  Action2  Agent  DLS reliability  HPC reliability  \\\n",
       "0           0      0.0      0.0    0.0         0.000000         0.030000   \n",
       "1           1      1.0      1.0    0.0         0.000000         0.059100   \n",
       "2           2      1.0      1.0    0.0         0.021000         0.087327   \n",
       "3           3      0.0      0.0    0.0         0.029370         0.114707   \n",
       "4           4      1.0      1.0    0.0         0.055789         0.141266   \n",
       "\n",
       "      P(SR)  Reward  StartState  State2  Terminus  Trial  \n",
       "0  0.681250     1.0         1.0     2.0       4.0    0.0  \n",
       "1  0.671966     1.0         0.0     3.0       7.0    1.0  \n",
       "2  0.675625     1.0         1.0     3.0       7.0    2.0  \n",
       "3  0.670047     0.0         1.0     2.0       4.0    3.0  \n",
       "4  0.669577     1.0         0.0     3.0       7.0    4.0  "
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_control.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def add_relevant_columns(dataframe):\n",
    "    dataframe['PreviousAction'] = dataframe.groupby(['Agent'])['Action1'].shift(1)\n",
    "    dataframe['PreviousStart'] = dataframe.groupby(['Agent'])['StartState'].shift(1)\n",
    "    dataframe['PreviousReward'] = dataframe.groupby(['Agent'])['Reward'].shift(1)\n",
    "    dataframe['Stay'] = (dataframe.PreviousAction == dataframe.Action1)\n",
    "    dataframe['SameStart'] = (dataframe.StartState == dataframe.PreviousStart)\n",
    "\n",
    "\n",
    "add_relevant_columns(data_control)\n",
    "add_relevant_columns(data_lesion_dls)\n",
    "add_relevant_columns(data_lesion_hpc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_mean_stay_prob(data):\n",
    "    means = data[data['Trial']>0].groupby(['PreviousReward', 'SameStart'])['Stay'].mean()\n",
    "    sems = data.groupby(['PreviousReward', 'SameStart'])['Stay'].sem()\n",
    "    return means[::-1], sems[::-1]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_lesion_hpc, sem_lesion_hpc = compute_mean_stay_prob(data_lesion_hpc)\n",
    "mean_lesion_dls, sem_lesion_dls = compute_mean_stay_prob(data_lesion_dls)\n",
    "mean_full, sem_full = compute_mean_stay_prob(data_control)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PreviousReward  SameStart\n",
       "1.0             True         0.944564\n",
       "                False        0.953787\n",
       "0.0             True         0.484103\n",
       "                False        0.484945\n",
       "Name: Stay, dtype: float64"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mean_lesion_dls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_doll_style(ax, data, yerr=None, title=''):\n",
    "    lightgray = '#d1d1d1'\n",
    "    darkgray = '#929292'\n",
    "\n",
    "    bar_width = 0.2\n",
    "\n",
    "    bars1 = data[:2]\n",
    "    bars2 = data[2:]\n",
    "    if yerr is not None:\n",
    "        errs1 = yerr[:2]\n",
    "        errs2 = yerr[2:]\n",
    "    else:\n",
    "        errs1 = None\n",
    "        errs2 = None \n",
    "        \n",
    "    # The x position of bars\n",
    "    r1 = np.arange(len(bars1)) * .8 + 1.5 * bar_width\n",
    "    r2 = [x + bar_width for x in r1]\n",
    "\n",
    "    plt.sca(ax)\n",
    "\n",
    "    handle1 = plt.bar(r1, bars1, width=bar_width, color=lightgray, yerr=errs1, capsize=4)\n",
    "    handle2 = plt.bar(r2, bars2, width=bar_width, color=darkgray, yerr=errs2, capsize=4)\n",
    "    plt.ylabel('Stay probability', fontsize=15)\n",
    "    plt.xticks([r + bar_width / 2 for r in r1], ['same', 'different'], fontsize=15)\n",
    "    plt.yticks(fontsize=15)\n",
    "    plt.title(title, fontsize=18)\n",
    "    plt.ylim([0.45, 1.])\n",
    "    plt.xlim([0, 1.6])\n",
    "\n",
    "    ax.spines['right'].set_visible(False)\n",
    "    ax.spines['top'].set_visible(False)\n",
    "    return handle1, handle2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/jessegeerts/miniconda3/envs/models/lib/python3.6/site-packages/matplotlib/tight_layout.py:176: UserWarning: Tight layout not applied. The left and right margins cannot be made large enough to accommodate all axes decorations. \n",
      "  warnings.warn('Tight layout not applied. The left and right margins '\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 266.4x180 with 3 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig, axes = plt.subplots(1,3, figsize= (3.7,2.5), sharey=True)\n",
    "\n",
    "h1, h2 = plot_doll_style(axes[0], list(mean_lesion_hpc), title='Striatum')\n",
    "h1, h2 = plot_doll_style(axes[1], list(mean_lesion_dls), title='Hippocampus')\n",
    "h1, h2 = plot_doll_style(axes[2], list(mean_full), title='Full model')\n",
    "\n",
    "\n",
    "axes[0].set_position([0.1,0.1,0.5,0.7])\n",
    "axes[1].set_position([0.8,0.1,0.5,0.7])\n",
    "axes[2].set_position([1.5,0.1,0.5,0.7])\n",
    "\n",
    "#leg = axes[2].legend(['Reward', 'No reward'], fontsize=12, frameon=False, handlelength=.7)\n",
    "\n",
    "#plt.subplots_adjust(left=0.07, right=.93, wspace=0.25, hspace=0.35)\n",
    "leg = fig.legend([h1, h2], ['Reward', 'No reward'], bbox_to_anchor=(2.1, .5), loc = (1,.5), title=\"Previous outcome\")\n",
    "leg.set_title('Previous outcome', prop={'size': 12})\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig(os.path.join(figure_location, 'StayProbability_DeterministicTask.pdf'), bbox_inches='tight')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SameStart  PreviousReward\n",
       "False      0.0               0.492189\n",
       "           1.0               0.493106\n",
       "True       0.0               0.491967\n",
       "           1.0               0.492369\n",
       "Name: P(SR), dtype: float64"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_control.groupby([\"SameStart\", \"PreviousReward\"])['P(SR)'].mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Regression analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Unnamed: 0</th>\n",
       "      <th>Action1</th>\n",
       "      <th>Action2</th>\n",
       "      <th>Agent</th>\n",
       "      <th>DLS reliability</th>\n",
       "      <th>HPC reliability</th>\n",
       "      <th>P(SR)</th>\n",
       "      <th>Reward</th>\n",
       "      <th>StartState</th>\n",
       "      <th>State2</th>\n",
       "      <th>Terminus</th>\n",
       "      <th>Trial</th>\n",
       "      <th>PreviousAction</th>\n",
       "      <th>PreviousStart</th>\n",
       "      <th>PreviousReward</th>\n",
       "      <th>Stay</th>\n",
       "      <th>SameStart</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.030000</td>\n",
       "      <td>0.681250</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.059100</td>\n",
       "      <td>0.671966</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.021000</td>\n",
       "      <td>0.087327</td>\n",
       "      <td>0.675625</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>True</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.029370</td>\n",
       "      <td>0.114707</td>\n",
       "      <td>0.670047</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>False</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.055789</td>\n",
       "      <td>0.141266</td>\n",
       "      <td>0.669577</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Unnamed: 0  Action1  Action2  Agent  DLS reliability  HPC reliability  \\\n",
       "0           0      0.0      0.0    0.0         0.000000         0.030000   \n",
       "1           1      1.0      1.0    0.0         0.000000         0.059100   \n",
       "2           2      1.0      1.0    0.0         0.021000         0.087327   \n",
       "3           3      0.0      0.0    0.0         0.029370         0.114707   \n",
       "4           4      1.0      1.0    0.0         0.055789         0.141266   \n",
       "\n",
       "      P(SR)  Reward  StartState  State2  Terminus  Trial  PreviousAction  \\\n",
       "0  0.681250     1.0         1.0     2.0       4.0    0.0             NaN   \n",
       "1  0.671966     1.0         0.0     3.0       7.0    1.0             0.0   \n",
       "2  0.675625     1.0         1.0     3.0       7.0    2.0             1.0   \n",
       "3  0.670047     0.0         1.0     2.0       4.0    3.0             1.0   \n",
       "4  0.669577     1.0         0.0     3.0       7.0    4.0             0.0   \n",
       "\n",
       "   PreviousStart  PreviousReward   Stay  SameStart  \n",
       "0            NaN             NaN  False      False  \n",
       "1            1.0             1.0  False      False  \n",
       "2            0.0             1.0   True      False  \n",
       "3            1.0             1.0  False       True  \n",
       "4            1.0             0.0  False      False  "
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_control.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = data_control[['Agent', 'PreviousReward', 'PreviousAction', 'SameStart', 'Action1']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Optimization terminated successfully.\n",
      "         Current function value: 0.499005\n",
      "         Iterations 7\n"
     ]
    }
   ],
   "source": [
    "mod = smf.logit(formula='Action1 ~ PreviousReward * PreviousAction * SameStart', data=data)\n",
    "res = mod.fit()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Optimization terminated successfully.\n",
      "         Current function value: 0.489139\n",
      "         Iterations 6\n",
      "Optimization terminated successfully.\n",
      "         Current function value: 0.538574\n",
      "         Iterations 8\n",
      "Warning: Maximum number of iterations has been exceeded.\n",
      "         Current function value: 0.412727\n",
      "         Iterations: 35\n",
      "Optimization terminated successfully.\n",
      "         Current function value: 0.507848\n",
      "         Iterations 7\n",
      "Warning: Maximum number of iterations has been exceeded.\n",
      "         Current function value: 0.496648\n",
      "         Iterations: 35\n",
      "Optimization terminated successfully.\n",
      "         Current function value: 0.518046\n",
      "         Iterations 8\n",
      "Optimization terminated successfully.\n",
      "         Current function value: 0.498005\n",
      "         Iterations 7\n",
      "Optimization terminated successfully.\n",
      "         Current function value: 0.425808\n",
      "         Iterations 8\n",
      "Optimization terminated successfully.\n",
      "         Current function value: 0.549229\n",
      "         Iterations 6\n",
      "Warning: Maximum number of iterations has been exceeded.\n",
      "         Current function value: 0.444553\n",
      "         Iterations: 35\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/jessegeerts/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/base/model.py:508: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals\n",
      "  \"Check mle_retvals\", ConvergenceWarning)\n",
      "/Users/jessegeerts/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/base/model.py:508: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals\n",
      "  \"Check mle_retvals\", ConvergenceWarning)\n",
      "/Users/jessegeerts/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/base/model.py:508: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals\n",
      "  \"Check mle_retvals\", ConvergenceWarning)\n",
      "/Users/jessegeerts/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/base/model.py:508: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals\n",
      "  \"Check mle_retvals\", ConvergenceWarning)\n",
      "/Users/jessegeerts/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/base/model.py:508: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals\n",
      "  \"Check mle_retvals\", ConvergenceWarning)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warning: Maximum number of iterations has been exceeded.\n",
      "         Current function value: 0.450759\n",
      "         Iterations: 35\n",
      "Warning: Maximum number of iterations has been exceeded.\n",
      "         Current function value: 0.473348\n",
      "         Iterations: 35\n",
      "Warning: Maximum number of iterations has been exceeded.\n",
      "         Current function value: 0.456640\n",
      "         Iterations: 35\n",
      "Optimization terminated successfully.\n",
      "         Current function value: 0.476260\n",
      "         Iterations 8\n",
      "Optimization terminated successfully.\n",
      "         Current function value: 0.442589\n",
      "         Iterations 8\n",
      "Optimization terminated successfully.\n",
      "         Current function value: 0.523259\n",
      "         Iterations 7\n",
      "Warning: Maximum number of iterations has been exceeded.\n",
      "         Current function value: 0.478907\n",
      "         Iterations: 35\n",
      "Optimization terminated successfully.\n",
      "         Current function value: 0.471507\n",
      "         Iterations 21\n",
      "Warning: Maximum number of iterations has been exceeded.\n",
      "         Current function value: 0.507018\n",
      "         Iterations: 35\n",
      "Optimization terminated successfully.\n",
      "         Current function value: 0.506375\n",
      "         Iterations 7\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/jessegeerts/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/base/model.py:508: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals\n",
      "  \"Check mle_retvals\", ConvergenceWarning)\n",
      "/Users/jessegeerts/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/base/model.py:508: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals\n",
      "  \"Check mle_retvals\", ConvergenceWarning)\n",
      "/Users/jessegeerts/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/base/model.py:508: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals\n",
      "  \"Check mle_retvals\", ConvergenceWarning)\n",
      "/Users/jessegeerts/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/base/model.py:508: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals\n",
      "  \"Check mle_retvals\", ConvergenceWarning)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warning: Maximum number of iterations has been exceeded.\n",
      "         Current function value: 0.453712\n",
      "         Iterations: 35\n",
      "Optimization terminated successfully.\n",
      "         Current function value: 0.497200\n",
      "         Iterations 7\n",
      "Optimization terminated successfully.\n",
      "         Current function value: 0.486236\n",
      "         Iterations 8\n",
      "Optimization terminated successfully.\n",
      "         Current function value: 0.529446\n",
      "         Iterations 7\n",
      "Warning: Maximum number of iterations has been exceeded.\n",
      "         Current function value: 0.483552\n",
      "         Iterations: 35\n",
      "Optimization terminated successfully.\n",
      "         Current function value: 0.525430\n",
      "         Iterations 7\n",
      "Warning: Maximum number of iterations has been exceeded.\n",
      "         Current function value: 0.426731\n",
      "         Iterations: 35\n",
      "Warning: Maximum number of iterations has been exceeded.\n",
      "         Current function value: 0.492921\n",
      "         Iterations: 35\n",
      "Warning: Maximum number of iterations has been exceeded.\n",
      "         Current function value: 0.514083\n",
      "         Iterations: 35\n",
      "Warning: Maximum number of iterations has been exceeded.\n",
      "         Current function value: 0.473534\n",
      "         Iterations: 35\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/jessegeerts/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/base/model.py:508: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals\n",
      "  \"Check mle_retvals\", ConvergenceWarning)\n",
      "/Users/jessegeerts/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/base/model.py:508: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals\n",
      "  \"Check mle_retvals\", ConvergenceWarning)\n",
      "/Users/jessegeerts/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/base/model.py:508: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals\n",
      "  \"Check mle_retvals\", ConvergenceWarning)\n",
      "/Users/jessegeerts/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/base/model.py:508: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals\n",
      "  \"Check mle_retvals\", ConvergenceWarning)\n",
      "/Users/jessegeerts/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/base/model.py:508: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals\n",
      "  \"Check mle_retvals\", ConvergenceWarning)\n",
      "/Users/jessegeerts/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/base/model.py:508: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals\n",
      "  \"Check mle_retvals\", ConvergenceWarning)\n",
      "/Users/jessegeerts/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/base/model.py:508: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals\n",
      "  \"Check mle_retvals\", ConvergenceWarning)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warning: Maximum number of iterations has been exceeded.\n",
      "         Current function value: 0.445715\n",
      "         Iterations: 35\n",
      "Warning: Maximum number of iterations has been exceeded.\n",
      "         Current function value: 0.490803\n",
      "         Iterations: 35\n",
      "Optimization terminated successfully.\n",
      "         Current function value: 0.428061\n",
      "         Iterations 24\n"
     ]
    },
    {
     "ename": "LinAlgError",
     "evalue": "Singular matrix",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mLinAlgError\u001b[0m                               Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-27-f72b44884927>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0magent\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAgent\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munique\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m     \u001b[0mmod\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msmf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mformula\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'Action1 ~ PreviousReward * PreviousAction * SameStart'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAgent\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0magent\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m     \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmod\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      5\u001b[0m     \u001b[0mdf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mres\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mignore_index\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/discrete/discrete_model.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, start_params, method, maxiter, full_output, disp, callback, **kwargs)\u001b[0m\n\u001b[1;32m   1832\u001b[0m         bnryfit = super(Logit, self).fit(start_params=start_params,\n\u001b[1;32m   1833\u001b[0m                 \u001b[0mmethod\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmethod\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmaxiter\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmaxiter\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfull_output\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfull_output\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1834\u001b[0;31m                 disp=disp, callback=callback, **kwargs)\n\u001b[0m\u001b[1;32m   1835\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1836\u001b[0m         \u001b[0mdiscretefit\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mLogitResults\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbnryfit\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/discrete/discrete_model.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, start_params, method, maxiter, full_output, disp, callback, **kwargs)\u001b[0m\n\u001b[1;32m    218\u001b[0m         mlefit = super(DiscreteModel, self).fit(start_params=start_params,\n\u001b[1;32m    219\u001b[0m                 \u001b[0mmethod\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmethod\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmaxiter\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmaxiter\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfull_output\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfull_output\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 220\u001b[0;31m                 disp=disp, callback=callback, **kwargs)\n\u001b[0m\u001b[1;32m    221\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    222\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mmlefit\u001b[0m \u001b[0;31m# up to subclasses to wrap results\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/base/model.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, start_params, method, maxiter, full_output, disp, fargs, callback, retall, skip_hessian, **kwargs)\u001b[0m\n\u001b[1;32m    471\u001b[0m             \u001b[0mHinv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcov_params_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxopt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    472\u001b[0m         \u001b[0;32melif\u001b[0m \u001b[0mmethod\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'newton'\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mfull_output\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 473\u001b[0;31m             \u001b[0mHinv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinalg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mretvals\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'Hessian'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mnobs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    474\u001b[0m         \u001b[0;32melif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mskip_hessian\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    475\u001b[0m             \u001b[0mH\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhessian\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mxopt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/models/lib/python3.6/site-packages/numpy/linalg/linalg.py\u001b[0m in \u001b[0;36minv\u001b[0;34m(a)\u001b[0m\n\u001b[1;32m    549\u001b[0m     \u001b[0msignature\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'D->D'\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misComplexType\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;34m'd->d'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    550\u001b[0m     \u001b[0mextobj\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_linalg_error_extobj\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_raise_linalgerror_singular\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 551\u001b[0;31m     \u001b[0mainv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_umath_linalg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msignature\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msignature\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mextobj\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mextobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    552\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mwrap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mainv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mastype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult_t\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcopy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    553\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/models/lib/python3.6/site-packages/numpy/linalg/linalg.py\u001b[0m in \u001b[0;36m_raise_linalgerror_singular\u001b[0;34m(err, flag)\u001b[0m\n\u001b[1;32m     95\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     96\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_raise_linalgerror_singular\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0merr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mflag\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 97\u001b[0;31m     \u001b[0;32mraise\u001b[0m \u001b[0mLinAlgError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Singular matrix\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     98\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     99\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_raise_linalgerror_nonposdef\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0merr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mflag\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mLinAlgError\u001b[0m: Singular matrix"
     ]
    }
   ],
   "source": [
    "df = pd.DataFrame({}, columns=res.params.keys())\n",
    "for agent in data.Agent.unique():\n",
    "    mod = smf.logit(formula='Action1 ~ PreviousReward * PreviousAction * SameStart', data=data[data.Agent==agent])\n",
    "    res = mod.fit()\n",
    "    df = df.append(res.params.to_dict(), ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0    -1.215337\n",
       "1    -0.680725\n",
       "2    -1.007641\n",
       "3    -0.481838\n",
       "4    -1.121497\n",
       "5    -0.757686\n",
       "6    -0.867501\n",
       "7    -1.722767\n",
       "8    -1.446227\n",
       "9    -1.963610\n",
       "10   -0.885519\n",
       "11   -0.344840\n",
       "12   -1.178655\n",
       "13   -1.174598\n",
       "14   -1.515912\n",
       "15   -2.018533\n",
       "16   -1.317301\n",
       "17   -1.163151\n",
       "18   -0.597837\n",
       "19   -0.346625\n",
       "20   -0.938596\n",
       "21   -1.128959\n",
       "22   -1.376725\n",
       "23   -0.385504\n",
       "24   -0.693147\n",
       "25   -1.402043\n",
       "26   -1.709521\n",
       "27   -1.317707\n",
       "28   -0.435318\n",
       "29   -1.410987\n",
       "30   -1.544083\n",
       "31   -0.976010\n",
       "Name: PreviousReward, dtype: float64"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df[\"PreviousReward\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table class=\"simpletable\">\n",
       "<caption>Logit Regression Results</caption>\n",
       "<tr>\n",
       "  <th>Dep. Variable:</th>      <td>Action1</td>     <th>  No. Observations:  </th>  <td>   271</td>  \n",
       "</tr>\n",
       "<tr>\n",
       "  <th>Model:</th>               <td>Logit</td>      <th>  Df Residuals:      </th>  <td>   263</td>  \n",
       "</tr>\n",
       "<tr>\n",
       "  <th>Method:</th>               <td>MLE</td>       <th>  Df Model:          </th>  <td>     7</td>  \n",
       "</tr>\n",
       "<tr>\n",
       "  <th>Date:</th>          <td>Wed, 29 Jul 2020</td> <th>  Pseudo R-squ.:     </th>  <td>0.2475</td>  \n",
       "</tr>\n",
       "<tr>\n",
       "  <th>Time:</th>              <td>15:23:33</td>     <th>  Log-Likelihood:    </th> <td> -133.01</td> \n",
       "</tr>\n",
       "<tr>\n",
       "  <th>converged:</th>           <td>False</td>      <th>  LL-Null:           </th> <td> -176.75</td> \n",
       "</tr>\n",
       "<tr>\n",
       "  <th> </th>                      <td> </td>        <th>  LLR p-value:       </th> <td>4.054e-16</td>\n",
       "</tr>\n",
       "</table>\n",
       "<table class=\"simpletable\">\n",
       "<tr>\n",
       "                         <td></td>                            <th>coef</th>     <th>std err</th>      <th>z</th>      <th>P>|z|</th>  <th>[0.025</th>    <th>0.975]</th>  \n",
       "</tr>\n",
       "<tr>\n",
       "  <th>Intercept</th>                                       <td>   -0.7732</td> <td>    0.349</td> <td>   -2.215</td> <td> 0.027</td> <td>   -1.457</td> <td>   -0.089</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>SameStart[T.True]</th>                               <td>    0.6678</td> <td>    0.477</td> <td>    1.401</td> <td> 0.161</td> <td>   -0.267</td> <td>    1.602</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>PreviousReward</th>                                  <td>   -0.9760</td> <td>    0.518</td> <td>   -1.883</td> <td> 0.060</td> <td>   -1.992</td> <td>    0.040</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>PreviousReward:SameStart[T.True]</th>                <td>   -1.9391</td> <td>    0.948</td> <td>   -2.046</td> <td> 0.041</td> <td>   -3.797</td> <td>   -0.081</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>PreviousAction</th>                                  <td>    0.2426</td> <td>    0.530</td> <td>    0.458</td> <td> 0.647</td> <td>   -0.796</td> <td>    1.281</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>PreviousAction:SameStart[T.True]</th>                <td>   -0.4474</td> <td>    0.737</td> <td>   -0.607</td> <td> 0.544</td> <td>   -1.893</td> <td>    0.998</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>PreviousReward:PreviousAction</th>                   <td>    2.1426</td> <td>    0.773</td> <td>    2.772</td> <td> 0.006</td> <td>    0.628</td> <td>    3.657</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>PreviousReward:PreviousAction:SameStart[T.True]</th> <td>   22.1332</td> <td> 8544.817</td> <td>    0.003</td> <td> 0.998</td> <td>-1.67e+04</td> <td> 1.68e+04</td>\n",
       "</tr>\n",
       "</table>"
      ],
      "text/plain": [
       "<class 'statsmodels.iolib.summary.Summary'>\n",
       "\"\"\"\n",
       "                           Logit Regression Results                           \n",
       "==============================================================================\n",
       "Dep. Variable:                Action1   No. Observations:                  271\n",
       "Model:                          Logit   Df Residuals:                      263\n",
       "Method:                           MLE   Df Model:                            7\n",
       "Date:                Wed, 29 Jul 2020   Pseudo R-squ.:                  0.2475\n",
       "Time:                        15:23:33   Log-Likelihood:                -133.01\n",
       "converged:                      False   LL-Null:                       -176.75\n",
       "                                        LLR p-value:                 4.054e-16\n",
       "===================================================================================================================\n",
       "                                                      coef    std err          z      P>|z|      [0.025      0.975]\n",
       "-------------------------------------------------------------------------------------------------------------------\n",
       "Intercept                                          -0.7732      0.349     -2.215      0.027      -1.457      -0.089\n",
       "SameStart[T.True]                                   0.6678      0.477      1.401      0.161      -0.267       1.602\n",
       "PreviousReward                                     -0.9760      0.518     -1.883      0.060      -1.992       0.040\n",
       "PreviousReward:SameStart[T.True]                   -1.9391      0.948     -2.046      0.041      -3.797      -0.081\n",
       "PreviousAction                                      0.2426      0.530      0.458      0.647      -0.796       1.281\n",
       "PreviousAction:SameStart[T.True]                   -0.4474      0.737     -0.607      0.544      -1.893       0.998\n",
       "PreviousReward:PreviousAction                       2.1426      0.773      2.772      0.006       0.628       3.657\n",
       "PreviousReward:PreviousAction:SameStart[T.True]    22.1332   8544.817      0.003      0.998   -1.67e+04    1.68e+04\n",
       "===================================================================================================================\n",
       "\"\"\""
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0    -1.215337\n",
       "1    -0.680725\n",
       "2    -1.007641\n",
       "3    -0.481838\n",
       "4    -1.121497\n",
       "5    -0.757686\n",
       "6    -0.867501\n",
       "7    -1.722767\n",
       "8    -1.446227\n",
       "9    -1.963610\n",
       "10   -0.885519\n",
       "11   -0.344840\n",
       "12   -1.178655\n",
       "13   -1.174598\n",
       "14   -1.515912\n",
       "15   -2.018533\n",
       "16   -1.317301\n",
       "17   -1.163151\n",
       "18   -0.597837\n",
       "19   -0.346625\n",
       "20   -0.938596\n",
       "21   -1.128959\n",
       "22   -1.376725\n",
       "23   -0.385504\n",
       "24   -0.693147\n",
       "25   -1.402043\n",
       "26   -1.709521\n",
       "27   -1.317707\n",
       "28   -0.435318\n",
       "29   -1.410987\n",
       "30   -1.544083\n",
       "31   -0.976010\n",
       "Name: PreviousReward, dtype: float64"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df['PreviousReward']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Ttest_1sampResult(statistic=-13.497324667045547, pvalue=1.6042574528289812e-14)"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from scipy.stats import ttest_ind\n",
    "from scipy.stats import ttest_1samp\n",
    "\n",
    "ttest_1samp(df['PreviousReward'],0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Ttest_1sampResult(statistic=-4.5382898405157395, pvalue=8.029457107185251e-05)"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ttest_1samp(df['PreviousReward:SameStart[T.True]'], 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Help on function ttest_1samp in module scipy.stats.stats:\n",
      "\n",
      "ttest_1samp(a, popmean, axis=0, nan_policy='propagate')\n",
      "    Calculate the T-test for the mean of ONE group of scores.\n",
      "    \n",
      "    This is a two-sided test for the null hypothesis that the expected value\n",
      "    (mean) of a sample of independent observations `a` is equal to the given\n",
      "    population mean, `popmean`.\n",
      "    \n",
      "    Parameters\n",
      "    ----------\n",
      "    a : array_like\n",
      "        sample observation\n",
      "    popmean : float or array_like\n",
      "        expected value in null hypothesis. If array_like, then it must have the\n",
      "        same shape as `a` excluding the axis dimension\n",
      "    axis : int or None, optional\n",
      "        Axis along which to compute test. If None, compute over the whole\n",
      "        array `a`.\n",
      "    nan_policy : {'propagate', 'raise', 'omit'}, optional\n",
      "        Defines how to handle when input contains nan. 'propagate' returns nan,\n",
      "        'raise' throws an error, 'omit' performs the calculations ignoring nan\n",
      "        values. Default is 'propagate'.\n",
      "    \n",
      "    Returns\n",
      "    -------\n",
      "    statistic : float or array\n",
      "        t-statistic\n",
      "    pvalue : float or array\n",
      "        two-tailed p-value\n",
      "    \n",
      "    Examples\n",
      "    --------\n",
      "    >>> from scipy import stats\n",
      "    \n",
      "    >>> np.random.seed(7654567)  # fix seed to get the same result\n",
      "    >>> rvs = stats.norm.rvs(loc=5, scale=10, size=(50,2))\n",
      "    \n",
      "    Test if mean of random sample is equal to true mean, and different mean.\n",
      "    We reject the null hypothesis in the second case and don't reject it in\n",
      "    the first case.\n",
      "    \n",
      "    >>> stats.ttest_1samp(rvs,5.0)\n",
      "    (array([-0.68014479, -0.04323899]), array([ 0.49961383,  0.96568674]))\n",
      "    >>> stats.ttest_1samp(rvs,0.0)\n",
      "    (array([ 2.77025808,  4.11038784]), array([ 0.00789095,  0.00014999]))\n",
      "    \n",
      "    Examples using axis and non-scalar dimension for population mean.\n",
      "    \n",
      "    >>> stats.ttest_1samp(rvs,[5.0,0.0])\n",
      "    (array([-0.68014479,  4.11038784]), array([  4.99613833e-01,   1.49986458e-04]))\n",
      "    >>> stats.ttest_1samp(rvs.T,[5.0,0.0],axis=1)\n",
      "    (array([-0.68014479,  4.11038784]), array([  4.99613833e-01,   1.49986458e-04]))\n",
      "    >>> stats.ttest_1samp(rvs,[[5.0],[0.0]])\n",
      "    (array([[-0.68014479, -0.04323899],\n",
      "           [ 2.77025808,  4.11038784]]), array([[  4.99613833e-01,   9.65686743e-01],\n",
      "           [  7.89094663e-03,   1.49986458e-04]]))\n",
      "\n"
     ]
    }
   ],
   "source": [
    "help(ttest_1samp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Optimization terminated successfully.\n",
      "         Current function value: 0.498717\n",
      "         Iterations 7\n"
     ]
    }
   ],
   "source": [
    "mod = smf.logit(formula='Action1 ~ PreviousReward * PreviousAction * SameStart * Agent', data=data)\n",
    "res = mod.fit()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table class=\"simpletable\">\n",
       "<caption>Logit Regression Results</caption>\n",
       "<tr>\n",
       "  <th>Dep. Variable:</th>      <td>Action1</td>     <th>  No. Observations:  </th>  <td> 10840</td> \n",
       "</tr>\n",
       "<tr>\n",
       "  <th>Model:</th>               <td>Logit</td>      <th>  Df Residuals:      </th>  <td> 10824</td> \n",
       "</tr>\n",
       "<tr>\n",
       "  <th>Method:</th>               <td>MLE</td>       <th>  Df Model:          </th>  <td>    15</td> \n",
       "</tr>\n",
       "<tr>\n",
       "  <th>Date:</th>          <td>Wed, 29 Jul 2020</td> <th>  Pseudo R-squ.:     </th>  <td>0.2800</td> \n",
       "</tr>\n",
       "<tr>\n",
       "  <th>Time:</th>              <td>15:23:35</td>     <th>  Log-Likelihood:    </th> <td> -5406.1</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>converged:</th>           <td>True</td>       <th>  LL-Null:           </th> <td> -7508.1</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th> </th>                      <td> </td>        <th>  LLR p-value:       </th>  <td> 0.000</td> \n",
       "</tr>\n",
       "</table>\n",
       "<table class=\"simpletable\">\n",
       "<tr>\n",
       "                            <td></td>                               <th>coef</th>     <th>std err</th>      <th>z</th>      <th>P>|z|</th>  <th>[0.025</th>    <th>0.975]</th>  \n",
       "</tr>\n",
       "<tr>\n",
       "  <th>Intercept</th>                                             <td>   -0.4366</td> <td>    0.105</td> <td>   -4.157</td> <td> 0.000</td> <td>   -0.642</td> <td>   -0.231</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>SameStart[T.True]</th>                                     <td>    0.4108</td> <td>    0.149</td> <td>    2.760</td> <td> 0.006</td> <td>    0.119</td> <td>    0.702</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>PreviousReward</th>                                        <td>   -1.0693</td> <td>    0.166</td> <td>   -6.456</td> <td> 0.000</td> <td>   -1.394</td> <td>   -0.745</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>PreviousReward:SameStart[T.True]</th>                      <td>   -2.3645</td> <td>    0.356</td> <td>   -6.638</td> <td> 0.000</td> <td>   -3.063</td> <td>   -1.666</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>PreviousAction</th>                                        <td>    0.5712</td> <td>    0.154</td> <td>    3.699</td> <td> 0.000</td> <td>    0.269</td> <td>    0.874</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>PreviousAction:SameStart[T.True]</th>                      <td>   -0.8381</td> <td>    0.218</td> <td>   -3.843</td> <td> 0.000</td> <td>   -1.265</td> <td>   -0.411</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>PreviousReward:PreviousAction</th>                         <td>    2.3226</td> <td>    0.243</td> <td>    9.567</td> <td> 0.000</td> <td>    1.847</td> <td>    2.798</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>PreviousReward:PreviousAction:SameStart[T.True]</th>       <td>    4.4519</td> <td>    0.502</td> <td>    8.867</td> <td> 0.000</td> <td>    3.468</td> <td>    5.436</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>Agent</th>                                                 <td>    0.0032</td> <td>    0.005</td> <td>    0.683</td> <td> 0.495</td> <td>   -0.006</td> <td>    0.012</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>SameStart[T.True]:Agent</th>                               <td>   -0.0003</td> <td>    0.007</td> <td>   -0.043</td> <td> 0.966</td> <td>   -0.013</td> <td>    0.013</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>PreviousReward:Agent</th>                                  <td>   -0.0011</td> <td>    0.007</td> <td>   -0.144</td> <td> 0.885</td> <td>   -0.015</td> <td>    0.013</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>PreviousReward:SameStart[T.True]:Agent</th>                <td>   -0.0007</td> <td>    0.016</td> <td>   -0.045</td> <td> 0.964</td> <td>   -0.031</td> <td>    0.030</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>PreviousAction:Agent</th>                                  <td>    0.0009</td> <td>    0.007</td> <td>    0.128</td> <td> 0.898</td> <td>   -0.013</td> <td>    0.014</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>PreviousAction:SameStart[T.True]:Agent</th>                <td>    0.0057</td> <td>    0.010</td> <td>    0.590</td> <td> 0.555</td> <td>   -0.013</td> <td>    0.025</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>PreviousReward:PreviousAction:Agent</th>                   <td>   -0.0020</td> <td>    0.011</td> <td>   -0.188</td> <td> 0.851</td> <td>   -0.023</td> <td>    0.019</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>PreviousReward:PreviousAction:SameStart[T.True]:Agent</th> <td>    0.0063</td> <td>    0.022</td> <td>    0.282</td> <td> 0.778</td> <td>   -0.037</td> <td>    0.050</td>\n",
       "</tr>\n",
       "</table>"
      ],
      "text/plain": [
       "<class 'statsmodels.iolib.summary.Summary'>\n",
       "\"\"\"\n",
       "                           Logit Regression Results                           \n",
       "==============================================================================\n",
       "Dep. Variable:                Action1   No. Observations:                10840\n",
       "Model:                          Logit   Df Residuals:                    10824\n",
       "Method:                           MLE   Df Model:                           15\n",
       "Date:                Wed, 29 Jul 2020   Pseudo R-squ.:                  0.2800\n",
       "Time:                        15:23:35   Log-Likelihood:                -5406.1\n",
       "converged:                       True   LL-Null:                       -7508.1\n",
       "                                        LLR p-value:                     0.000\n",
       "=========================================================================================================================\n",
       "                                                            coef    std err          z      P>|z|      [0.025      0.975]\n",
       "-------------------------------------------------------------------------------------------------------------------------\n",
       "Intercept                                                -0.4366      0.105     -4.157      0.000      -0.642      -0.231\n",
       "SameStart[T.True]                                         0.4108      0.149      2.760      0.006       0.119       0.702\n",
       "PreviousReward                                           -1.0693      0.166     -6.456      0.000      -1.394      -0.745\n",
       "PreviousReward:SameStart[T.True]                         -2.3645      0.356     -6.638      0.000      -3.063      -1.666\n",
       "PreviousAction                                            0.5712      0.154      3.699      0.000       0.269       0.874\n",
       "PreviousAction:SameStart[T.True]                         -0.8381      0.218     -3.843      0.000      -1.265      -0.411\n",
       "PreviousReward:PreviousAction                             2.3226      0.243      9.567      0.000       1.847       2.798\n",
       "PreviousReward:PreviousAction:SameStart[T.True]           4.4519      0.502      8.867      0.000       3.468       5.436\n",
       "Agent                                                     0.0032      0.005      0.683      0.495      -0.006       0.012\n",
       "SameStart[T.True]:Agent                                  -0.0003      0.007     -0.043      0.966      -0.013       0.013\n",
       "PreviousReward:Agent                                     -0.0011      0.007     -0.144      0.885      -0.015       0.013\n",
       "PreviousReward:SameStart[T.True]:Agent                   -0.0007      0.016     -0.045      0.964      -0.031       0.030\n",
       "PreviousAction:Agent                                      0.0009      0.007      0.128      0.898      -0.013       0.014\n",
       "PreviousAction:SameStart[T.True]:Agent                    0.0057      0.010      0.590      0.555      -0.013       0.025\n",
       "PreviousReward:PreviousAction:Agent                      -0.0020      0.011     -0.188      0.851      -0.023       0.019\n",
       "PreviousReward:PreviousAction:SameStart[T.True]:Agent     0.0063      0.022      0.282      0.778      -0.037       0.050\n",
       "=========================================================================================================================\n",
       "\"\"\""
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3.17537174578342e-11"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res.pvalues['PreviousReward:SameStart[T.True]']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}