% This function takes two complete set of input trains as inputs and a
% mixing probability.
%
% First all the spikes in the complete input set for each train is
% shuffled around.
%
% Then spikes are tagged with probability pA or (1 - pA) for inclusion
%
% Then all the tagged spikes are merged together into a new complete
% set of spikes, that is returned by the function.
%
% This 
%

function outTrain = mixTwoTrainsKeepCorr(trainA, trainB, pA)

%DEN HÄR FUNKTIONEN GÖR FEL, TESTA MED pA = 0 eller 1

pB = 1 - pA;

allSpikesA = sort(trainA(:));
allSpikesB = sort(trainB(:));

allSpikesA(find(allSpikesA == inf)) = [];
allSpikesB(find(allSpikesB == inf)) = [];

uSpikesA = unique(allSpikesA);
uSpikesB = unique(allSpikesB);

uMaskA = find(rand(size(uSpikesA)) < pA);
uMaskB = find(rand(size(uSpikesB)) < pB);

% Number the spikes, repetitions of same spike get same number

idxSpikesA = NaN*ones(size(allSpikesA));
idxSpikesA(1) = 1; tol = 1e-8;

for i=2:length(allSpikesA)
  if(abs(allSpikesA(i) - allSpikesA(i-1)) < tol)
    idxSpikesA(i) = idxSpikesA(i-1); % Same as previous spike, keep idx
  else
    idxSpikesA(i) = idxSpikesA(i-1) + 1; % Increment counter if new spike  
  end
end

% Do same for spikes in trainB

idxSpikesB = NaN*ones(size(allSpikesB));
idxSpikesB(1) = 1; tol = 1e-8;

for i=2:length(allSpikesB)
  if(abs(allSpikesB(i) - allSpikesB(i-1)) < tol)
    idxSpikesB(i) = idxSpikesB(i-1); 
  else
    idxSpikesB(i) = idxSpikesB(i-1) + 1;
  end
end

nTrains = size(trainA,2);

% Create storage for the resulting spike vectors
for i = 1:nTrains
  tSpik{i} = [];
end

% We use freeTrains to make sure that two repetitions of the same
% spike does not come in the same input train

for i = 1:length(uMaskA)
  keepSpikes = allSpikesA(find(uMaskA(i) == idxSpikesA));

  freeTrains = 1:nTrains;

  for j = 1:length(keepSpikes)
    idx = ceil(length(freeTrains)*rand(1));
    trainIdx = freeTrains(idx);
    freeTrains(idx) = [];

    tSpik{trainIdx} = [tSpik{trainIdx}; keepSpikes(j)];      
  end
end

for i = 1:length(uMaskB)
  keepSpikes = allSpikesB(find(uMaskB(i) == idxSpikesB));

  freeTrains = 1:nTrains;

  for j = 1:length(keepSpikes)
    idx = ceil(length(freeTrains)*rand(1));
    trainIdx = freeTrains(idx);
    freeTrains(idx) = [];

    tSpik{trainIdx} = [tSpik{trainIdx}; keepSpikes(j)];      
  end
end

% Convert the tSpike cell array to a matrix, pad with inf.

maxLen = 0;

for i=1:nTrains
  maxLen = max(maxLen, length(tSpik{i}));    
end

outTrain = inf*ones(maxLen, nTrains);

for i=1:nTrains
  outTrain(1:length(tSpik{i}),i) = tSpik{i}; 
end

outTrain = sort(outTrain);