function [winner,A,O,step_counter] = GPR_engine(saliences,DA_sel,DA_cont,dt,tolerance,max_steps,theta,varargin)
% GPR_ENGINE solution engine for the Gurney, Prescott & Redgrave (2001a,b) BG model
%
% GPR_ENGINE(S,DA1,DA2,DT,TOLERANCE,MAX_STEPS,THETA) where S is an array of salience values for the actions
% represented on the BG channels (the number of channels is implicitly defined by the length of the
% salience array). Dopamine levels in the selection and control pathway are set by the values of DA1 and DA2.
% The time-step is DT, and the model is run until either the change in activation of all units is less than TOLERANCE
% or MAX_STEPS has been reached. If the output of any GPi channel is below THETA then that channel is considered
% selected. Returns: an array of the winning action(s) (channel(s)), or the empty matrix [] if no winner
%
% GPR_ENGINE(...,SWITCH) where SWITCH = 'hard' enforces hard switching so that a maximum of one selection is
% made. When more than one GPi unit's output is below THETA then the channel of the lowest is returned, else [] is returned.
% Where SWITCH = 'gate', returns a vector containing the proportion
% of output below THETA for each channel, where 0 indicates that the
% output is above THETA, and 1 indicates no output.
%
%
% [W,nA,nO,STEPS] = GPR_ENGINE(...,A,O) are the matrices of activations A and outputs O of
% all the units from the previous competition. By column: [SD1 SD2 STN
% GPe GPi]. W is the array of winner(s) if any. Specifying nA and nO returns the corresponding matrices from
% the current simulation to re-use as arguments for the next one. Put
% SWITCH = [] if no need to specify this parameter. Will also return
% STEPS, the number of steps to convergence.
%
% GPR_ENGINE(...,FLAG) any combination of the following options creates (set A=[], O=[] if not required):
% 'g': includes the connections from the Gurney et al. (2004) Network
% paper too
%
% 'd': includes the new DA model from my technical report: Humphries, M.D. (2003). High level
% modeling of dopamine mechanisms in striatal neurons. ABRG 3. Dept. Psychology University of Sheffield, UK.
%
% Note#1: it is critical that saliences are input at time zero! This
% condition ensures that striatal cells have non-zero activation changes,
% and thus convergence does not occur on the first time-step.
%
% Note#2: includes weights of additional connections from Gurney et al
% (2004) Network paper
%
% Note#3: if used, the parameters for the new DA model are taken from the
% optimally-performing model as described in the technical report. However, this
% included a slightly elevated dopamine level for the GPR model (DA=0.3)
% which should be specified in this function call if required.
%
% Note#4: the model is solved using a zero-order hold method. As tau = 1/k = 0.04,
% so DT < 0.04 is required to at least ensure the possibility of an accurate simulation.
%
% Author: Mark Humphries 21/1/2005
%%% MODEL PARAMETERS
NUM_CHANNELS = length(saliences);
%% weight values as defined by GPR
W_SEL = 1;
W_CONT = 1;
W_STN = 1;
W_SEL_GPi = -1;
W_CONT_GPe = -1;
W_STN_GPi = 0.9;
W_STN_GPe = 0.9;
W_GPe_STN = -1;
W_GPe_GPi = -0.3;
%% additional weights are zero by default
W_GPi_GPi = 0;
W_GPe_GPe = 0;
W_SEL_GPe = 0;
%% thresholds as defined by GPR
e_SEL = 0.2;
e_CONT = 0.2;
e_STN = -0.25;
e_GPe = -0.2;
e_GPi = -0.2;
%%% INITIALISE ARRAYS
% activity arrays
A = zeros(NUM_CHANNELS,5);
old_A = zeros(NUM_CHANNELS,5);
delta_a = ones(NUM_CHANNELS,5);
% output arrays
O = zeros(NUM_CHANNELS,5); % 1 = SD1, 2 = SD2, 3 = STN, 4 = GPe, 5 = GPi
%%% Optional parameters %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% set defaults
type = 'soft';
blnNewDA = 0;
gain_DA_sel = DA_sel;
gain_DA_cont = DA_cont;
% set options
if nargin >= 8 & ~isempty(varargin{1}) type = varargin{1}; end
if nargin >= 9 & ~isempty(varargin{2})
[rA cA] = size(varargin{2});
if cA ~= 5 error('Activation matrix must have 5 columns for GPR model'); end
if rA ~= NUM_CHANNELS error('Activation matrix must have the same number of rows as specified saliences'); end
A = varargin{2};
end
if nargin >= 10 & ~isempty(varargin{3})
[rO cO] = size(varargin{3});
if cO ~= 5 error('Output matrix must have 5 columns for GPR model'); end
if rO ~= NUM_CHANNELS error('Output matrix must have the same number of rows as specified saliences'); end
O = varargin{3};
end
if nargin >= 11 ~isempty(varargin{4})
if findstr(varargin{4},'g') % include weights from Gurney et al (2004) Network paper
W_GPi_GPi = -0.2;
W_GPe_GPe = -0.2;
W_SEL_GPe = -0.25;
end
if findstr(varargin{4},'d') % include new DA model
gain_DA_sel = 0; % set gain DA to zero
gain_DA_cont = 0;
blnNewDA = 1; % set flag to use new output functions
% parameter values from tech. report
e_SEL = 0.1;
e_CONT = 0.1;
gain_SEL = 0.8;
gain_CONT = 0.8;
pivot = 0.1;
end
end
%%% ARTIFICAL UNIT PARAMETERS %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
k = 25; % gain
m = 1; % slope
decay_constant = exp(-k*dt);
%%% SIMULATE MODEL
step_counter = 0;
[row col] = size(saliences);
if row < col
c = saliences';
else
c = saliences;
end
int_vec = ones(NUM_CHANNELS,1);
while step_counter < max_steps & sum(sum(delta_a > tolerance)) > 0
step_counter = step_counter + 1;
old_A = A;
%% calculate salience changes
%% STRIATUM D1
u_SEL = c .* W_SEL .* (1 + gain_DA_sel);
A(:,1) = (A(:,1) - u_SEL) * decay_constant + u_SEL;
%% STRIATUM D2
u_CONT = c .* W_CONT .* (1 - gain_DA_cont);
A(:,2) = (A(:,2) - u_CONT) * decay_constant + u_CONT;
%% STN
u_STN = c .* W_STN + O(:,4) .* W_GPe_STN;
A(:,3) = (A(:,3) - u_STN) * decay_constant + u_STN;
%% GPe
temp = (sum(O(:,4)) .* int_vec) - O(:,4); %% removes own input from each summed value
u_GPe = sum(O(:,3)) .* W_STN_GPe + O(:,2) .* W_CONT_GPe + O(:,1) .* W_SEL_GPe + temp .* W_GPe_GPe;
A(:,4) = (A(:,4) - u_GPe) * decay_constant + u_GPe;
%% GPi
temp = (sum(O(:,5)) .* int_vec) - O(:,5); %% removes own input from each summed value
u_GPi = sum(O(:,3)) .* W_STN_GPi + O(:,4) .* W_GPe_GPi + O(:,1) .* W_SEL_GPi + temp .* W_GPi_GPi;
A(:,5) = (A(:,5) - u_GPi) * decay_constant + u_GPi;
%% calculate outputs
if blnNewDA
O(:,1) = DA_ramp_output(A(:,1),e_SEL,m,DA_sel,1,gain_SEL,pivot)';
O(:,2) = DA_ramp_output(A(:,2),e_CONT,m,DA_cont,2,gain_CONT)';
else
O(:,1) = ramp_output(A(:,1),e_SEL,m)';
O(:,2) = ramp_output(A(:,2),e_CONT,m)';
end
O(:,3) = ramp_output(A(:,3),e_STN,m)';
O(:,4) = ramp_output(A(:,4),e_GPe,m)';
O(:,5) = ramp_output(A(:,5),e_GPi,m)';
delta_a = abs(A - old_A);
end
winner = [];
switch type
case 'soft'
winner = find(O(:,5) < theta);
case 'hard'
temp = find(O(:,5) < theta);
if ~isempty(temp) winner = find(O(:,5) == min(O(temp,5))); end
if length(winner) > 1
winner = []; % can only return one winner
end
case 'gate'
winner = zeros(NUM_CHANNELS,1);
output = O(:,5);
idxs = find(output < theta);
winner(idxs) = (theta - output(idxs)) / theta;
end