from tqdm import tqdm
import matplotlib.pyplot as plt
from definitions import ROOT_FOLDER
import os
import pandas as pd
from hippocampus.agents import CombinedAgent
from hippocampus.environments import HexWaterMaze
import numpy as np
import seaborn as sns
from datetime import datetime
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--gamma', type=float)
parser.add_argument('--inv_temp', type=float)
parser.add_argument('--n_agents')
parser.add_argument('--lesion_hippocampus')
parser.add_argument('--learning_rate', type=float)
args = parser.parse_args()
# set parameters
if args.gamma is None:
gamma = .99
else:
gamma = args.gamma
if args.n_agents is None:
n_agents = 100
else:
n_agents = args.n_agents
if args.inv_temp is None:
inv_temp = 5.
else:
inv_temp = args.inv_temp
if args.lesion_hippocampus is None:
lesion_hippocampus = False
else:
lesion_hippocampus = True
if args.learning_rate is None:
learning_rate = .07
else:
learning_rate = args.learning_rate
lesion_striatum = True
lesion_hippocampus = False
tqdm.write('Running {} agents with lr {}, gamma {}, inv temp {} and HPC lesioned {}'.format(n_agents,
learning_rate,
gamma,
inv_temp,
lesion_hippocampus))
if lesion_hippocampus and not lesion_striatum:
group = 'lesion'
elif not lesion_hippocampus and not lesion_striatum:
group = 'control'
elif lesion_striatum and not lesion_hippocampus:
group = 'lesion_DLS'
else:
group = 'other'
# save location
results_folder = os.path.join(ROOT_FOLDER, 'results', 'pearce', group, str(datetime.now()))
figure_folder = os.path.join(results_folder, 'figures')
if not os.path.exists(results_folder):
os.makedirs(results_folder)
os.makedirs(figure_folder)
params = pd.DataFrame({'n_agents': [n_agents],
'inv_temp': [inv_temp],
'gamma': [gamma],
'lesion HPC': [lesion_hippocampus],
'lesion DLS': [lesion_striatum]})
params.to_csv(os.path.join(results_folder, 'params.csv'))
# initialise environment
g = HexWaterMaze(10)
# determine platform sequence
possible_platform_states = np.array([192, 185, 181, 174, 216, 210, 203, 197]) # for the r = 10 case
def determine_platform_seq(platform_states):
indices = np.arange(len(platform_states))
usage = np.zeros(len(platform_states))
plat_seq = [np.random.choice(platform_states)]
for sess in range(1, 11):
distances = np.array([g.grid.distance(plat_seq[sess - 1], s) for s in platform_states])
candidates = indices[np.logical_and(usage < 2, distances > g.grid.radius)]
platform_idx = np.random.choice(candidates)
plat_seq.append(platform_states[platform_idx])
usage[platform_idx] += 1.
return plat_seq
for n_agent in tqdm(range(n_agents)):
# set random seed
np.random.seed(n_agent)
# determine sequence of platform locations
platform_sequence = determine_platform_seq(possible_platform_states)
# intialise agent
agent = CombinedAgent(g, init_sr='rw',
lesion_dls=lesion_striatum,
lesion_hpc=lesion_hippocampus,
inv_temp=inv_temp,
gamma=gamma,
learning_rate=learning_rate)
agent_results = []
agent_ets = []
session = 0
total_trial_count = 0
for ses in tqdm(range(11)):
for trial in tqdm(range(4), leave=False):
# every first trial of a session, change the platform location
if trial == 0:
g.set_platform_state(platform_sequence[ses])
res = agent.one_episode(random_policy=False)
res['trial'] = trial
res['escape time'] = res.time.max()
res['session'] = ses
res['total trial'] = total_trial_count
agent_results.append(res)
agent_ets.append(res.time.max())
total_trial_count += 1
agent_df = pd.concat(agent_results)
agent_df['total time'] = np.arange(len(agent_df))
agent_df['agent'] = n_agent
agent_df.to_csv(os.path.join(results_folder, 'agent{}.csv'.format(n_agent)))
# plot and save a prelim figure
first_and_last = agent_df[np.logical_or(agent_df.trial == 0, agent_df.trial == 3)]
fig = plt.figure()
ax = sns.lineplot(data=first_and_last, x='session', y='escape time', hue='trial')
plt.title('Agent n {}'.format(n_agent))
plt.savefig(os.path.join(figure_folder, 'agent{}.png'.format(n_agent)))
plt.close()
# plot averages
all_data = []
for ag in tqdm(range(n_agents), desc='loading data...'):
df = pd.read_csv(os.path.join(results_folder, 'agent{}.csv'.format(ag)))
summary = df.pivot_table(index=['agent', 'session', 'trial'], aggfunc='mean')
all_data.append(summary)
df = pd.concat(all_data)
df['platform location'] = df['platform'].astype('category')
df.to_csv(os.path.join(results_folder, 'summary.csv'))
# Plot the average escape time per platform
plt.figure()
sns.barplot(data=df.loc[(list(range(n_agents)), list(range(11)), 0)], x='platform location', y='escape time')
plt.savefig(os.path.join(figure_folder, 'et_per_platform.png'))
plt.close()
# plot the escape time per session for trials 1 and 4
plt.figure()
first_last = df.loc[(list(range(n_agents)), list(range(11)), (0, 3))]
sns.lineplot(data=first_last.reset_index(), x='session', y='escape time', hue='trial', ci=None)
plt.savefig(os.path.join(figure_folder, 'escape_time_firstlast.png'))
plt.close()
# plot the escape time per session for all trials
plt.figure()
sns.lineplot(data=df.reset_index(), x='session', y='escape time', hue='trial', ci=None)
plt.savefig(os.path.join(figure_folder, 'escape_time.png'))
plt.close()
# plot the escape time per trial with vlines indicating new sessions
plt.figure()
sns.lineplot(data=df, x='total trial', y='escape time')
for i in range(44):
if (i % 4) == 0:
plt.axvline(x=i, ymin=0, ymax=1, linewidth=1, color='r', alpha=.3)
plt.savefig(os.path.join(figure_folder, 'escape_time_pertrial.png'))
plt.close()