close all
set(0,'defaultlinelinewidth',2,'DefaultAxesFontSize',24)
rew1_flag=1; %reward 1 in the top-right corner
rew2_flag=0; %reward 2 in the bottom-left corner (moved)
plot_flag = 1; %plot trajectories at every trial
ACh_flag=0; % cholinergic depression (1=+ACh; 0=-ACh)
step = 1; %step 1ms
%% Task parameters
Trials = 40; %number of trials
T_max=15*10^3; %maximum time trial
starting_position= [0,0]; %starting position
t_rew=T_max; %time of reward - initialized to maximum
t_extreme=t_rew+300; %ttime of reward - initialized to maximu
t_end = T_max;
c = [1.5,1.5]; %centre reward 1
r_goal=0.3; %radius goal area 1
c2 = [-1.5,-1.5]; %centre reward 2 (moved)
r_goal2=0.3; %radius goal area 2 (moved)
%% Place cells
space_pc = 0.4; %place cells separation distance
bounds_x = [-2,2]; %bounds open field, x axis
bounds_y = [-2,2]; %bounds open field, y axis
x_pc = bounds_x(1):space_pc:bounds_x(2); %place cells on axis x
n_x = length(x_pc); %nr of place cells on axis x
y_pc= bounds_y(1):space_pc:bounds_y(2); %place cells on axis y
n_y = length(y_pc); %nr of place cells on axis y
pos = zeros(1,2); %position of the agent at each timestep
%create grid
y = repmat(y_pc, n_x,1);
y= reshape(y,n_x*n_y,1);
x = repmat(x_pc,1,n_y);
pc = [x',y]; %place cells' centres (x,y)
pc = round(pc*10)/10;
N_pc=length(pc); %number of place cells
rho_pc=400*10^(-3); %maximum firing rate place cells
sigma_pc=0.4; %pc separation distance
%% Action neurons - neuron model
eps0=20; %scaling constant epsp
tau_m=20; %membrane time constant
tau_s=5; %synaptic time rise epsp
chi=-5; %scaling constant refractory effect
rho0=60*10^(-3); %scaling rate
theta=16; %threshold
delta_u=2; %escape noise
%% Action neurons - parameters
N_action=40; %number action neurons
%action selection
tau_gamma = 50; %raise time convolution action selection
v_gamma=20; %decay time convolution action selection
theta_actor = 2*pi*[1:N_action]/N_action; %angles actions
%winner-take-all weights
psi = 20; %the higher, the more narrow the range of excitation
w_minus = -300;
w_plus = 100;
diff_theta = repmat(theta_actor,N_action,1) - repmat(theta_actor',1, N_action);
f = exp(psi*cos(diff_theta)); %lateral connectivity function
f = f - f.*eye(N_action);
normalised = sum(f);
normalised = normalised(1);
w_lateral = (w_minus/N_action+w_plus*f/normalised); %lateral connectivity action neurons
%actions
a0=.08;
actions = a0*[sin(theta_actor); cos(theta_actor)]; %possible actions (x,y)
dx = 0.01; %length of bouncing back from walls
%% synaptic plasticity parameters
A_pre_post=1; %amplitude pre-post window
A_post_pre=1; %amplitude post-pre window
tau_pre_post= 10; %time constant pre-post window
tau_post_pre= 10; %time constant post-pre window
tau_e= 2*10^3; %time constant eligibility trace
eta_DA=0.01; %learning rate eligibility trace
eta_ACh = 10^-3*2; %learning rate acetylcholine
%feed-forward weights
w_max=3; %upper bound feed-forward weights
w_min=1; %.pwer bound feed-forward weights
w_in = ones(N_pc, N_action)'*2; %initialization feed-forward weights
trace_pre_post= zeros(N_action, N_pc); %initialize pre-post trace
trace_post_pre= zeros(N_action,N_pc);%initialize post-pre trace
trace_tot = zeros(N_action,N_pc); %sum of the traces
eligibility_trace = zeros(N_action, N_pc); %total convolution
beta = 0.75; %average reward timescale
%% initialise variables
i=0; %counter ms
tr=0; %counter trial
w_tot = [ones(N_pc,N_action)'.*w_in, w_lateral]; %total weigths
X = zeros(N_pc,1); %matrix of spikes place cells
X_cut = zeros(N_pc+N_action, N_action); %matrix of spikes place cells
Y_action_neurons= zeros(N_action, 1); %matrix of spikes action neurons
time_reward= zeros(Trials,1); %stores time of reward 1
time_reward2= time_reward; %stores time of reward 2 (moved)
time_reward_old= time_reward; %stores time when agent enters the previously rewarded location
epsp_rise=zeros(N_action+N_pc,N_action); %epsp rise compontent convolution
epsp_decay=zeros(N_action+N_pc,N_action); %epsp decay compontent convolution
epsp_tot=zeros(N_action+N_pc, N_action); %epsp
rho_action_neurons= zeros(N_action,1); %firing rate action neurons
rho_rise= rho_action_neurons; %firing rate action neurons, rise compontent convolution
rho_decay = rho_action_neurons; %firing rate action neurons, decay compontent convolution
Canc = ones(N_pc+N_action,N_action)';
last_spike_post=zeros(N_action,1)-1000; %vector time last spike postsynaptic neuron
store_pos = zeros(T_max*Trials,2); %stores trajectories (for plotting)
firing_rate_store = zeros(N_action,T_max*Trials); %stores firing rates action neurons (for plotting)
%% initialize plot open field
figure('position', [0, 0, 1000, 2000])
subplot(2,2,1)
reward_plot = plot(c(1)+r_goal*cos(-pi:2*pi/100:pi), c(2)+r_goal*sin(-pi:2*pi/100:pi), 'color', 'black'); %plot reward 1
hold on
point_plot = plot(starting_position(1),starting_position(2), '.r', 'MarkerSize',10); %plot initial starting point
%plot walls
line([bounds_x(1) bounds_x(1)], [bounds_y(1) bounds_y(2)], 'color','black');
line([bounds_x(2) bounds_x(2)], [bounds_y(1) bounds_y(2)], 'color','black');
line([bounds_x(1) bounds_x(2)], [bounds_y(1) bounds_y(1)], 'color','black');
line([bounds_x(1) bounds_x(2)], [bounds_y(2) bounds_y(2)], 'color','black');
axis([-2 2 -2 2])
%% delete actions that lead out of the maze
%find index place cells that lie on the walls
sides(1,:) = (find(pc(:,2) == -2))'; %bottom wall, y=-2
sides(2,:) = (find(pc(:,2) == 2))'; %top wall, y=+2
sides(3,:) = (find(pc(:,1) == 2))'; %left wall, x=-2
sides(4,:) = (find(pc(:,1) == -2))'; %right wall, x=+2
%store index of actions forbidden from each side
forbidden_actions(1,:) = 11:29; %actions that point south - theta in (180, 360) degrees approx
forbidden_actions(2,:) = [1:9, 31:40]; %actions that point north - theta in (0,180) degrees approx
forbidden_actions(3,:) = 1:19; %actions that point east - theta in (-90, 90) degrees approx
forbidden_actions(4,:) = 21:39; %actions that point west - theta in (90, 270) degrees approx
%kill connections between place cells on the walls and forbidden actions
w_walls = ones(N_action, N_pc+N_action);
for g=1:4
w_walls(forbidden_actions(g,:), sides(g,:))=0;
end
%% start simulation
w_tot_old = w_tot(1:N_action,1:N_pc); %store weights before start
average_reward = 0; %initialise average reward
while i<T_max*Trials
i=i+1;
t=mod(i,T_max);
%% reset new trial
if t==1
pos = starting_position; %initialize position at origin (centre open field)
reward=0; %flag that signals when the reward is found
tr=tr+1; %trial number
t_rew=T_max; %time of reward - initialized at T_max at the beginning of the trial
%initialisation variables - reset between trials
Y_action_neurons= zeros(N_action, 1);
X_cut = zeros(N_pc+N_action, N_action);
epsp_rise=zeros(N_action+N_pc,N_action);
epsp_decay=zeros(N_action+N_pc,N_action);
epsp_tot=zeros(N_action+N_pc, N_action);
rho_action_neurons= zeros(N_action,1);
rho_rise= zeros(N_action,1);
rho_decay = zeros(N_action,1);
Canc = ones(N_pc+N_action,N_action)';
last_spike_post=zeros(N_action,1)-1000;
trace_pre_pos= zeros(N_action, N_pc);
trace_post_pre= zeros(N_action,N_pc);
trace_tot = zeros(N_action,N_pc);
eligibility_trace = zeros(N_action, N_pc);
%change reward location in the second half of the experiment
if tr== (Trials/2)+1
rew1_flag=0;
rew2_flag=1;
subplot(2,2,1)
delete(reward_plot)
reward_plot_old = plot(c(1)+r_goal*cos(-pi:2*pi/100:pi), c(2)+r_goal*sin(-pi:2*pi/100:pi), 'color', 'black', 'linestyle', '--'); %plot reward 1
reward_plot_new = plot(c2(1)+r_goal2*cos(-pi:2*pi/100:pi), c2(2)+r_goal2*sin(-pi:2*pi/100:pi), 'color', 'black'); %plot reward 2
end
average_reward = average_reward*(1-beta); %average reward decreases if no reward is found
end
%% place cells
rhos = rho_pc.*exp(-sum((repmat(pos,n_x*n_y,1)-pc).^2,2)/(sigma_pc^2)); %rate inhomogeneous poisson process
prob = rhos;
%turn place cells off after reward is reached
if t>t_rew
prob=0;
end
X = (rand(1,N_pc)<=prob')'; %spike train pcs
store_pos(i,:) = pos; %store position (for plotting)
%% reward
% agent enters reward 1 in the first half of the trial
if sum((pos-c).^2)<=r_goal^2 && reward==0 && rew1_flag==1
reward=1; %reward found
t_rew=t; %time of reward
time_reward(tr) = t; %store time of reward
average_reward = average_reward + beta*reward; %average reward increases if the reward is found
end
% agent enters reward 2 in the second half of the trial
if sum((pos-c2).^2)<=r_goal2^2 && reward==0 && rew2_flag==1
reward=1; %reward 2 found
t_rew=t; %time of reward 2
time_reward2(tr) = t; %store time of reward 2
average_reward = average_reward + beta*reward; %average reward increases if the reward is found
end
% agent enters reward 1 in the second half of the trial (previously rewarded location)
if sum((pos-c).^2)<=r_goal^2 && rew1_flag==0 && rew2_flag==1
t_rew=t; %the trial is ended, even though this location is no longer rewarded
time_reward_old(tr)=t; %store time of entrance old reward location
end
%% action neurons
% reset after last post-synaptic spike
X_cut = repmat([X; Y_action_neurons],1,N_action);
X_cut = X_cut.*Canc';
epsp_rise=epsp_rise.*Canc';
epsp_decay=epsp_decay.*Canc';
% neuron model
[epsp_tot, epsp_decay, epsp_rise] = convolution (epsp_decay, epsp_rise, tau_m, tau_s, eps0, X_cut, w_tot.*w_walls); %EPSP in the model * weights
[Y_action_neurons,last_spike_post, Canc] = neuron(epsp_tot, chi, last_spike_post, tau_m, rho0, theta, delta_u, i); %sums EPSP, calculates potential and spikes
% smooth firing rate of the action neurons
[rho_action_neurons, rho_decay, rho_rise] = convolution (rho_decay, rho_rise, tau_gamma, v_gamma, 1, Y_action_neurons);
firing_rate_store(:,i) = rho_action_neurons; %store action neurons' firing rates
% select action
a = ((rho_action_neurons'*actions')/N_action);
a(isnan(a))=0;
%% synaptic plasticity
%STDP with symmetric window
[ trace_pre_pos, trace_post_pre,eligibility_trace, trace_tot, W] = weights_update_stdp(A_pre_post, A_post_pre, tau_pre_post, tau_post_pre, repmat(X',N_action,1) , repmat(Y_action_neurons,1, N_pc), trace_pre_pos, trace_post_pre, trace_tot, tau_e);
% online weights update (effective only with acetylcholine - ACh_flag=1)
w_tot(1:N_action,1:N_pc)= w_tot(1:N_action,1:N_pc)-eta_ACh*W*(ACh_flag);
%weights limited between lower and upper bounds
w_tot(w_tot(:,1:N_pc)>w_max)=w_max;
w_tot(w_tot(:,1:N_pc)<w_min)=w_min;
%% position update
pos = pos+a;
%check if agent is out of boundaries. If it is, bounce back in the opposite direction
if pos(1)<=bounds_x(1)
pos = pos+dx*[1,0];
else
if pos(1)>= bounds_x(2)
pos = pos+dx*[-1,0];
else
if pos(2)<=bounds_y(1)
pos = pos+dx*[0,1];
else
if pos(2)>=bounds_y(2)
pos = pos+dx*[0,-1];
end
end
end
end
%time when trial end is 300ms after reward is found
t_extreme = t_rew+300;
if t> t_extreme && t<T_max
i = (ceil(i/T_max))*T_max-1; %set i counter to the end of the trial
t_end = t_extreme; %for plotting
end
if t==0
t=T_max;
%% update weights - end of trial
%weights are retroactively potentiated through an eligibility trace
%the dynamic signal=(rew-average rew) determines magnitude and sign of the update
w_tot(1:N_action,1:N_pc)=w_tot_old+eta_DA*eligibility_trace*(reward-average_reward);
%weights limited between lower and upper bounds
w_tot(w_tot(:,1:N_pc)>w_max)=w_max;
w_tot(w_tot(:,1:N_pc)<w_min)=w_min;
%store weights before the beginning of next trial (for updates in case reward is found)
w_tot_old = w_tot(1:N_action,1:N_pc);
%calculate policy
ac =actions*(w_tot_old.*w_walls(:,1:N_pc))/a0; %vector of preferred actions according to the weights
ac(:,unique(sort(reshape(sides, length(sides)*4, 1))))=0; %do not count actions AT the boundaries (just for plotting)
%% plot
if plot_flag==1
%display trajectory of the agent in each trial
subplot(2,2,1)
f3=plot(store_pos( (floor((i-1)/T_max))*T_max+1:(floor((i-1)/T_max))*T_max+t_end,1), store_pos((floor((i-1)/T_max))*T_max+1:(floor((i-1)/T_max))*T_max+t_end,2), 'red'); %trajectory
hold on
delete(point_plot)
point_plot = plot(starting_position(1),starting_position(2), '.r', 'MarkerSize',10); %starting point
title(['Trial ', num2str(tr)])
%display average reward
subplot(2,2,2)
plot(tr, average_reward, '.r', 'MarkerSize',30)
axis([1 Trials 0 1])
hold on
title('Average reward')
%display weights over the open field, averaged over action neurons
subplot(2,2,3)
w_plot = mean(w_tot(:,1:N_pc)); %use weights as they were at the beginning of the trial
w_plot = reshape(w_plot,sqrt(N_pc),sqrt(N_pc));
imagesc(w_plot')
set(gca,'YDir','normal')
colorbar
title('Mean weights')
%plot policy as a vector field
subplot(2,2,4)
f4=quiver(pc(:,1), pc(:,2), ac(1,:)', ac(2,:)', 'linewidth', 2, 'color', 'black');
axis([-2 2 -2 2])
title('Agent''s policy')
drawnow
%pause
delete(f3)
delete(f4)
end
%%
t_end = T_max;
end
end