from itertools import product
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from scipy.stats import multivariate_normal
from tqdm import tqdm
from hippocampus import utils
from hippocampus.environments import SimpleMDP, HexWaterMaze, Environment, PlusMaze
from hippocampus.dynamic_programming_utils import generate_random_policy, value_iteration
from hippocampus.utils import to_agent_frame
class CombinedAgent(object):
def __init__(self, env=Environment(), init_sr='rw', lesion_dls=False, lesion_hpc=False, gamma=.95, eta=.03,
inv_temp=10, learning_rate=.1, inact_hpc=0., inact_dls=0., A_alpha=1., A_beta=.5,
alpha1=.01, beta1=.1):
self.inv_temp = inv_temp
self.eta = eta
self.learning_rate = learning_rate
self.A_alpha = A_alpha
self.A_beta = A_beta
self.alpha1 = alpha1
self.beta1 = beta1
self.lesion_striatum = lesion_dls
self.lesion_hippocampus = lesion_hpc
#if self.lesion_hippocampus and self.lesion_striatum:
# raise ValueError('cannot lesion both')
self.env = env
self.HPC = SRTD(self.env, init_sr=init_sr, gamma=gamma, eta=self.eta)
self.DLS = LandmarkLearningAgent(self.env, eta=self.eta)
self.current_choice = None
self.weights = np.zeros((self.DLS.features.n_cells, self.env.nr_actions))
self.gamma = gamma
if inact_hpc:
self.max_psr = 1. - inact_hpc
self.p_sr = self.max_psr
self.inact_dls = 0.
elif inact_dls:
self.max_psr = 1
self.inact_dls = inact_dls
self.p_sr = .8
else:
self.max_psr = 1
self.inact_dls = 0.
self.p_sr = .9
def set_exploration(self, inv_temp):
self.inv_temp = inv_temp
def one_episode(self, random_policy=False, setp_sr=None, random_start_loc=True):
if self.lesion_striatum and self.lesion_hippocampus:
random_policy = True
time_limit = 1000
self.env.reset(random_loc=random_start_loc)
t = 0
s = self.env.get_current_state()
cumulative_reward = 0
possible_orientations = np.round(np.degrees(self.env.action_directions))
angles = []
for i, o in enumerate(possible_orientations):
angle = utils.angle_to_landmark(self.env.get_state_location(s), self.env.landmark_location, np.radians(o))
angles.append(angle)
orientation = possible_orientations[np.argmin(np.abs(angles))]
# get MF system features
f = self.DLS.get_feature_rep(s, orientation)
Q_mf = self.weights.T @ f
results = pd.DataFrame({})
results = results.append({'time': t,
'reward': 0,
'SPE': 0,
'RPE': 0,
'HPC reliability': self.HPC.reliability,
'DLS reliability': self.DLS.reliability,
'alpha': self.get_alpha(self.DLS.reliability),
'beta': self.get_beta(self.DLS.reliability),
'state': s,
'P(SR)': self.p_sr,
'choice': self.current_choice,
'M_hat': self.HPC.M_hat.flatten(),
'R_hat': self.HPC.R_hat.copy(),
'Q_mf': Q_mf,
'platform': self.env.get_goal_state()}, ignore_index=True)
while not self.env.is_terminal(s) and t < time_limit:
if setp_sr is None:
self.update_p_sr()
else:
self.p_sr = setp_sr
# select action
Q_combined, Q_allo = self.compute_Q(s, orientation, self.p_sr)
if random_policy:
allo_a = np.random.choice(list(range(self.env.nr_actions)))
else:
allo_a = self.softmax_selection(s, Q_combined)
ego_a = self.get_ego_action(allo_a, orientation)
# act
next_state, reward = self.env.act(allo_a)
# get MF state representation
orientation = self.DLS.get_orientation(s, next_state, orientation)
next_f = self.DLS.get_feature_rep(next_state, orientation)
# SR updates
SPE = self.HPC.compute_error(next_state, s)
delta_M = self.HPC.learning_rate * SPE
self.HPC.M_hat[s, :] += delta_M
self.HPC.update_R(next_state, reward)
# MF updates
next_Q = self.weights.T @ next_f
if self.env.is_terminal(next_state):
RPE = reward - Q_mf[ego_a]
else:
RPE = reward + self.gamma * np.max(next_Q) - Q_mf[ego_a]
self.weights[:, ego_a] = self.weights[:, ego_a] + self.learning_rate * RPE * f
# Reliability updates
if self.env.is_terminal(next_state):
self.DLS.update_reliability(RPE)
self.HPC.update_reliability(SPE, s)
s = next_state
f = next_f
Q_mf = next_Q
t += 1
cumulative_reward += reward
results = results.append({'time': t,
'reward': reward,
'SPE': SPE,
'RPE': RPE,
'HPC reliability': self.HPC.reliability,
'DLS reliability': self.DLS.reliability,
'alpha': self.get_alpha(self.DLS.reliability),
'beta': self.get_beta(self.DLS.reliability),
'state': s,
'P(SR)': self.p_sr,
'choice': self.current_choice,
'M_hat': self.HPC.M_hat.copy(),
'R_hat': self.HPC.R_hat.copy(),
'Q_mf': Q_mf,
'Q_allo': Q_allo,
'Q': Q_combined,
'features': f.copy(),
'weights': self.weights.copy(),
'platform': self.env.get_goal_state(),
'landmark': self.env.landmark_location}, ignore_index=True)
return results
def get_ego_action(self, allo_a, orientation):
ego_angle = round(utils.get_relative_angle(np.degrees(self.env.action_directions[allo_a]), orientation))
if ego_angle == 180:
ego_angle = -180
for i, theta in enumerate(self.env.ego_angles):
if theta == round(ego_angle):
return i
raise ValueError('Angle {} not in list.'.format(ego_angle))
def update_p_sr(self):
if self.lesion_hippocampus:
self.p_sr = 0.
return
if self.lesion_striatum:
self.p_sr = 1.
return
alpha = self.get_alpha(self.DLS.reliability)
beta = self.get_beta(self.HPC.reliability)
tau = self.max_psr / (alpha + beta)
fixedpoint = (alpha + self.inact_dls * beta) * tau
dpdt = (fixedpoint - self.p_sr) / tau
new_p_sr = self.p_sr + dpdt
if new_p_sr > self.max_psr:
new_p_sr = self.max_psr
if new_p_sr < 0:
new_p_sr = 0
if new_p_sr < 0 or new_p_sr > 1:
raise ValueError('P(SR) is not a probability: {}'.format(new_p_sr))
self.p_sr = new_p_sr
def get_alpha(self, chi_mf):
alpha1 = self.alpha1
A = self.A_alpha
B = np.log((alpha1 ** -1) * A - 1)
return A / (1 + np.exp(B * chi_mf))
def get_beta(self, chi_mb):
beta1 = self.beta1
A = self.A_beta
B = np.log((beta1 ** -1) * A - 1)
return A / (1 + np.exp(B * chi_mb))
def compute_Q(self, state_idx, orientation, p_sr):
# compute Q_SR
V = self.HPC.M_hat @ self.HPC.R_hat
next_state = [self.env.get_next_state(state_idx, a) for a in range(self.env.nr_actions)]
Q_sr = [V[s] for s in next_state]
# compute Q_MF
features = self.DLS.get_feature_rep(state_idx, orientation)
Q_ego = self.weights.T @ features
allocentric_idx = [self.DLS.get_allo_action(idx, orientation) for idx in range(self.env.nr_actions)]
Q_allo = np.empty(len(Q_ego))
for i in range(len(Q_ego)):
allo_idx = allocentric_idx[i]
Q_allo[allo_idx] = Q_ego[i]
Q_mf = Q_allo
Q = p_sr * np.array(Q_sr) + (1-p_sr) * np.array(Q_mf)
return Q, Q_mf
def softmax_selection(self, state_index, Q):
probabilities = utils.softmax(Q, self.inv_temp)
action_idx = np.random.choice(list(range(self.env.nr_actions)), p=probabilities)
return action_idx
class QLearningAgent(object):
"""Vanilla Q learning agent with tabular state representation.
"""
max_RPE = 1
def __init__(self, environment=SimpleMDP(), learning_rate=.1, gamma=.9, epsilon=.1, eta=.03,
anneal_epsilon=False, beta=10):
"""
:param environment:
:param learning_rate:
:param gamma:
:param epsilon:
"""
self.env = environment
self.learning_rate = learning_rate
self.gamma = gamma
self.epsilon = epsilon
self.beta = beta
self.eta = eta
self.reliability = 0
self.omega = 0.
self.Q = np.zeros((self.env.nr_states, self.env.nr_actions))
self.anneal_epsilon = anneal_epsilon
if anneal_epsilon:
self.epsilon = 1
def one_episode(self, time_limit=1000):
self.env.reset()
results = pd.DataFrame({'time': [],
'reward': [],
'RPE': [],
'reliability': [],
'omega': []})
t = 0
cumulative_reward = 0
s = self.env.get_current_state()
while not self.env.is_terminal(s) and t < time_limit:
a = self.softmax_selection(s)
next_state, reward = self.env.act(a)
RPE = reward + self.gamma * np.max(self.Q[next_state]) - self.Q[s][a]
if self.env.is_terminal(next_state):
self.omega += .4 * (np.abs(RPE) - self.omega)
self.reliability += self.eta * ((1 - abs(RPE) / self.max_RPE) - self.reliability)
self.Q[s][a] = self.Q[s][a] + self.learning_rate * RPE
cumulative_reward += reward
s = next_state
t += 1
if self.anneal_epsilon and self.epsilon >= .1:
self.epsilon -= .9 / 50000
results = results.append({'time': t,
'reward': reward,
'RPE': RPE,
'reliability': self.reliability,
'omega': self.omega}, ignore_index=True)
return results
def epsilon_greedy(self, state_idx):
if np.random.rand() < self.epsilon:
action_idx = np.random.choice(list(range(self.env.nr_actions)))
else:
action_idx = utils.random_argmax(self.Q[state_idx])
return action_idx
def softmax_selection(self, state_idx):
probabilities = utils.softmax(self.Q[state_idx], self.beta)
action_idx = np.random.choice(list(range(self.env.nr_actions)), p=probabilities)
return action_idx
class SRTD(object):
def __init__(self, env=Environment(), init_sr='identity', beta=20, eta=.03, gamma=.99):
self.env = env
self.learning_rate = .1
self.epsilon = .1
self.gamma = gamma
self.beta = beta
self.eta = eta
self.reliability = .8
self.omega = 1. # np.ones(self.env.nr_states)
# SR initialisation
self.M_hat = self.init_M(init_sr)
self.identity = np.eye(self.env.nr_states)
self.R_hat = np.zeros(self.env.nr_states)
def init_M(self, init_sr):
M_hat = np.zeros((self.env.nr_states, self.env.nr_states))
if init_sr == 'zero':
return M_hat
if init_sr == 'identity':
M_hat = np.eye(self.env.nr_states)
elif init_sr == 'rw': # Random walk initalisation
random_policy = generate_random_policy(self.env)
M_hat = self.env.get_successor_representation(random_policy, gamma=self.gamma)
elif init_sr == 'opt':
optimal_policy, _ = value_iteration(self.env)
M_hat = self.env.get_successor_representation(optimal_policy, gamma=self.gamma)
return M_hat
def get_SR(self):
return self.M_hat
def one_episode(self, random_policy=False):
time_limit = 1000
self.env.reset()
t = 0
s = self.env.get_current_state()
cumulative_reward = 0
results = pd.DataFrame({'time': [],
'reward': [],
'RPE': [],
'reliability': [],
'state': []})
while not self.env.is_terminal(s) and t < time_limit:
if random_policy:
a = np.random.choice(list(range(self.env.nr_actions)))
else:
a = self.select_action(s)
next_state, reward = self.env.act(a)
SPE = self.compute_error(next_state, s)
self.update_reliability(SPE, s)
self.M_hat[s, :] += self.update_M(SPE)
self.update_R(next_state, reward)
s = next_state
t += 1
cumulative_reward += reward
results = results.append({'time': t, 'reward': reward, 'SPE': SPE, 'reliability': self.reliability,
'state': s}, ignore_index=True)
return results
def update_R(self, next_state, reward):
RPE = reward - self.R_hat[next_state]
self.R_hat[next_state] += 1. * RPE
def update_M(self, SPE):
delta_M = self.learning_rate * SPE
return delta_M
def update_reliability(self, SPE, s):
self.reliability += self.eta * (1 - abs(SPE[s]) / 1 - self.reliability)
def compute_error(self, next_state, s):
if self.env.is_terminal(next_state):
SPE = self.identity[s, :] + self.identity[next_state, :] - self.M_hat[s, :]
else:
SPE = self.identity[s, :] + self.gamma * self.M_hat[next_state, :] - self.M_hat[s, :]
return SPE
def select_action(self, state_idx, softmax=True):
# TODO: get categorical dist over next state
# okay because it's local
# gradient-based (hill-climbing) gradient ascent
# graph hill climbing
# Maybe change for M(sa,sa). potentially over state action only in two step
V = self.M_hat @ self.R_hat
next_state = [self.env.get_next_state(state_idx, a) for a in range(self.env.nr_actions)]
Q = [V[s] for s in next_state]
probabilities = utils.softmax(Q, self.beta)
return np.random.choice(list(range(self.env.nr_actions)), p=probabilities)
class LandmarkLearningAgent(object):
"""Q learning agent using landmark features.
"""
max_RPE = 1
def __init__(self, environment=Environment(), learning_rate=.1, gamma=.9, eta=.03, beta=10):
"""
:param environment:
:param learning_rate:
:param gamma:
"""
self.env = environment
self.learning_rate = learning_rate
self.gamma = gamma
self.beta = beta
self.eta = eta
self.reliability = 0
self.features = LandmarkCells()
self.weights = np.zeros((self.features.n_cells, self.env.nr_actions))
def one_episode(self, time_limit=1000):
self.env.reset()
t = 0
cumulative_reward = 0
s = self.env.get_current_state()
orientation = 30 # np.random.choice([30, 90, 150, 210, 270, 330])
f = self.get_feature_rep(s, orientation)
Q = self.weights.T @ f
results = pd.DataFrame({'time': [],
'reward': [],
'RPE': [],
'reliability': [],
'state': []})
while not self.env.is_terminal(s) and t < time_limit:
a = self.softmax_selection(s, Q)
allo_a = self.get_allo_action(a, orientation)
next_state, reward = self.env.act(allo_a)
orientation = self.get_orientation(s, next_state, orientation)
next_f = self.get_feature_rep(next_state, orientation)
RPE, next_Q = self.compute_error(f, a, next_f, next_state, reward)
if self.env.is_terminal(next_state):
self.update_reliability(RPE)
self.update_weights(RPE, a, f)
cumulative_reward += reward
s = next_state
f = next_f
Q = next_Q
t += 1
results = results.append({'time': t, 'reward': reward, 'RPE': RPE, 'reliability': self.reliability,
'state': s}, ignore_index=True)
return results
def update_reliability(self, RPE):
self.reliability += self.eta * ((1 - abs(RPE) / self.max_RPE) - self.reliability)
def update_weights(self, RPE, a, f):
self.weights[:, a] = self.weights[:, a] + self.learning_rate * RPE * f
def compute_error(self, f, a, next_f, next_state, reward):
Q = self.compute_Q(f)
next_Q = self.weights.T @ next_f
if self.env.is_terminal(next_state):
RPE = reward - Q[a]
else:
RPE = reward + self.gamma * np.max(next_Q) - Q[a]
return RPE, next_Q
def softmax_selection(self, state_index, Q):
probabilities = utils.softmax(Q, self.beta)
action_idx = np.random.choice(list(range(self.env.nr_actions)), p=probabilities)
return action_idx
def angle_to_landmark(self, state, orientation):
rel_pos = to_agent_frame(self.env.landmark_location, self.env.get_state_location(state), np.radians(orientation))
angle = np.arctan2(rel_pos[1], rel_pos[0])
return np.degrees(angle)
def get_feature_rep(self, state, orientation):
distance = self.get_distance_to_landmark(state)
angle = self.angle_to_landmark(state, orientation)
response = self.features.compute_response(distance, angle)
return response
def get_distance_to_landmark(self, state):
distance_to_landmark = np.linalg.norm(
np.array(self.env.landmark_location) - np.array(self.env.get_state_location(state)))
return distance_to_landmark
def get_orientation(self, state, next_state, current_orientation):
if state == next_state:
return current_orientation
s1 = self.env.get_state_location(state)
s2 = self.env.get_state_location(next_state)
return np.degrees(np.arctan2(s2[1] - s1[1], s2[0] - s1[0]))
def get_allo_action(self, ego_action_idx, orientation):
allo_angle = (orientation + self.env.ego_angles[ego_action_idx]) % 360
for i, theta in enumerate(self.env.allo_angles):
if theta == round(allo_angle):
return i
raise ValueError('Angle not in list.')
def compute_Q(self, features):
return self.weights.T @ features
class QLearningTwoStep(LandmarkLearningAgent):
def __init__(self, env=Environment(), eta=.03):
super().__init__(environment=env, eta=eta)
self.omega = 1
def get_feature_rep(self, state_idx, orientation):
return np.eye(self.env.nr_states)[state_idx]
def update_omega(self, RPE):
self.omega += self.eta * (np.abs(RPE) - self.omega)
class LandmarkCells(object):
def __init__(self):
self.n_angles = 8
self.angles = np.linspace(-np.pi, np.pi, self.n_angles)
self.preferred_distances = np.linspace(1, 18, 10)
self.preferred_distances = np.linspace(0.5, 6, 10)
self.field_length = 2.
self.field_width = np.radians(30)
self.receptive_fields = []
self.rf_locations = []
for r, th in product(self.preferred_distances, self.angles):
f = multivariate_normal([r, th], [[self.field_length, 0], [0, self.field_width]])
self.receptive_fields.append(f)
self.rf_locations.append((r, th))
self.n_cells = self.n_angles * len(self.preferred_distances)
def compute_response(self, distance, angle):
angle = np.radians(angle)
return np.array([f.pdf([distance, angle]) * np.sqrt((2*np.pi)**2 * np.linalg.det(f.cov)) for f in self.receptive_fields])
def plot_receptive_field(self, idx):
ax = plt.subplot(projection="polar")
n = 360
m = 100
rad = np.linspace(0, 10, m)
a = np.linspace(-np.pi, np.pi, n)
r, th = np.meshgrid(rad, a)
pos = np.empty(r.shape + (2,))
pos[:, :, 0] = r
pos[:, :, 1] = th
z = self.receptive_fields[idx].pdf(pos)
# plt.ylim([0, 2*np.pi])
plt.xlim([-np.pi, np.pi])
plt.pcolormesh(th, r, z)
ax.set_theta_zero_location('N')
ax.set_thetagrids(np.linspace(-180, 180, 6, endpoint=False))
plt.plot(a, r, ls='none', color='k')
plt.grid(True)
plt.colorbar()
return ax
if __name__ == '__main__':
from tqdm import tqdm
from definitions import ROOT_FOLDER
import os
g = HexWaterMaze(6)
g.set_platform_state(30)
agent = SRTD(g, init_sr='rw')
agent_results = []
agent_ets = []
for ep in tqdm(range(50)):
res = agent.one_episode()
res['trial'] = ep
res['escape time'] = res.time.max()
agent_results.append(res)
agent_ets.append(res.time.max())
df = pd.concat(agent_results)
results_dir = os.path.join(ROOT_FOLDER, 'results')
if not os.path.exists(results_dir):
os.makedirs(results_dir)
df.to_csv(os.path.join(results_dir, 'SRwatermaze.csv'))
SPEs = np.array(df['SPE'].tolist())
np.save(os.path.join(results_dir, 'watermazeSPEs.npy'), SPEs)