function [y stimcurr hcurr r] = simulate_glm(x,dt,k,h,dc,runs,SpikeTrains,BasisTab2,softRect,ind_plot)
% [y stimcurr hcurr r] = simulate_glm(x,dt,k,h,dc,runs)
%
%  --- Code Modified from Weber & Pillow (2017) ---:
%  Weber AI, Pillow JW. Capturing the Dynamical Repertoire of Single Neurons 
%  with Generalized Linear Models. Neural Comput. 2017 (12):3260-3289.

%  This code fits a Poisson GLM to given data, using basis vectors to
%  characterize the stimulus and post-spike filters.
%
%  The inputs are:
%   x: stimulus
%   dt: time step of x and y in ms
%   k: stimulus filter
%   h: post-spike filter
%   dc: dc offset
%   runs: number of trials to simulate
%   softRect: 0 uses exponential nonlinearity; 1 uses soft-rectifying nonlinearity
%   ind_plot: 0 for plotting, 1 for no plotting
%
%  The outputs are:
%   y: spike train (0s and 1s)
%   stimcurr: output of stimulus filter (without DC current added)
%   hcurr: output of post-spike filter
%   r: firing rate (stimcurr + hcurr + dc passed through nonlinearity)

%% generate data with fitted GLM
spikebinary = SpikeTrains;
nTimePts = length(x);
refreshRate = 1000/dt; % stimulus in ms, sampled at dt
if softRect
    NL = @logexp1;
else
    NL = @exp;
end

g = zeros(nTimePts+length(h),runs);     % filtered stimulus + dc
y = zeros(nTimePts,runs);               % initialize response vector (pad with zeros in order to convolve with post-spike filter)
r = zeros(nTimePts+length(h)-1,runs);   % firing rate (output of nonlinearity)
hcurr = zeros(size(g));     % post-spike current

stimcurr = sameconv(x,k);   % convolving new stimulus x, with fitted k filter
noise = 0 + 0.1.*randn(length(stimcurr),1);  % inject noise into filtered stimulus
Iinj = stimcurr + dc ; %+ noise;       % injected current includes DC drive

for runNum = 1:runs
    
    g(:,runNum) = [Iinj; zeros(length(h),1)]; 
    
    %%% loop to get responses, incorporate post-spike filter
    for t = 1:nTimePts
        r(t,runNum) = feval(NL,g(t,runNum));  % firing rate (output of nonlinearity)
        prob_0(t,runNum) = exp(-r(t,runNum)/refreshRate); % P(0 spikes) --- probability of 0 spikes in this time bin
        bern_prob(t,runNum) = 1-prob_0(t,runNum); % 1 - P(0 spikes) --- bernoulli 
        
%       % noise generation (random walk -- gaussian centered at 0)
%         if t == 1
%             rand_num(t,runNum) = 0 + 0.005*randn;
%         elseif t>1
%             rand_num(t,runNum) = rand_num(t-1,runNum) + (0 + 0.005*randn);
%         end
        
        % spike generation: 
        if 0.001+rand<bern_prob(t,runNum)     % if probability of spiking is greater than randomly drawn number (ADDED 0.001 SHIFT TO GET RID OF NOISE!!)
           y(t,runNum) = 1;
           g(t:t+length(h)-1,runNum) = g(t:t+length(h)-1,runNum) + h;  % add post-spike filter
           hcurr(t:t+length(h)-1,runNum) = hcurr(t:t+length(h)-1,runNum) + h;
           %rand_num(t,runNum) = 0; % reset random number walk back to 0
        end
    end
end

hcurr = hcurr(1:nTimePts,:);  % trim zero padding for post-spike current
r = r(1:nTimePts,:);  % trim zero padding for firing rate

%% plot the results 
% % time
minT = 1/dt;
maxT = length(x);
tIdx = minT:maxT;
t = (tIdx-minT)*dt;
% % 
% figure('Position', [10 10 1200 650])
% %%% stimulus
% yo(1) = subplot(5,1,1); hold on;
% plot(t,x(tIdx),'color','k','linewidth',2)
% xlim([min(t) max(t)])
% ylim([min(x(tIdx))-.05*abs(min(x(tIdx))) max(x(tIdx))+.05*abs(max(x(tIdx)))])
% box off
% title('stimulus')
% 
% %%% filter outputs
% yo(2) = subplot(5,1,2); hold on;
% plot(t,Iinj(tIdx),'r','linewidth',1.5);   % plotting the stimulus current (+dc)
% plot(t,hcurr(tIdx,1),'b','linewidth',1.5) % plotting the post-spike current
% plot(t,g(tIdx,1),'g','linewidth',1.5)
% xlim([min(t) max(t)])
% ylim([min([hcurr(tIdx,1); Iinj(tIdx)])*1.1 max([hcurr(tIdx,1); Iinj(tIdx)])*1.1])
% box off
% title('filter outputs')
% 
% %%% firing rate (lambda)
% yo(3) = subplot(5,1,3); hold on;
% semilogy(t,feval(NL,hcurr(tIdx,1)+Iinj(tIdx)),'color',[.5 .5 .5],'linewidth',1.5)
% xlim([min(t) max(t)])
% box off
% title('lambda (conditional intensity/IFR)')
% 
% %%% probability of spiking (between 0 and 1)
% yo(6)=subplot(5,1,4);
% plot(t,prob_0(tIdx),'color',[0.5,0.5,0.5]);
% title('P(spike|lambda) = 1 - exp(-lambda/refresh rate)')
% ylim([0,1])
% xlim([min(t) max(t)])
% 
% %%% GLM spikes
% yo(4) = subplot(5,1,5); hold on;
% spikeHeight = .7;
% 
% for i = 1:size(y,2) % for each run of glm simulation
%     spt = find(y(tIdx,i));
%     for spikeNum = 1:length(spt)
%         plot([spt(spikeNum)*dt spt(spikeNum)*dt],[i-.5 i-.5+spikeHeight],'color',[.5 .5 .5],'linewidth',1.25)
%     end
% end
% 
% for f = 1:size(spikebinary,2)   % for each real spike
%     spt = find(spikebinary(tIdx,f));
%     for spikeNum = 1:length(spt)
%         plot([spt(spikeNum)*dt spt(spikeNum)*dt],[i+f+.5 i+f+.5+spikeHeight],'color',[0 0 0],'linewidth',1.25)
%     end
% end
% 
% xlim([0 max(t)-min(t)])
% ylim([0 runs+spikeHeight+2])
% xlabel('time (ms)')
% title('spikes -- if P>rand, a spike occurs')
% linkaxes(yo,'x')
   

