from tqdm import tqdm
import matplotlib.pyplot as plt
from definitions import ROOT_FOLDER
import os
import pandas as pd
from hippocampus.agents import CombinedAgent
from hippocampus.environments import HexWaterMaze
import numpy as np
import seaborn as sns
from datetime import datetime
import argparse
np.random.seed(0)
# save location
results_folder = os.path.join(ROOT_FOLDER, 'results', 'cue_vs_place_watermaze')
figure_folder = os.path.join(results_folder, 'figures')
if not os.path.exists(results_folder):
os.makedirs(results_folder)
os.makedirs(figure_folder)
g = HexWaterMaze(6)
g.starting_state = 42
g.current_state = 42
possible_platform_states = np.array([51, 57])
for group in ['dls', 'hpc', 'sham', 'double']:
tqdm.write('\nRunning {} lesioned group... \n'.format(group))
if group == 'dls':
l_dls = True
l_hpc = False
elif group == 'hpc':
l_dls = False
l_hpc = True
elif group == 'sham':
l_dls = False
l_hpc = False
elif group == 'double':
l_dls = True
l_hpc = True
for agent_n in tqdm(range(15)):
g.other_terminals = []
filename = os.path.join(results_folder, '{}_agent{}.csv'.format(group, agent_n))
if os.path.exists(filename):
continue
platform_sequence = possible_platform_states
np.random.shuffle(platform_sequence)
agent = CombinedAgent(g, init_sr='rw',
lesion_dls=l_dls,
lesion_hpc=l_hpc,
inv_temp=10.,
gamma=.99,
learning_rate=.01,
eta=.03)
agent_results = []
agent_ets = []
session = 0
total_trial_count = 0
n_training_trials = 15
for trial in tqdm(range(n_training_trials), leave=False):
# every first trial of a session, change the platform location
if trial == 0:
g.set_platform_state(platform_sequence[0])
res = agent.one_episode(random_policy=False, random_start_loc=False)
res['trial'] = trial
res['escape time'] = res.time.max()
res['total trial'] = total_trial_count
res['trial type'] = 'train'
agent_results.append(res)
agent_ets.append(res.time.max())
total_trial_count += 1
# probe trial
g.set_platform_state(platform_sequence[1]) # set new platform state
g.add_terminal(platform_sequence[0]) # set previous platform to also be terminal
res = agent.one_episode(random_policy=False, random_start_loc=False)
res['trial'] = trial
res['escape time'] = res.time.max()
res['total trial'] = total_trial_count
res['trial type'] = 'probe'
agent_results.append(res)
agent_ets.append(res.time.max())
agent_df = pd.concat(agent_results)
agent_df['total time'] = np.arange(len(agent_df))
agent_df['agent'] = 0
agent_df.to_csv(filename)