function [dh, fith, p_fit, r2, spk_fun, outdata] = plotModelSpikeTransform(ah, t, vm, ifr)
% function [dh, fith, p_fit, r2, spk_fun, outdata] = plotModelSpikeTransform(ah, t, vm, ifr)
%
% function assumes that vm and ifr are taken at the same sampling rate
% uses a set of trials, so each vm and spk trial are row vectors where
% each row is a trial.
% Option
linear_thresh = 0; % Makes it return a 3 parameter fit that is the linear threshold relationship, rather than a power law
if (isempty(ah))
figure; axes;
ah = gca;
end
xlabel('Cell Input','FontSize', 12); ylabel('Cell Output','FontSize', 12);
[~, maxi] = max(vm(:));
vm_rest = vm(200:2000, :);
vm_rest = mean(vm_rest(:));
vm = vm - vm_rest;
dt = mean(diff(t)); % the sampling rate of the vectors
bint = 5; %ms, the width of the spike counting bin
step = bint;
stepi = floor(step/dt);
bini = floor(bint/dt);
nbins = floor(length(t)/stepi); %round down to not overrun the vector
ntrials = size(vm,2);
ifr_mean = zeros(nbins, ntrials);
vm_mean = zeros(nbins, ntrials);
vm_bin_w = .5;
%vm_bin_edges = (min(vm(:))-vm_bin_w/2):vm_bin_w:(max(vm(:))+vm_bin_w);
vm_bin_edges = 0:vm_bin_w:(max(vm(:))+vm_bin_w);
vm_bin_centers = vm_bin_edges(1:end-1) + vm_bin_w/2;
ifr_vm = zeros(length(vm_bin_centers),ntrials);
for jj = 1:ntrials %averaging over periods of time.
for ii=1:nbins
mini = (ii-1)*stepi + 1;
maxi = mini + bini;
maxi = min(length(t), maxi); %just in case there is a single element over
ifr_mean(ii,jj) = mean(ifr(mini:maxi, jj));
vm_mean(ii,jj) = mean(vm(mini:maxi, jj));
end
% after averaging a little, we now construct a regularly spaced ifr(vm) relationship for each trial
[~, max_ifri] = max(vm_mean(:,jj));
tifr = ifr_mean(1:max_ifri, jj); %select the rising faze of the Vm
tvm = vm_mean(1:max_ifri, jj);
ifr_vm_temp = NaN*zeros(length(vm_bin_centers),1);
for ii=1:(length(vm_bin_edges)-1)
vmi = find(tvm >= vm_bin_edges(ii) & tvm < vm_bin_edges(ii+1));
if(~isempty(vmi))
ifr_vm_temp(ii) = mean(tifr(vmi)); %use the mean of any ifr's in the bin
end
end
%Now we must deal with any holes that may be in the vector by linearly interpreting the ifr
nani = find(isnan(ifr_vm_temp));
for ii = 1:length(nani)
if (nani(ii) > 1 && nani(ii) < length(ifr_vm_temp)) % deal w/ edge cases separately
ifr_vm_temp(nani(ii)) = (ifr_vm_temp(nani(ii)-1) + ifr_vm_temp(nani(ii)+1))./2; %average the two points on either side
elseif (nani(ii) == 1)
ifr_vm_temp(nani(ii)) = ifr_vm_temp(2) - diff(ifr_vm_temp(2:3));
elseif (nani(ii) == length(ifr_vm_temp))
ifr_vm_temp(end) = ifr_vm_temp(end-1) + diff(ifr_vm_temp((end-2):(end-1)));
end
end
ifr_vm(:,jj) = ifr_vm_temp;
end
%dh = plot(ah, vm_bin_centers, ifr_vm);
mean_ifr_vm = nanmean2(ifr_vm,2);
sd_ifr_vm = nanstd2(ifr_vm, 0, 2);
nani = isnan(mean_ifr_vm);
mean_ifr_vm = mean_ifr_vm(~nani); % eliminate any remaining nans
sd_ifr_vm = sd_ifr_vm(~nani); sd_ifr_vm(sd_ifr_vm < 1) = 1; %set a 1Hz lower bound on SD.
vm_bin_centers = vm_bin_centers(~nani);
%also, we only want to fit those points that are lower than the peak
[maxv, maxi] = max(mean_ifr_vm);
mean_ifr_vm = mean_ifr_vm(1:maxi);
sd_ifr_vm = sd_ifr_vm(1:maxi);
vm_bin_centers = vm_bin_centers(1:maxi);
dh = plot(ah, vm_bin_centers, mean_ifr_vm);
dh(2) = plot(ah, vm_bin_centers, sd_ifr_vm+mean_ifr_vm, ':');
dh(3) = plot(ah, vm_bin_centers, mean_ifr_vm-sd_ifr_vm, ':');
if linear_thresh %a linear threshold!
% 3 fitted parameters
powfitf = @(p,x)p(3).*(rectify(x(:)-p(2)).^p(1)); %parameters 1) exponent, 2) offset 3)scale
p0 = [1 1 1];
else %power law relationship
%2 fitted parameters
powfitf = @(p,x)p(2).*(rectify(x(:)).^p(1)); %parameters 1) exponent, 2)scale
p0 = [1 1];
end
powErr = @(p) sum((powfitf(p,vm_bin_centers) - mean_ifr_vm(:)).^2 ./ (sd_ifr_vm.^2));
options = optimset('TolFun', 1e-12,'TolX', 1e-12, 'TolCon', 1e-12, 'MaxFunEvals', 8000, 'MaxIter', 4000, 'Display', 'final', 'algorithm', 'interior-point');
if linear_thresh
p_fit = fmincon(powErr, p0, [], [], [], [], [1, 0, 0], [1, 50, 100], [], options); %fitting using 3 params
else
p_fit = fmincon(powErr, p0, [], [], [], [], [0, 0], [100, 100], [], options); % fitting for 2
end
yfit = powfitf(p_fit, vm_bin_centers);
fith = plot(ah, vm_bin_centers, yfit, '--k');
r2 = calcR2(mean_ifr_vm, yfit);
spk_fun = powfitf;
outdata.vm = vm_bin_centers;
outdata.ifr_mean = mean_ifr_vm;
outdata.ifr_std = sd_ifr_vm;
% -------------------------------------
function rect_vect = rectify(in_vect)
% function rect_vect = rectify(in_vect)
%
% Support function to rectify (eliminate negative values from) a vector
% ----------------------------------------
rect_vect = in_vect;
rect_vect(rect_vect < 0) = 0;