function [X_out,idc,p_res] = SetCon_CommonNeighbour_Cross(Nsyn, X_11,X_22,pCon,para_p_input,para_p_output,set_flag,X)
% Generates a random connectivity matrix X[Nout, Nin] across two different
% neuron populations with connection probabilities pCon and implements the
% "common neighbour rule" (Perin et al. 2011, PNAS).
%
% INPUT:
% Nsyn: Number of synapses that are already created
% X11: Connection matrix within the output neurons
% X22: Connection matrix within the input neurons
% pCon: Connection probability
% para_p_input: Common neighbour rule parameters for the input neurons
% para_p_output: Common neighbour rule parameters for the output neurons
% set_flag: Change connection matrix according to the common
% neighbour rule if set_flag==TRUE
% X: Connection matrix between the input and output neurons
%
% OUTPUT:
% X_out: Connection matrix containing idc at non-zero entries
% idc: Synapse indices (starting with Nsyn)
% p_res: Connection probability as a function of the number of
% common neighbours
Nin = length(X_22(1,:));
Nout = length(X_11(:,1));
% Set probabilities according to the number of neighbours
% (derived from Kampa, Letzkus and Stuart, 2006)
if ~exist('para_p_input', 'var') || isempty(para_p_input')
para_p_input = [0.1411 0.1677];
end
if ~exist('para_p_output', 'var')|| isempty(para_p_output')
para_p_output = [-0.1621 0.3514];
end
max_neigh_input = floor((1/pCon-para_p_input(2))/para_p_input(1)); % ensures p <= 1
p_Neighbours_input = para_p_input(1)*(1:max_neigh_input) + para_p_input(2);
max_neigh_output = floor(-para_p_output(2)/para_p_output(1)); % ensures p >= 0
p_Neighbours_output = para_p_output(1)*(1:max_neigh_output) + para_p_output(2);
% Set random connectivity with probability pCon (if not preset)
if ~exist('X', 'var')
k1 = randperm(Nin*Nout);
k = k1(1:round(pCon*Nin*Nout));
X=zeros(Nin,Nout);
idc=Nsyn+(1:length(k));
X(k)=idc;
else
idc = min(X(X>0)) + (1:max(X(X>0))) - 1;
k = X(X>0)';
end;
% Compute the number of common neighbours for each pair in X
[~, ~, neigh_00] = find_neigh(X,X_22,0,0); % neurons projecting to a common neighbour in the output layer
[~, ~, neigh_01] = find_neigh(X,X_22,0,1); % the same, but the output neuron in the pair receives input from the neighbour instead of projecting to it
[~, ~, neigh_10] = find_neigh(X,X_11,1,0); % neurons receiving input from a common neighbour in the input layer
[pair_id, ~, neigh_11] = find_neigh(X,X_11,1,1); % the same, but the input neuron in the pair projects to the neighbour instead of receiving input from it
N_neigh_input = zeros(length(pair_id),1);
N_neigh_output = zeros(length(pair_id),1);
for i=1:length(pair_id)
neigh_act = unique([neigh_00(i,:), neigh_01(i,:)]);
N_neigh_output(i) = length(neigh_act(neigh_act>0));
neigh_act = unique([neigh_10(i,:), neigh_11(i,:)]);
N_neigh_input(i) = length(neigh_act(neigh_act>0));
end;
% Set new matrix, if desired
if ~exist('set_flag', 'var') || set_flag
% Normalize pCon such that p values to add up to original pCon
p = p_calc_cross(X, pair_id, N_neigh_input, N_neigh_output, pCon, p_Neighbours_output'*p_Neighbours_input);
% Select connections randomly according to p
pair_id_selected = [];
for i=1:length(p(:,1,1))
for j=1:length(p(1,:,2))
pair_act = pair_id(N_neigh_output == p(i,j,1) & N_neigh_input == p(i,j,2));
pair_old = k(ismember(k, pair_act));
if length(pair_old) > round(p(i,j,6)*length(pair_act)) % if pair_old is too long, randomly select from it
k1 = randperm(length(pair_old));
pair_id_selected = sort([pair_id_selected, pair_old(k1(1:round(p(i,j,6)*length(pair_act))))]); % apply common neighbour rule
else % if pair_old is too short, use all pairs in it and add random pairs from pair_act
k1 = randperm(length(pair_act));
pair_id_selected = sort([pair_id_selected, pair_act(k1(1:round(p(i,j,6)*length(pair_act))))']); % apply common neighbour rule
end;
end;
end;
% Set connectivity matrix
rand_ind = randperm(length(pair_id_selected)); % randomize indices
X_out=zeros(Nout,Nin);
idc=Nsyn+(1:length(pair_id_selected));
X_out(pair_id_selected)=idc;
X_out(pair_id_selected)=idc(rand_ind);
else
X_out = X;
end;
% Analyse connectivity
[~, ~, neigh_00_res] = find_neigh(X,X_22,0,0); % neurons projecting to a common neighbour in the output layer
[~, ~, neigh_01_res] = find_neigh(X,X_22,0,1); % the same, but the output neuron in the pair receives input from the neighbour instead of projecting to it
[~, ~, neigh_10_res] = find_neigh(X,X_11,1,0); % neurons receiving input from a common neighbour in the input layer
[pair_id_res, ~, neigh_11_res] = find_neigh(X,X_11,1,1); % the same, but the input neuron in the pair projects to the neighbour instead of receiving input from it
N_neigh_input_res = zeros(length(pair_id_res),1);
N_neigh_output_res = zeros(length(pair_id_res),1);
for i=1:length(pair_id_res)
neigh_act = unique([neigh_00_res(i,:), neigh_01_res(i,:)]);
N_neigh_output_res(i) = length(neigh_act(neigh_act>0));
neigh_act = unique([neigh_10_res(i), neigh_11_res(i)]);
N_neigh_input_res(i) = length(neigh_act(neigh_act>0));
end;
p_res = p_calc_cross(X, pair_id, N_neigh_input_res, N_neigh_output_res, pCon, p_Neighbours_output'*p_Neighbours_input);
% -------------------------------------------------------------------------
% --------------------- Auxillary functions -----------------------------
% -------------------------------------------------------------------------
function [pair_id, N_neigh, neigh_out] = find_neigh(X,X_rec,in_flag,switch_flag)
% function to compute the number of common neighbours for each pair of
% neurons connected by X
% in_flag = 1: common input neurons in X_rec (output neurons otherwise)
% switch_flag = 1: direction of X connections is reversed
Nin = length(X(1,:));
Nout = length(X(:,1));
% Find pairs which share at least one common neighbour
if in_flag == 1
if switch_flag == 1
X_act = [X' X_rec]>0;
dim = 2;
else
X_act = [X; X_rec]>0;
dim = 1;
end;
else
if switch_flag == 1
X_act = [X'; X_rec]>0;
dim = 1;
else
X_act = [X X_rec]>0;
dim = 2;
end;
end;
dummy = cumsum(X_act,dim);
pairs = zeros(1,2);
kk=0;
neigh=[];
for i=2:max(max(dummy)) % loop up to maximal number of neighbours of any of the cells
if dim == 1
idx = find(dummy(end,:) == i); % neurons with i outgoing connections
else
idx = find(dummy(:,end) == i); % neurons with i incoming connections
end;
neighbours = zeros(length(idx),i);
for j=1:length(idx)
if dim == 1
neighbours(j,:) = find(X_act(:,idx(j))>0);
else
neighbours(j,:) = find(X_act(idx(j),:)>0);
end;
pairs(kk+1:kk+nchoosek(length(neighbours(j,:)),2),:) = nchoosek(neighbours(j,:),2);
neigh(kk+1:kk+nchoosek(length(neighbours(j,:)),2),:) = idx(j);
kk = kk+nchoosek(length(neighbours(j,:)),2);
end;
end;
% Use only those pairs which cross the layers
if in_flag == 1
if switch_flag == 1
neigh = neigh((pairs(:,1)<=Nout & pairs(:,2)>Nout) | (pairs(:,1)>Nout & pairs(:,2)<=Nout));
pairs = pairs((pairs(:,1)<=Nout & pairs(:,2)>Nout) | (pairs(:,1)>Nout & pairs(:,2)<=Nout),:);
else
neigh = neigh((pairs(:,1)<=Nout & pairs(:,2)>Nout) | (pairs(:,1)>Nout & pairs(:,2)<=Nout));
pairs = pairs((pairs(:,1)<=Nout & pairs(:,2)>Nout) | (pairs(:,1)>Nout & pairs(:,2)<=Nout),:);
end;
else
if switch_flag == 1
neigh = neigh((pairs(:,1)<=Nin & pairs(:,2)>Nin) | (pairs(:,1)>Nin & pairs(:,2)<=Nin));
pairs = pairs((pairs(:,1)<=Nin & pairs(:,2)>Nin) | (pairs(:,1)>Nin & pairs(:,2)<=Nin),:);
else
neigh = neigh((pairs(:,1)<=Nin & pairs(:,2)>Nin) | (pairs(:,1)>Nin & pairs(:,2)<=Nin));
pairs = pairs((pairs(:,1)<=Nin & pairs(:,2)>Nin) | (pairs(:,1)>Nin & pairs(:,2)<=Nin),:);
end;
end;
% Transform neuron indicies and make pair indicies
trafo = 1:Nin+Nout;
if in_flag == 1
if switch_flag == 1
trafo((1:Nin)+Nout) = 1:Nin;
pairs(:,2) = trafo(pairs(:,2))';
pair_id = (pairs(:,2)-1)*Nout + pairs(:,1);
else
trafo((1:Nin)+Nout) = 1:Nin;
pairs(:,2) = trafo(pairs(:,2))';
pair_id = (pairs(:,2)-1)*Nout + pairs(:,1);
end;
else
if switch_flag == 1
trafo((1:Nout)+Nin) = 1:Nout;
pairs(:,2) = trafo(pairs(:,2))';
pair_id = (pairs(:,1)-1)*Nout + pairs(:,2);
else
trafo((1:Nout)+Nin) = 1:Nout;
pairs(:,2) = trafo(pairs(:,2))';
pair_id = (pairs(:,1)-1)*Nout + pairs(:,2);
end;
end;
% Compute number of neighbours
[pair_id, ~, pair_ind_redund] = unique(pair_id);
N_neigh = accumarray(pair_ind_redund,1);
% Determine neighbours for each pair_id
neigh_out = zeros(length(pair_id), max(N_neigh));
for i=1:length(pair_id)
neigh_out(i,1:N_neigh(i)) = neigh(pair_ind_redund==i);
end;
% Merge with random pairs
out_ind = ~ismember(1:(Nin*Nout), pair_id);
N_neigh = [zeros(length(find(out_ind)),1); N_neigh];
neigh_out = [zeros(length(find(out_ind)),max(N_neigh)); neigh_out];
[pair_id, sort_idx] = sort([find(out_ind)'; pair_id]);
N_neigh = N_neigh(sort_idx);
neigh_out = neigh_out(sort_idx,:);
function p = p_calc_cross(X, pair_id, N_neigh_input, N_neigh_output, pCon, p_Neighbours)
% function to analyse connectivity
pair_id_selected = find(X>0);
N_neigh_input_unique = unique(N_neigh_input);
N_neigh_output_unique = unique(N_neigh_output);
N_neigh_input_selected = N_neigh_input(ismember(pair_id, pair_id_selected));
N_neigh_output_selected = N_neigh_output(ismember(pair_id, pair_id_selected));
p = zeros(length(N_neigh_output_unique), length(N_neigh_input_unique), 6);
for i=1:length(N_neigh_output_unique)
for j=1:length(N_neigh_input_unique)
pair_act = pair_id(N_neigh_output == N_neigh_output_unique(i) & N_neigh_input == N_neigh_input_unique(j));
p(i,j,1) = N_neigh_output_unique(i);
p(i,j,2) = N_neigh_input_unique(j);
p(i,j,3) = length(pair_act);
p(i,j,4) = sum((N_neigh_output_selected==N_neigh_output_unique(i)) & (N_neigh_input_selected==N_neigh_input_unique(j)));
if ~isempty(pair_act)
p(i,j,5) = sum((N_neigh_output_selected==N_neigh_output_unique(i)) & (N_neigh_input_selected==N_neigh_input_unique(j))) / length(pair_act);
end;
if N_neigh_output_unique(i)+1>length(p_Neighbours(:,1)) || N_neigh_input_unique(j)+1>length(p_Neighbours(1,:))
p(i,j,6) = 0;
else
p(i,j,6) = pCon*p_Neighbours(N_neigh_output_unique(i)+1, N_neigh_input_unique(j)+1);
end;
end;
end;
% Normalize pCon such that p values to add up to original pCon
if isempty(p_Neighbours)
NN = 0;
MM = 0;
else
NN = min(length(p_Neighbours(:,1)), length(p(:,1,3)));
MM = min(length(p_Neighbours(1,:)), length(p(1,:,3)));
end;
p(:,:,6) = p(:,:,6) / pCon;
pCon = pCon * pCon / sum(sum(pCon*p_Neighbours(1:NN,1:MM) .* p(1:NN,1:MM,3) / sum(sum(p(1:NN,1:MM,3))) ));
p(:,:,6) = min(p(:,:,6) * pCon);
% (c) 2016 J. Hass, L. Hertaeg and D. Durstewitz,
% Central Institute of Mental Health, Mannheim University of Heidelberg
% and BCCN Heidelberg-Mannheim