%% Second Figure:
%%% plot filters and compare real to model spikes
if ind_plot == 0
    axisLabelFontSize = 12;
    axisTickLabelFontSize = 12;
    axisWidth = 1;
    
    fig=figure; fig.Position=[10 10 1200 650];
    subplot(3,4,8); hold on;
    h1 = gca;
    plot([0 length(k)],[0 0],'k--','linewidth',1.5);
    plot(k,'b','linewidth',2);
    set(gca,'xtick',0:length(k)/4:length(k),'xticklabel',round(-length(k)*dt:length(k)/4*dt:0))
    set(gca,'tickdir','out','linewidth',axisWidth,'fontsize',axisTickLabelFontSize)
    text(length(k)/15,max(k)-.05*(max(k)-min(k)),['\mu = ' num2str(round(dc*10)/10)],'fontsize',axisLabelFontSize);
    xlim([-5 length(k)])
    ylim([min(k)-.05*(max(k)-min(k)) max(k)+.05*(max(k)-min(k))])
    box off
    h1p = get(h1,'position');
    %set(h1,'position',[h1p(1) h1p(2)*1.05 h1p(3)*.9 h1p(4)*.95])
    
    subplot(3,4,12); hold on;
    plot([0 length(h)],[0 0],'k--','linewidth',1.5);
    plot(h(2:end),'r','linewidth',2);
    h2 = gca;
    set(gca,'tickdir','out','xtick',0:length(h)/4:length(h),'xticklabel',round(0:(length(h)/4*dt):length(h)*dt))
    set(gca,'linewidth',axisWidth,'fontsize',axisTickLabelFontSize)
    xlabel('time (ms)','fontsize',axisLabelFontSize)
    xlim([-5 length(h)])
    ylim([min(h)-.05*(max(h)-min(h)) max(h)+.05*(max(h)-min(h))])
    box off
    h2p = get(h2,'position');
    %set(h2,'position',[h2p(1) h2p(2)*.95 h2p(3)*.9 h2p(4)*.95])
    
    fig(1)=subplot(3,4,1:3); 
    plot(t,x(tIdx),'color','k','linewidth',2)
    %xlim([55400 56600])
    xlim([min(t) max(t)])
    ylim([min(x(tIdx))-.05*abs(min(x(tIdx))) max(x(tIdx))+.05*abs(max(x(tIdx)))])
    box off
    title('stimulus')
    
    fig(2)=subplot(3,4,[5:7,9:11]); hold on;
    spikeHeight = .7;
    
    for i = 1:size(y,2) % for each run of glm simulation
        spt = find(y(tIdx,i));
        for spikeNum = 1:length(spt)
            plot([spt(spikeNum)*dt spt(spikeNum)*dt],[i-.5 i-.5+spikeHeight],'color',[.5 .5 .5],'linewidth',1.25)
        end
    end
    
    for f = 1:size(spikebinary,2)   % for each real spike
        spt = find(spikebinary(tIdx,f));
        for spikeNum = 1:length(spt)
            plot([spt(spikeNum)*dt spt(spikeNum)*dt],[i+f-.5 i+f-.5+spikeHeight],'color',[0 0 0],'linewidth',1.25)
        end
    end
    
    ylim([0 runs+spikeHeight+1])
    %xlim([55400 56600])
    xlim([0 max(t)])
    xlabel('time (ms)')
    title('spikes (black=real, grey=GLM)')
    linkaxes(fig,'x')
    
    uit = uitable('Data', table2cell(BasisTab2),'ColumnName',BasisTab2.Properties.VariableNames,...
        'Units', 'Normalized', 'Position',[0.72,0.75,0.2,0.15]);
    set(uit,'ColumnWidth',{50},'FontSize',13)
end