{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.append('../..')\n", "\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import os\n", "\n", "from definitions import RESULTS_FOLDER, FIGURE_FOLDER\n", "import statsmodels.formula.api as smf\n", "\n", "\n", "figure_location = os.path.join(FIGURE_FOLDER,'twostep')\n", "if not os.path.exists(figure_location):\n", " os.makedirs(figure_location)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "data_lesion_hpc = pd.read_csv(os.path.join(RESULTS_FOLDER, 'twostep', 'results_lesion_hpc.csv'))\n", "data_lesion_dls = pd.read_csv(os.path.join(RESULTS_FOLDER, 'twostep', 'results_lesion_dls.csv'))\n", "data_control = pd.read_csv(os.path.join(RESULTS_FOLDER, 'twostep', 'results_control.csv'))" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>Unnamed: 0</th>\n", " <th>Action1</th>\n", " <th>Action2</th>\n", " <th>Agent</th>\n", " <th>DLS reliability</th>\n", " <th>HPC reliability</th>\n", " <th>P(SR)</th>\n", " <th>Reward</th>\n", " <th>StartState</th>\n", " <th>State2</th>\n", " <th>Terminus</th>\n", " <th>Trial</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.057327</td>\n", " <td>0.087327</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>3.0</td>\n", " <td>5.0</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>1</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.115198</td>\n", " <td>0.167028</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>3.0</td>\n", " <td>5.0</td>\n", " <td>1.0</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>2</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.175993</td>\n", " <td>0.239769</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>3.0</td>\n", " <td>5.0</td>\n", " <td>2.0</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>3</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.224014</td>\n", " <td>0.306158</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>3.0</td>\n", " <td>5.0</td>\n", " <td>3.0</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>4</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.252047</td>\n", " <td>0.366749</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>4.0</td>\n", " <td>7.0</td>\n", " <td>4.0</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " Unnamed: 0 Action1 Action2 Agent DLS reliability HPC reliability \\\n", "0 0 0.0 0.0 0.0 0.057327 0.087327 \n", "1 1 0.0 0.0 0.0 0.115198 0.167028 \n", "2 2 0.0 0.0 0.0 0.175993 0.239769 \n", "3 3 0.0 0.0 0.0 0.224014 0.306158 \n", "4 4 0.0 0.0 0.0 0.252047 0.366749 \n", "\n", " P(SR) Reward StartState State2 Terminus Trial \n", "0 1.0 1.0 0.0 3.0 5.0 0.0 \n", "1 1.0 1.0 0.0 3.0 5.0 1.0 \n", "2 1.0 1.0 0.0 3.0 5.0 2.0 \n", "3 1.0 0.0 0.0 3.0 5.0 3.0 \n", "4 1.0 1.0 0.0 4.0 7.0 4.0 " ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data_lesion_dls.head()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def is_common_or_rare(action, out):\n", " left_outcomes = (5, 6)\n", " right_outcomes = (7, 8)\n", " if action == 0 and out in left_outcomes:\n", " return 'common'\n", " elif action == 0 and out in right_outcomes:\n", " return 'rare'\n", " elif action == 1 and out in left_outcomes:\n", " return 'rare'\n", " elif action == 1 and out in right_outcomes:\n", " return 'common'\n", " else:\n", " raise ValueError('The combination of action and outcome does not make sense')" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def add_relevant_columns(dataframe):\n", " dataframe['PreviousAction'] = dataframe.groupby(['Agent'])['Action1'].shift(1)\n", " dataframe['PreviousStart'] = dataframe.groupby(['Agent'])['StartState'].shift(1)\n", " dataframe['PreviousReward'] = dataframe.groupby(['Agent'])['Reward'].shift(1)\n", " dataframe['Stay'] = (dataframe.PreviousAction == dataframe.Action1)\n", " dataframe['Transition'] = np.vectorize(is_common_or_rare)(dataframe['Action1'], dataframe['Terminus'])\n", " dataframe['PreviousTransition'] = dataframe.groupby(['Agent'])['Transition'].shift(1)\n", " " ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "add_relevant_columns(data_control)\n", "add_relevant_columns(data_lesion_dls)\n", "add_relevant_columns(data_lesion_hpc)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>Unnamed: 0</th>\n", " <th>Action1</th>\n", " <th>Action2</th>\n", " <th>Agent</th>\n", " <th>DLS reliability</th>\n", " <th>HPC reliability</th>\n", " <th>P(SR)</th>\n", " <th>Reward</th>\n", " <th>StartState</th>\n", " <th>State2</th>\n", " <th>Terminus</th>\n", " <th>Trial</th>\n", " <th>PreviousAction</th>\n", " <th>PreviousStart</th>\n", " <th>PreviousReward</th>\n", " <th>Stay</th>\n", " <th>Transition</th>\n", " <th>PreviousTransition</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.057327</td>\n", " <td>0.087327</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>3.0</td>\n", " <td>5.0</td>\n", " <td>0.0</td>\n", " <td>NaN</td>\n", " <td>NaN</td>\n", " <td>NaN</td>\n", " <td>False</td>\n", " <td>common</td>\n", " <td>NaN</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>1</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.115198</td>\n", " <td>0.167028</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>3.0</td>\n", " <td>5.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>True</td>\n", " <td>common</td>\n", " <td>common</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>2</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.175993</td>\n", " <td>0.239769</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>3.0</td>\n", " <td>5.0</td>\n", " <td>2.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>True</td>\n", " <td>common</td>\n", " <td>common</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>3</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.224014</td>\n", " <td>0.306158</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>3.0</td>\n", " <td>5.0</td>\n", " <td>3.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>True</td>\n", " <td>common</td>\n", " <td>common</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>4</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.252047</td>\n", " <td>0.366749</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>4.0</td>\n", " <td>7.0</td>\n", " <td>4.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>True</td>\n", " <td>rare</td>\n", " <td>common</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " Unnamed: 0 Action1 Action2 Agent DLS reliability HPC reliability \\\n", "0 0 0.0 0.0 0.0 0.057327 0.087327 \n", "1 1 0.0 0.0 0.0 0.115198 0.167028 \n", "2 2 0.0 0.0 0.0 0.175993 0.239769 \n", "3 3 0.0 0.0 0.0 0.224014 0.306158 \n", "4 4 0.0 0.0 0.0 0.252047 0.366749 \n", "\n", " P(SR) Reward StartState State2 Terminus Trial PreviousAction \\\n", "0 1.0 1.0 0.0 3.0 5.0 0.0 NaN \n", "1 1.0 1.0 0.0 3.0 5.0 1.0 0.0 \n", "2 1.0 1.0 0.0 3.0 5.0 2.0 0.0 \n", "3 1.0 0.0 0.0 3.0 5.0 3.0 0.0 \n", "4 1.0 1.0 0.0 4.0 7.0 4.0 0.0 \n", "\n", " PreviousStart PreviousReward Stay Transition PreviousTransition \n", "0 NaN NaN False common NaN \n", "1 0.0 1.0 True common common \n", "2 0.0 1.0 True common common \n", "3 0.0 1.0 True common common \n", "4 0.0 0.0 True rare common " ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data_lesion_dls.head()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "def compute_mean_stay_prob(data):\n", " means = data[data['Trial']>0].groupby(['PreviousTransition', 'PreviousReward'])['Stay'].mean()\n", " sems = data.groupby(['PreviousTransition', 'PreviousReward'])['Stay'].sem()\n", " return means, sems" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "mean_lesion_hpc, sem_lesion_hpc = compute_mean_stay_prob(data_lesion_hpc)\n", "mean_lesion_dls, sem_lesion_dls = compute_mean_stay_prob(data_lesion_dls)\n", "mean_full, sem_full = compute_mean_stay_prob(data_control)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PreviousTransition PreviousReward\n", "common 0.0 0.558473\n", " 1.0 0.836690\n", "rare 0.0 0.563003\n", " 1.0 0.823360\n", "Name: Stay, dtype: float64" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mean_lesion_hpc" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "def plot_daw_style(ax, data, yerr=None, title=''):\n", " lightgray = '#d1d1d1'\n", " darkgray = '#929292'\n", "\n", " bar_width= 0.2\n", "\n", " bars1 = data[:2][::-1]\n", " bars2 = data[2:][::-1]\n", " if yerr is not None:\n", " errs1 = yerr[:2][::-1]\n", " errs2 = yerr[2:][::-1]\n", " else:\n", " errs1 = yerr\n", " errs2 = yerr\n", " \n", " # The x position of bars\n", " r1 = np.array([0.125, 0.625]) \n", " r2 = [x + bar_width + .05 for x in r1]\n", " list(sem_full),\n", " plt.sca(ax)\n", " \n", " plt.bar(r1, bars1, width=bar_width, color='blue', capsize=4,yerr=errs1)\n", " plt.bar(r2, bars2, width=bar_width, color='red', capsize=4,yerr=errs1)\n", " plt.xticks([r+ bar_width/2 +.025 for r in r1], ['Rewarded', 'Unrewarded'], fontsize=12)\n", " plt.yticks(fontsize=12)\n", " plt.title(title, fontsize=16)\n", " plt.ylim([0.4, 1])\n", " plt.xlim([0, 1])" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "<Figure size 648x180 with 3 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fig, axes = plt.subplots(1,3, figsize= (9.5,2.5), sharey=True)\n", "\n", "plot_daw_style(axes[0], list(mean_lesion_hpc), yerr=sem_lesion_hpc, title='Striatum')\n", "plot_daw_style(axes[1], list(mean_lesion_dls), yerr=sem_lesion_dls, title='Hippocampus')\n", "plot_daw_style(axes[2], list(mean_full), yerr=sem_full, title='Full model')\n", "\n", "\n", "leg = axes[1].legend(['Common', 'Rare'], fontsize=12, frameon=False, handlelength=0.7, title='Previous transition')\n", "plt.sca(axes[0])\n", "plt.ylabel('Stay probability', fontsize=12)\n", "\n", "plt.savefig(os.path.join(figure_location, 'StayProbability.pdf'))\n", "\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "plt.close()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Doll et al analysis" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "figure_location = os.path.join(FIGURE_FOLDER,'twostep_deterministic')\n", "if not os.path.exists(figure_location):\n", " os.makedirs(figure_location)\n", "\n", "# load data\n", "data_lesion_hpc = pd.read_csv(os.path.join(RESULTS_FOLDER, 'twostep_deterministic', 'results_lesion_hpc.csv'))\n", "data_lesion_dls = pd.read_csv(os.path.join(RESULTS_FOLDER, 'twostep_deterministic', 'results_lesion_dls.csv'))\n", "data_control = pd.read_csv(os.path.join(RESULTS_FOLDER, 'twostep_deterministic', 'results_control.csv'))\n" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>Unnamed: 0</th>\n", " <th>Action1</th>\n", " <th>Action2</th>\n", " <th>Agent</th>\n", " <th>DLS reliability</th>\n", " <th>HPC reliability</th>\n", " <th>P(SR)</th>\n", " <th>Reward</th>\n", " <th>StartState</th>\n", " <th>State2</th>\n", " <th>Terminus</th>\n", " <th>Trial</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.000000</td>\n", " <td>0.030000</td>\n", " <td>0.681250</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>2.0</td>\n", " <td>4.0</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>1</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.000000</td>\n", " <td>0.059100</td>\n", " <td>0.671966</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>3.0</td>\n", " <td>7.0</td>\n", " <td>1.0</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>2</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.021000</td>\n", " <td>0.087327</td>\n", " <td>0.675625</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>3.0</td>\n", " <td>7.0</td>\n", " <td>2.0</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>3</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.029370</td>\n", " <td>0.114707</td>\n", " <td>0.670047</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>2.0</td>\n", " <td>4.0</td>\n", " <td>3.0</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>4</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.055789</td>\n", " <td>0.141266</td>\n", " <td>0.669577</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>3.0</td>\n", " <td>7.0</td>\n", " <td>4.0</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " Unnamed: 0 Action1 Action2 Agent DLS reliability HPC reliability \\\n", "0 0 0.0 0.0 0.0 0.000000 0.030000 \n", "1 1 1.0 1.0 0.0 0.000000 0.059100 \n", "2 2 1.0 1.0 0.0 0.021000 0.087327 \n", "3 3 0.0 0.0 0.0 0.029370 0.114707 \n", "4 4 1.0 1.0 0.0 0.055789 0.141266 \n", "\n", " P(SR) Reward StartState State2 Terminus Trial \n", "0 0.681250 1.0 1.0 2.0 4.0 0.0 \n", "1 0.671966 1.0 0.0 3.0 7.0 1.0 \n", "2 0.675625 1.0 1.0 3.0 7.0 2.0 \n", "3 0.670047 0.0 1.0 2.0 4.0 3.0 \n", "4 0.669577 1.0 0.0 3.0 7.0 4.0 " ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data_control.head()" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "def add_relevant_columns(dataframe):\n", " dataframe['PreviousAction'] = dataframe.groupby(['Agent'])['Action1'].shift(1)\n", " dataframe['PreviousStart'] = dataframe.groupby(['Agent'])['StartState'].shift(1)\n", " dataframe['PreviousReward'] = dataframe.groupby(['Agent'])['Reward'].shift(1)\n", " dataframe['Stay'] = (dataframe.PreviousAction == dataframe.Action1)\n", " dataframe['SameStart'] = (dataframe.StartState == dataframe.PreviousStart)\n", "\n", "\n", "add_relevant_columns(data_control)\n", "add_relevant_columns(data_lesion_dls)\n", "add_relevant_columns(data_lesion_hpc)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "def compute_mean_stay_prob(data):\n", " means = data[data['Trial']>0].groupby(['PreviousReward', 'SameStart'])['Stay'].mean()\n", " sems = data.groupby(['PreviousReward', 'SameStart'])['Stay'].sem()\n", " return means[::-1], sems[::-1]\n" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "mean_lesion_hpc, sem_lesion_hpc = compute_mean_stay_prob(data_lesion_hpc)\n", "mean_lesion_dls, sem_lesion_dls = compute_mean_stay_prob(data_lesion_dls)\n", "mean_full, sem_full = compute_mean_stay_prob(data_control)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PreviousReward SameStart\n", "1.0 True 0.944564\n", " False 0.953787\n", "0.0 True 0.484103\n", " False 0.484945\n", "Name: Stay, dtype: float64" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mean_lesion_dls" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "def plot_doll_style(ax, data, yerr=None, title=''):\n", " lightgray = '#d1d1d1'\n", " darkgray = '#929292'\n", "\n", " bar_width = 0.2\n", "\n", " bars1 = data[:2]\n", " bars2 = data[2:]\n", " if yerr is not None:\n", " errs1 = yerr[:2]\n", " errs2 = yerr[2:]\n", " else:\n", " errs1 = None\n", " errs2 = None \n", " \n", " # The x position of bars\n", " r1 = np.arange(len(bars1)) * .8 + 1.5 * bar_width\n", " r2 = [x + bar_width for x in r1]\n", "\n", " plt.sca(ax)\n", "\n", " handle1 = plt.bar(r1, bars1, width=bar_width, color=lightgray, yerr=errs1, capsize=4)\n", " handle2 = plt.bar(r2, bars2, width=bar_width, color=darkgray, yerr=errs2, capsize=4)\n", " plt.ylabel('Stay probability', fontsize=15)\n", " plt.xticks([r + bar_width / 2 for r in r1], ['same', 'different'], fontsize=15)\n", " plt.yticks(fontsize=15)\n", " plt.title(title, fontsize=18)\n", " plt.ylim([0.45, 1.])\n", " plt.xlim([0, 1.6])\n", "\n", " ax.spines['right'].set_visible(False)\n", " ax.spines['top'].set_visible(False)\n", " return handle1, handle2" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/jessegeerts/miniconda3/envs/models/lib/python3.6/site-packages/matplotlib/tight_layout.py:176: UserWarning: Tight layout not applied. The left and right margins cannot be made large enough to accommodate all axes decorations. \n", " warnings.warn('Tight layout not applied. The left and right margins '\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "<Figure size 266.4x180 with 3 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fig, axes = plt.subplots(1,3, figsize= (3.7,2.5), sharey=True)\n", "\n", "h1, h2 = plot_doll_style(axes[0], list(mean_lesion_hpc), title='Striatum')\n", "h1, h2 = plot_doll_style(axes[1], list(mean_lesion_dls), title='Hippocampus')\n", "h1, h2 = plot_doll_style(axes[2], list(mean_full), title='Full model')\n", "\n", "\n", "axes[0].set_position([0.1,0.1,0.5,0.7])\n", "axes[1].set_position([0.8,0.1,0.5,0.7])\n", "axes[2].set_position([1.5,0.1,0.5,0.7])\n", "\n", "#leg = axes[2].legend(['Reward', 'No reward'], fontsize=12, frameon=False, handlelength=.7)\n", "\n", "#plt.subplots_adjust(left=0.07, right=.93, wspace=0.25, hspace=0.35)\n", "leg = fig.legend([h1, h2], ['Reward', 'No reward'], bbox_to_anchor=(2.1, .5), loc = (1,.5), title=\"Previous outcome\")\n", "leg.set_title('Previous outcome', prop={'size': 12})\n", "\n", "plt.tight_layout()\n", "\n", "plt.savefig(os.path.join(figure_location, 'StayProbability_DeterministicTask.pdf'), bbox_inches='tight')\n", "\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "plt.close()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "SameStart PreviousReward\n", "False 0.0 0.492189\n", " 1.0 0.493106\n", "True 0.0 0.491967\n", " 1.0 0.492369\n", "Name: P(SR), dtype: float64" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data_control.groupby([\"SameStart\", \"PreviousReward\"])['P(SR)'].mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Regression analysis" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>Unnamed: 0</th>\n", " <th>Action1</th>\n", " <th>Action2</th>\n", " <th>Agent</th>\n", " <th>DLS reliability</th>\n", " <th>HPC reliability</th>\n", " <th>P(SR)</th>\n", " <th>Reward</th>\n", " <th>StartState</th>\n", " <th>State2</th>\n", " <th>Terminus</th>\n", " <th>Trial</th>\n", " <th>PreviousAction</th>\n", " <th>PreviousStart</th>\n", " <th>PreviousReward</th>\n", " <th>Stay</th>\n", " <th>SameStart</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.000000</td>\n", " <td>0.030000</td>\n", " <td>0.681250</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>2.0</td>\n", " <td>4.0</td>\n", " <td>0.0</td>\n", " <td>NaN</td>\n", " <td>NaN</td>\n", " <td>NaN</td>\n", " <td>False</td>\n", " <td>False</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>1</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.000000</td>\n", " <td>0.059100</td>\n", " <td>0.671966</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>3.0</td>\n", " <td>7.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>False</td>\n", " <td>False</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>2</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.021000</td>\n", " <td>0.087327</td>\n", " <td>0.675625</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>3.0</td>\n", " <td>7.0</td>\n", " <td>2.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>True</td>\n", " <td>False</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>3</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.029370</td>\n", " <td>0.114707</td>\n", " <td>0.670047</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>2.0</td>\n", " <td>4.0</td>\n", " <td>3.0</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>False</td>\n", " <td>True</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>4</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.055789</td>\n", " <td>0.141266</td>\n", " <td>0.669577</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>3.0</td>\n", " <td>7.0</td>\n", " <td>4.0</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>False</td>\n", " <td>False</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " Unnamed: 0 Action1 Action2 Agent DLS reliability HPC reliability \\\n", "0 0 0.0 0.0 0.0 0.000000 0.030000 \n", "1 1 1.0 1.0 0.0 0.000000 0.059100 \n", "2 2 1.0 1.0 0.0 0.021000 0.087327 \n", "3 3 0.0 0.0 0.0 0.029370 0.114707 \n", "4 4 1.0 1.0 0.0 0.055789 0.141266 \n", "\n", " P(SR) Reward StartState State2 Terminus Trial PreviousAction \\\n", "0 0.681250 1.0 1.0 2.0 4.0 0.0 NaN \n", "1 0.671966 1.0 0.0 3.0 7.0 1.0 0.0 \n", "2 0.675625 1.0 1.0 3.0 7.0 2.0 1.0 \n", "3 0.670047 0.0 1.0 2.0 4.0 3.0 1.0 \n", "4 0.669577 1.0 0.0 3.0 7.0 4.0 0.0 \n", "\n", " PreviousStart PreviousReward Stay SameStart \n", "0 NaN NaN False False \n", "1 1.0 1.0 False False \n", "2 0.0 1.0 True False \n", "3 1.0 1.0 False True \n", "4 1.0 0.0 False False " ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data_control.head()" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "data = data_control[['Agent', 'PreviousReward', 'PreviousAction', 'SameStart', 'Action1']]" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Optimization terminated successfully.\n", " Current function value: 0.499005\n", " Iterations 7\n" ] } ], "source": [ "mod = smf.logit(formula='Action1 ~ PreviousReward * PreviousAction * SameStart', data=data)\n", "res = mod.fit()" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Optimization terminated successfully.\n", " Current function value: 0.489139\n", " Iterations 6\n", "Optimization terminated successfully.\n", " Current function value: 0.538574\n", " Iterations 8\n", "Warning: Maximum number of iterations has been exceeded.\n", " Current function value: 0.412727\n", " Iterations: 35\n", "Optimization terminated successfully.\n", " Current function value: 0.507848\n", " Iterations 7\n", "Warning: Maximum number of iterations has been exceeded.\n", " Current function value: 0.496648\n", " Iterations: 35\n", "Optimization terminated successfully.\n", " Current function value: 0.518046\n", " Iterations 8\n", "Optimization terminated successfully.\n", " Current function value: 0.498005\n", " Iterations 7\n", "Optimization terminated successfully.\n", " Current function value: 0.425808\n", " Iterations 8\n", "Optimization terminated successfully.\n", " Current function value: 0.549229\n", " Iterations 6\n", "Warning: Maximum number of iterations has been exceeded.\n", " Current function value: 0.444553\n", " Iterations: 35\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/jessegeerts/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/base/model.py:508: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals\n", " \"Check mle_retvals\", ConvergenceWarning)\n", "/Users/jessegeerts/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/base/model.py:508: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals\n", " \"Check mle_retvals\", ConvergenceWarning)\n", "/Users/jessegeerts/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/base/model.py:508: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals\n", " \"Check mle_retvals\", ConvergenceWarning)\n", "/Users/jessegeerts/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/base/model.py:508: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals\n", " \"Check mle_retvals\", ConvergenceWarning)\n", "/Users/jessegeerts/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/base/model.py:508: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals\n", " \"Check mle_retvals\", ConvergenceWarning)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Warning: Maximum number of iterations has been exceeded.\n", " Current function value: 0.450759\n", " Iterations: 35\n", "Warning: Maximum number of iterations has been exceeded.\n", " Current function value: 0.473348\n", " Iterations: 35\n", "Warning: Maximum number of iterations has been exceeded.\n", " Current function value: 0.456640\n", " Iterations: 35\n", "Optimization terminated successfully.\n", " Current function value: 0.476260\n", " Iterations 8\n", "Optimization terminated successfully.\n", " Current function value: 0.442589\n", " Iterations 8\n", "Optimization terminated successfully.\n", " Current function value: 0.523259\n", " Iterations 7\n", "Warning: Maximum number of iterations has been exceeded.\n", " Current function value: 0.478907\n", " Iterations: 35\n", "Optimization terminated successfully.\n", " Current function value: 0.471507\n", " Iterations 21\n", "Warning: Maximum number of iterations has been exceeded.\n", " Current function value: 0.507018\n", " Iterations: 35\n", "Optimization terminated successfully.\n", " Current function value: 0.506375\n", " Iterations 7\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/jessegeerts/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/base/model.py:508: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals\n", " \"Check mle_retvals\", ConvergenceWarning)\n", "/Users/jessegeerts/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/base/model.py:508: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals\n", " \"Check mle_retvals\", ConvergenceWarning)\n", "/Users/jessegeerts/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/base/model.py:508: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals\n", " \"Check mle_retvals\", ConvergenceWarning)\n", "/Users/jessegeerts/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/base/model.py:508: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals\n", " \"Check mle_retvals\", ConvergenceWarning)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Warning: Maximum number of iterations has been exceeded.\n", " Current function value: 0.453712\n", " Iterations: 35\n", "Optimization terminated successfully.\n", " Current function value: 0.497200\n", " Iterations 7\n", "Optimization terminated successfully.\n", " Current function value: 0.486236\n", " Iterations 8\n", "Optimization terminated successfully.\n", " Current function value: 0.529446\n", " Iterations 7\n", "Warning: Maximum number of iterations has been exceeded.\n", " Current function value: 0.483552\n", " Iterations: 35\n", "Optimization terminated successfully.\n", " Current function value: 0.525430\n", " Iterations 7\n", "Warning: Maximum number of iterations has been exceeded.\n", " Current function value: 0.426731\n", " Iterations: 35\n", "Warning: Maximum number of iterations has been exceeded.\n", " Current function value: 0.492921\n", " Iterations: 35\n", "Warning: Maximum number of iterations has been exceeded.\n", " Current function value: 0.514083\n", " Iterations: 35\n", "Warning: Maximum number of iterations has been exceeded.\n", " Current function value: 0.473534\n", " Iterations: 35\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/jessegeerts/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/base/model.py:508: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals\n", " \"Check mle_retvals\", ConvergenceWarning)\n", "/Users/jessegeerts/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/base/model.py:508: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals\n", " \"Check mle_retvals\", ConvergenceWarning)\n", "/Users/jessegeerts/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/base/model.py:508: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals\n", " \"Check mle_retvals\", ConvergenceWarning)\n", "/Users/jessegeerts/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/base/model.py:508: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals\n", " \"Check mle_retvals\", ConvergenceWarning)\n", "/Users/jessegeerts/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/base/model.py:508: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals\n", " \"Check mle_retvals\", ConvergenceWarning)\n", "/Users/jessegeerts/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/base/model.py:508: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals\n", " \"Check mle_retvals\", ConvergenceWarning)\n", "/Users/jessegeerts/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/base/model.py:508: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals\n", " \"Check mle_retvals\", ConvergenceWarning)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Warning: Maximum number of iterations has been exceeded.\n", " Current function value: 0.445715\n", " Iterations: 35\n", "Warning: Maximum number of iterations has been exceeded.\n", " Current function value: 0.490803\n", " Iterations: 35\n", "Optimization terminated successfully.\n", " Current function value: 0.428061\n", " Iterations 24\n" ] }, { "ename": "LinAlgError", "evalue": "Singular matrix", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mLinAlgError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m<ipython-input-27-f72b44884927>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0magent\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAgent\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munique\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mmod\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msmf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mformula\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'Action1 ~ PreviousReward * PreviousAction * SameStart'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAgent\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0magent\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmod\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0mdf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mres\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mignore_index\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/discrete/discrete_model.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, start_params, method, maxiter, full_output, disp, callback, **kwargs)\u001b[0m\n\u001b[1;32m 1832\u001b[0m bnryfit = super(Logit, self).fit(start_params=start_params,\n\u001b[1;32m 1833\u001b[0m \u001b[0mmethod\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmethod\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmaxiter\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmaxiter\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfull_output\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfull_output\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1834\u001b[0;31m disp=disp, callback=callback, **kwargs)\n\u001b[0m\u001b[1;32m 1835\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1836\u001b[0m \u001b[0mdiscretefit\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mLogitResults\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbnryfit\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/discrete/discrete_model.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, start_params, method, maxiter, full_output, disp, callback, **kwargs)\u001b[0m\n\u001b[1;32m 218\u001b[0m mlefit = super(DiscreteModel, self).fit(start_params=start_params,\n\u001b[1;32m 219\u001b[0m \u001b[0mmethod\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmethod\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmaxiter\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmaxiter\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfull_output\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfull_output\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 220\u001b[0;31m disp=disp, callback=callback, **kwargs)\n\u001b[0m\u001b[1;32m 221\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 222\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmlefit\u001b[0m \u001b[0;31m# up to subclasses to wrap results\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda3/envs/models/lib/python3.6/site-packages/statsmodels/base/model.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, start_params, method, maxiter, full_output, disp, fargs, callback, retall, skip_hessian, **kwargs)\u001b[0m\n\u001b[1;32m 471\u001b[0m \u001b[0mHinv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcov_params_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxopt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 472\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mmethod\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'newton'\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mfull_output\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 473\u001b[0;31m \u001b[0mHinv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinalg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mretvals\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'Hessian'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mnobs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 474\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mskip_hessian\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 475\u001b[0m \u001b[0mH\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhessian\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mxopt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda3/envs/models/lib/python3.6/site-packages/numpy/linalg/linalg.py\u001b[0m in \u001b[0;36minv\u001b[0;34m(a)\u001b[0m\n\u001b[1;32m 549\u001b[0m \u001b[0msignature\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'D->D'\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misComplexType\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;34m'd->d'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 550\u001b[0m \u001b[0mextobj\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_linalg_error_extobj\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_raise_linalgerror_singular\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 551\u001b[0;31m \u001b[0mainv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_umath_linalg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msignature\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msignature\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mextobj\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mextobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 552\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mainv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mastype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult_t\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcopy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 553\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda3/envs/models/lib/python3.6/site-packages/numpy/linalg/linalg.py\u001b[0m in \u001b[0;36m_raise_linalgerror_singular\u001b[0;34m(err, flag)\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 96\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_raise_linalgerror_singular\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0merr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mflag\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 97\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mLinAlgError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Singular matrix\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 98\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 99\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_raise_linalgerror_nonposdef\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0merr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mflag\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mLinAlgError\u001b[0m: Singular matrix" ] } ], "source": [ "df = pd.DataFrame({}, columns=res.params.keys())\n", "for agent in data.Agent.unique():\n", " mod = smf.logit(formula='Action1 ~ PreviousReward * PreviousAction * SameStart', data=data[data.Agent==agent])\n", " res = mod.fit()\n", " df = df.append(res.params.to_dict(), ignore_index=True)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0 -1.215337\n", "1 -0.680725\n", "2 -1.007641\n", "3 -0.481838\n", "4 -1.121497\n", "5 -0.757686\n", "6 -0.867501\n", "7 -1.722767\n", "8 -1.446227\n", "9 -1.963610\n", "10 -0.885519\n", "11 -0.344840\n", "12 -1.178655\n", "13 -1.174598\n", "14 -1.515912\n", "15 -2.018533\n", "16 -1.317301\n", "17 -1.163151\n", "18 -0.597837\n", "19 -0.346625\n", "20 -0.938596\n", "21 -1.128959\n", "22 -1.376725\n", "23 -0.385504\n", "24 -0.693147\n", "25 -1.402043\n", "26 -1.709521\n", "27 -1.317707\n", "28 -0.435318\n", "29 -1.410987\n", "30 -1.544083\n", "31 -0.976010\n", "Name: PreviousReward, dtype: float64" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df[\"PreviousReward\"]" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<table class=\"simpletable\">\n", "<caption>Logit Regression Results</caption>\n", "<tr>\n", " <th>Dep. Variable:</th> <td>Action1</td> <th> No. Observations: </th> <td> 271</td> \n", "</tr>\n", "<tr>\n", " <th>Model:</th> <td>Logit</td> <th> Df Residuals: </th> <td> 263</td> \n", "</tr>\n", "<tr>\n", " <th>Method:</th> <td>MLE</td> <th> Df Model: </th> <td> 7</td> \n", "</tr>\n", "<tr>\n", " <th>Date:</th> <td>Wed, 29 Jul 2020</td> <th> Pseudo R-squ.: </th> <td>0.2475</td> \n", "</tr>\n", "<tr>\n", " <th>Time:</th> <td>15:23:33</td> <th> Log-Likelihood: </th> <td> -133.01</td> \n", "</tr>\n", "<tr>\n", " <th>converged:</th> <td>False</td> <th> LL-Null: </th> <td> -176.75</td> \n", "</tr>\n", "<tr>\n", " <th> </th> <td> </td> <th> LLR p-value: </th> <td>4.054e-16</td>\n", "</tr>\n", "</table>\n", "<table class=\"simpletable\">\n", "<tr>\n", " <td></td> <th>coef</th> <th>std err</th> <th>z</th> <th>P>|z|</th> <th>[0.025</th> <th>0.975]</th> \n", "</tr>\n", "<tr>\n", " <th>Intercept</th> <td> -0.7732</td> <td> 0.349</td> <td> -2.215</td> <td> 0.027</td> <td> -1.457</td> <td> -0.089</td>\n", "</tr>\n", "<tr>\n", " <th>SameStart[T.True]</th> <td> 0.6678</td> <td> 0.477</td> <td> 1.401</td> <td> 0.161</td> <td> -0.267</td> <td> 1.602</td>\n", "</tr>\n", "<tr>\n", " <th>PreviousReward</th> <td> -0.9760</td> <td> 0.518</td> <td> -1.883</td> <td> 0.060</td> <td> -1.992</td> <td> 0.040</td>\n", "</tr>\n", "<tr>\n", " <th>PreviousReward:SameStart[T.True]</th> <td> -1.9391</td> <td> 0.948</td> <td> -2.046</td> <td> 0.041</td> <td> -3.797</td> <td> -0.081</td>\n", "</tr>\n", "<tr>\n", " <th>PreviousAction</th> <td> 0.2426</td> <td> 0.530</td> <td> 0.458</td> <td> 0.647</td> <td> -0.796</td> <td> 1.281</td>\n", "</tr>\n", "<tr>\n", " <th>PreviousAction:SameStart[T.True]</th> <td> -0.4474</td> <td> 0.737</td> <td> -0.607</td> <td> 0.544</td> <td> -1.893</td> <td> 0.998</td>\n", "</tr>\n", "<tr>\n", " <th>PreviousReward:PreviousAction</th> <td> 2.1426</td> <td> 0.773</td> <td> 2.772</td> <td> 0.006</td> <td> 0.628</td> <td> 3.657</td>\n", "</tr>\n", "<tr>\n", " <th>PreviousReward:PreviousAction:SameStart[T.True]</th> <td> 22.1332</td> <td> 8544.817</td> <td> 0.003</td> <td> 0.998</td> <td>-1.67e+04</td> <td> 1.68e+04</td>\n", "</tr>\n", "</table>" ], "text/plain": [ "<class 'statsmodels.iolib.summary.Summary'>\n", "\"\"\"\n", " Logit Regression Results \n", "==============================================================================\n", "Dep. Variable: Action1 No. Observations: 271\n", "Model: Logit Df Residuals: 263\n", "Method: MLE Df Model: 7\n", "Date: Wed, 29 Jul 2020 Pseudo R-squ.: 0.2475\n", "Time: 15:23:33 Log-Likelihood: -133.01\n", "converged: False LL-Null: -176.75\n", " LLR p-value: 4.054e-16\n", "===================================================================================================================\n", " coef std err z P>|z| [0.025 0.975]\n", "-------------------------------------------------------------------------------------------------------------------\n", "Intercept -0.7732 0.349 -2.215 0.027 -1.457 -0.089\n", "SameStart[T.True] 0.6678 0.477 1.401 0.161 -0.267 1.602\n", "PreviousReward -0.9760 0.518 -1.883 0.060 -1.992 0.040\n", "PreviousReward:SameStart[T.True] -1.9391 0.948 -2.046 0.041 -3.797 -0.081\n", "PreviousAction 0.2426 0.530 0.458 0.647 -0.796 1.281\n", "PreviousAction:SameStart[T.True] -0.4474 0.737 -0.607 0.544 -1.893 0.998\n", "PreviousReward:PreviousAction 2.1426 0.773 2.772 0.006 0.628 3.657\n", "PreviousReward:PreviousAction:SameStart[T.True] 22.1332 8544.817 0.003 0.998 -1.67e+04 1.68e+04\n", "===================================================================================================================\n", "\"\"\"" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "res.summary()" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0 -1.215337\n", "1 -0.680725\n", "2 -1.007641\n", "3 -0.481838\n", "4 -1.121497\n", "5 -0.757686\n", "6 -0.867501\n", "7 -1.722767\n", "8 -1.446227\n", "9 -1.963610\n", "10 -0.885519\n", "11 -0.344840\n", "12 -1.178655\n", "13 -1.174598\n", "14 -1.515912\n", "15 -2.018533\n", "16 -1.317301\n", "17 -1.163151\n", "18 -0.597837\n", "19 -0.346625\n", "20 -0.938596\n", "21 -1.128959\n", "22 -1.376725\n", "23 -0.385504\n", "24 -0.693147\n", "25 -1.402043\n", "26 -1.709521\n", "27 -1.317707\n", "28 -0.435318\n", "29 -1.410987\n", "30 -1.544083\n", "31 -0.976010\n", "Name: PreviousReward, dtype: float64" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df['PreviousReward']" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Ttest_1sampResult(statistic=-13.497324667045547, pvalue=1.6042574528289812e-14)" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from scipy.stats import ttest_ind\n", "from scipy.stats import ttest_1samp\n", "\n", "ttest_1samp(df['PreviousReward'],0)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Ttest_1sampResult(statistic=-4.5382898405157395, pvalue=8.029457107185251e-05)" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ttest_1samp(df['PreviousReward:SameStart[T.True]'], 0)" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Help on function ttest_1samp in module scipy.stats.stats:\n", "\n", "ttest_1samp(a, popmean, axis=0, nan_policy='propagate')\n", " Calculate the T-test for the mean of ONE group of scores.\n", " \n", " This is a two-sided test for the null hypothesis that the expected value\n", " (mean) of a sample of independent observations `a` is equal to the given\n", " population mean, `popmean`.\n", " \n", " Parameters\n", " ----------\n", " a : array_like\n", " sample observation\n", " popmean : float or array_like\n", " expected value in null hypothesis. If array_like, then it must have the\n", " same shape as `a` excluding the axis dimension\n", " axis : int or None, optional\n", " Axis along which to compute test. If None, compute over the whole\n", " array `a`.\n", " nan_policy : {'propagate', 'raise', 'omit'}, optional\n", " Defines how to handle when input contains nan. 'propagate' returns nan,\n", " 'raise' throws an error, 'omit' performs the calculations ignoring nan\n", " values. Default is 'propagate'.\n", " \n", " Returns\n", " -------\n", " statistic : float or array\n", " t-statistic\n", " pvalue : float or array\n", " two-tailed p-value\n", " \n", " Examples\n", " --------\n", " >>> from scipy import stats\n", " \n", " >>> np.random.seed(7654567) # fix seed to get the same result\n", " >>> rvs = stats.norm.rvs(loc=5, scale=10, size=(50,2))\n", " \n", " Test if mean of random sample is equal to true mean, and different mean.\n", " We reject the null hypothesis in the second case and don't reject it in\n", " the first case.\n", " \n", " >>> stats.ttest_1samp(rvs,5.0)\n", " (array([-0.68014479, -0.04323899]), array([ 0.49961383, 0.96568674]))\n", " >>> stats.ttest_1samp(rvs,0.0)\n", " (array([ 2.77025808, 4.11038784]), array([ 0.00789095, 0.00014999]))\n", " \n", " Examples using axis and non-scalar dimension for population mean.\n", " \n", " >>> stats.ttest_1samp(rvs,[5.0,0.0])\n", " (array([-0.68014479, 4.11038784]), array([ 4.99613833e-01, 1.49986458e-04]))\n", " >>> stats.ttest_1samp(rvs.T,[5.0,0.0],axis=1)\n", " (array([-0.68014479, 4.11038784]), array([ 4.99613833e-01, 1.49986458e-04]))\n", " >>> stats.ttest_1samp(rvs,[[5.0],[0.0]])\n", " (array([[-0.68014479, -0.04323899],\n", " [ 2.77025808, 4.11038784]]), array([[ 4.99613833e-01, 9.65686743e-01],\n", " [ 7.89094663e-03, 1.49986458e-04]]))\n", "\n" ] } ], "source": [ "help(ttest_1samp)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Optimization terminated successfully.\n", " Current function value: 0.498717\n", " Iterations 7\n" ] } ], "source": [ "mod = smf.logit(formula='Action1 ~ PreviousReward * PreviousAction * SameStart * Agent', data=data)\n", "res = mod.fit()" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<table class=\"simpletable\">\n", "<caption>Logit Regression Results</caption>\n", "<tr>\n", " <th>Dep. Variable:</th> <td>Action1</td> <th> No. Observations: </th> <td> 10840</td> \n", "</tr>\n", "<tr>\n", " <th>Model:</th> <td>Logit</td> <th> Df Residuals: </th> <td> 10824</td> \n", "</tr>\n", "<tr>\n", " <th>Method:</th> <td>MLE</td> <th> Df Model: </th> <td> 15</td> \n", "</tr>\n", "<tr>\n", " <th>Date:</th> <td>Wed, 29 Jul 2020</td> <th> Pseudo R-squ.: </th> <td>0.2800</td> \n", "</tr>\n", "<tr>\n", " <th>Time:</th> <td>15:23:35</td> <th> Log-Likelihood: </th> <td> -5406.1</td>\n", "</tr>\n", "<tr>\n", " <th>converged:</th> <td>True</td> <th> LL-Null: </th> <td> -7508.1</td>\n", "</tr>\n", "<tr>\n", " <th> </th> <td> </td> <th> LLR p-value: </th> <td> 0.000</td> \n", "</tr>\n", "</table>\n", "<table class=\"simpletable\">\n", "<tr>\n", " <td></td> <th>coef</th> <th>std err</th> <th>z</th> <th>P>|z|</th> <th>[0.025</th> <th>0.975]</th> \n", "</tr>\n", "<tr>\n", " <th>Intercept</th> <td> -0.4366</td> <td> 0.105</td> <td> -4.157</td> <td> 0.000</td> <td> -0.642</td> <td> -0.231</td>\n", "</tr>\n", "<tr>\n", " <th>SameStart[T.True]</th> <td> 0.4108</td> <td> 0.149</td> <td> 2.760</td> <td> 0.006</td> <td> 0.119</td> <td> 0.702</td>\n", "</tr>\n", "<tr>\n", " <th>PreviousReward</th> <td> -1.0693</td> <td> 0.166</td> <td> -6.456</td> <td> 0.000</td> <td> -1.394</td> <td> -0.745</td>\n", "</tr>\n", "<tr>\n", " <th>PreviousReward:SameStart[T.True]</th> <td> -2.3645</td> <td> 0.356</td> <td> -6.638</td> <td> 0.000</td> <td> -3.063</td> <td> -1.666</td>\n", "</tr>\n", "<tr>\n", " <th>PreviousAction</th> <td> 0.5712</td> <td> 0.154</td> <td> 3.699</td> <td> 0.000</td> <td> 0.269</td> <td> 0.874</td>\n", "</tr>\n", "<tr>\n", " <th>PreviousAction:SameStart[T.True]</th> <td> -0.8381</td> <td> 0.218</td> <td> -3.843</td> <td> 0.000</td> <td> -1.265</td> <td> -0.411</td>\n", "</tr>\n", "<tr>\n", " <th>PreviousReward:PreviousAction</th> <td> 2.3226</td> <td> 0.243</td> <td> 9.567</td> <td> 0.000</td> <td> 1.847</td> <td> 2.798</td>\n", "</tr>\n", "<tr>\n", " <th>PreviousReward:PreviousAction:SameStart[T.True]</th> <td> 4.4519</td> <td> 0.502</td> <td> 8.867</td> <td> 0.000</td> <td> 3.468</td> <td> 5.436</td>\n", "</tr>\n", "<tr>\n", " <th>Agent</th> <td> 0.0032</td> <td> 0.005</td> <td> 0.683</td> <td> 0.495</td> <td> -0.006</td> <td> 0.012</td>\n", "</tr>\n", "<tr>\n", " <th>SameStart[T.True]:Agent</th> <td> -0.0003</td> <td> 0.007</td> <td> -0.043</td> <td> 0.966</td> <td> -0.013</td> <td> 0.013</td>\n", "</tr>\n", "<tr>\n", " <th>PreviousReward:Agent</th> <td> -0.0011</td> <td> 0.007</td> <td> -0.144</td> <td> 0.885</td> <td> -0.015</td> <td> 0.013</td>\n", "</tr>\n", "<tr>\n", " <th>PreviousReward:SameStart[T.True]:Agent</th> <td> -0.0007</td> <td> 0.016</td> <td> -0.045</td> <td> 0.964</td> <td> -0.031</td> <td> 0.030</td>\n", "</tr>\n", "<tr>\n", " <th>PreviousAction:Agent</th> <td> 0.0009</td> <td> 0.007</td> <td> 0.128</td> <td> 0.898</td> <td> -0.013</td> <td> 0.014</td>\n", "</tr>\n", "<tr>\n", " <th>PreviousAction:SameStart[T.True]:Agent</th> <td> 0.0057</td> <td> 0.010</td> <td> 0.590</td> <td> 0.555</td> <td> -0.013</td> <td> 0.025</td>\n", "</tr>\n", "<tr>\n", " <th>PreviousReward:PreviousAction:Agent</th> <td> -0.0020</td> <td> 0.011</td> <td> -0.188</td> <td> 0.851</td> <td> -0.023</td> <td> 0.019</td>\n", "</tr>\n", "<tr>\n", " <th>PreviousReward:PreviousAction:SameStart[T.True]:Agent</th> <td> 0.0063</td> <td> 0.022</td> <td> 0.282</td> <td> 0.778</td> <td> -0.037</td> <td> 0.050</td>\n", "</tr>\n", "</table>" ], "text/plain": [ "<class 'statsmodels.iolib.summary.Summary'>\n", "\"\"\"\n", " Logit Regression Results \n", "==============================================================================\n", "Dep. Variable: Action1 No. Observations: 10840\n", "Model: Logit Df Residuals: 10824\n", "Method: MLE Df Model: 15\n", "Date: Wed, 29 Jul 2020 Pseudo R-squ.: 0.2800\n", "Time: 15:23:35 Log-Likelihood: -5406.1\n", "converged: True LL-Null: -7508.1\n", " LLR p-value: 0.000\n", "=========================================================================================================================\n", " coef std err z P>|z| [0.025 0.975]\n", "-------------------------------------------------------------------------------------------------------------------------\n", "Intercept -0.4366 0.105 -4.157 0.000 -0.642 -0.231\n", "SameStart[T.True] 0.4108 0.149 2.760 0.006 0.119 0.702\n", "PreviousReward -1.0693 0.166 -6.456 0.000 -1.394 -0.745\n", "PreviousReward:SameStart[T.True] -2.3645 0.356 -6.638 0.000 -3.063 -1.666\n", "PreviousAction 0.5712 0.154 3.699 0.000 0.269 0.874\n", "PreviousAction:SameStart[T.True] -0.8381 0.218 -3.843 0.000 -1.265 -0.411\n", "PreviousReward:PreviousAction 2.3226 0.243 9.567 0.000 1.847 2.798\n", "PreviousReward:PreviousAction:SameStart[T.True] 4.4519 0.502 8.867 0.000 3.468 5.436\n", "Agent 0.0032 0.005 0.683 0.495 -0.006 0.012\n", "SameStart[T.True]:Agent -0.0003 0.007 -0.043 0.966 -0.013 0.013\n", "PreviousReward:Agent -0.0011 0.007 -0.144 0.885 -0.015 0.013\n", "PreviousReward:SameStart[T.True]:Agent -0.0007 0.016 -0.045 0.964 -0.031 0.030\n", "PreviousAction:Agent 0.0009 0.007 0.128 0.898 -0.013 0.014\n", "PreviousAction:SameStart[T.True]:Agent 0.0057 0.010 0.590 0.555 -0.013 0.025\n", "PreviousReward:PreviousAction:Agent -0.0020 0.011 -0.188 0.851 -0.023 0.019\n", "PreviousReward:PreviousAction:SameStart[T.True]:Agent 0.0063 0.022 0.282 0.778 -0.037 0.050\n", "=========================================================================================================================\n", "\"\"\"" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "res.summary()" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3.17537174578342e-11" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "res.pvalues['PreviousReward:SameStart[T.True]']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }