clc 
clear all
close all
%% Choose which output to plot
helix_2d = 1;
helix_down_layer1 = 0;
helix_up_layer1 = 0;
pegboard_2d_layer1 = 0;
pegboard_3d_layer1 = 0;

%% load the files
if helix_2d == 1
    out = readtable('encoded_helix_t20p1_b1.5pi_2d_train.csv');
    traj_data = readtable('traj_wo_20k.csv');
end
if helix_down_layer1 == 1
    out = readtable('encoded_helix_t20p1_b1.5pi_3d_test_down.csv');
    traj_data = readtable('Trajectory_fivecoil_down.csv');
end
if helix_up_layer1 == 1
    out = readtable('encoded_helix_t20p1_b1.5pi_3d_test_up.csv');
    traj_data = readtable('Trajectory_fivecoil_up.csv');
end
% if pegboard_2d_layer1 == 1
%     out = readtable('encoded_Pegboard_2d_l1.csv');
%     traj_data = readtable('traj_peg2d_50k.csv');
% end
% if pegboard_3d_layer1 == 1
%     out = readtable('encoded_Pegboard_2d3d_l1.csv');
%     traj_data = readtable('Trajectory_interpolated_pegboard.csv');
% end

out = table2array(out);
traj_data = traj_data(:,1:3);
% if pegboard_2d_layer1
%     traj_data = traj_data(:,1:2);
% end
% if pegboard_3d_layer1
%     traj_data = traj_data(:,2:3);
% end
pos = table2array(traj_data);
%% plotting output
if helix_2d || helix_down_layer1 || helix_up_layer1 
    grid_cells = [2,3,29,35,39,40,42,44,47];
    place_cells = [1, 6, 8, 15, 18, 19, 20, 21, 23, 24, 27, 31, 34, 36, 41, 43, 45, 46, 48];
% else
%     grid_cells = [ 4  8 13 24 28 30 34 35 37 42 48];
%     place_cells = [ 1 2 3 5 9 10 11 12 14 15 16 19 21 22 26 27 29 33 36 39 40 43 44 46 47 48 49 50];
end
neuron_minor = [];
neuron_major = [];
for n = place_cells
    n
    
    ot = out(:,n);
    ot_mean = mean(ot);
    ot_std = std(ot);
    thresh = ot_mean + 1.5*ot_std;
    %ot = out(n,:);  
    %thresh=max(ot)*lim;  
    firr=find((ot)>thresh); 
    firposgrid=pos(firr,:); %Firing positions on the trajectory
    if helix_down_layer1 || helix_up_layer1 || helix_2d
        res = 25;
        x_mesh = -1:1/res:1;
        y_mesh = -1:1/res:1; 
    end
    if pegboard_2d_layer1 || pegboard_3d_layer1
        res = 4;
        x_mesh = 0:1/res:10;
        y_mesh = 0:1/res:10;
    end
    [fx,fy] = meshgrid(x_mesh,y_mesh);

    firingmap = zeros(length(fx));
    gridpoint = [reshape(fx,prod(size(fx)),1) reshape(fy,prod(size(fy)),1)];
    roundinggridpoint = round(gridpoint);
    firposround = round(firposgrid); 
    firingvalue = abs(ot(firr));
    for ii = 1:length(firposgrid(:,1))
        [~,q1]=min(abs(firposgrid(ii,1)-fx(1,:)));
        [~,q2]=min(abs(firposgrid(ii,2)-fx(1,:)));
        firingmap(q1,q2) = firingvalue(ii);
    end
    firingmap = firingmap/max(max(firingmap));
    gaussian = fspecial('gaussian',[10 10],3); % Design the gaussian filter for smoothing the map
    spikes_smooth=conv2(gaussian,firingmap);
    spikes_smooth = imrotate(spikes_smooth/max(max(spikes_smooth)),90);
    spikes_smooth = spikes_smooth(end-length(y_mesh)-5:end,:);
    
    majorAxis = [];
    minorAxis = [];
    threshold = 0.1 * max(max(firingmap));
    binaryMap = firingmap > threshold;
    stats = regionprops(binaryMap, 'Area', 'Centroid', 'MajorAxisLength', 'MinorAxisLength', 'Orientation');
    for i = 1:length(stats)
        majorAxis = [majorAxis stats(i).MajorAxisLength]
        minorAxis = [minorAxis stats(i).MinorAxisLength];
    end
    neuron_major = [neuron_major majorAxis];
    neuron_minor = [neuron_minor minorAxis];

clf;
figure(1);
subplot(2,1,1);plot(pos(:,1),pos(:,2),'b');
axis equal;
title(['Neuron ',num2str(n)], 'Fontsize', 20);
hold on;
plot(firposgrid(:,1),firposgrid(:,2),'.r', 'markersize', 10);axis off
subplot(2,1,2);imagesc(spikes_smooth); axis off
colormap(jet)
axis equal;
savefig(['place_2D', num2str(n)]);
saveas(gcf, ['place_2D', num2str(n), '.png'])
% saveas(gcf,sprintf('place_helixD_%d.png',n))
% saveas(gcf,sprintf('place_helixD_%d.fig',n))
% pause;
close(gcf);
end
% %% firing rate maps 
% if pegboard_2d_layer1
% hex = 0;
% squ = 0;
% for n = 1:50
%     ot = out(:,n); 
%     %ot = out(:,n);
%     ot_mean = mean(ot);
%     ot_std = std(ot);
%     thresh = ot_mean + 1.5*ot_std;
%     %ot = out(n,:);  
%     %thresh=max(ot)*lim;  
%     firr=find((ot)>thresh); 
%     %thresh=max(ot)*.80; 
%     %firr=find((ot)>thresh);
%     firposgrid=pos(firr,:); %Firing positions on the trajectory
% 
%     res = 10; %resolution of the rate map
%     fx = meshgrid(0:1/res:10);
%     fy = meshgrid(0:1/res:10);
%     firingmap = zeros(length(fx));
%     gridpoint = [reshape(fx,prod(size(fx)),1) reshape(fy,prod(size(fx)),1)];
%     roundinggridpoint = round(gridpoint);
%     firposround = round(firposgrid);
%     firingvalue = abs(ot(firr));
%     for ii = 1:length(firposgrid)
%         [~,q1]=min(abs(firposgrid(ii,1)-fx(1,:)));
%         [~,q2]=min(abs(firposgrid(ii,2)-fy(1,:)));
%         firingmap(q1,q2) = firingvalue(ii);
%     end
%     firingmap = firingmap/max(max(firingmap));
%     gaussian = fspecial('gaussian',[10 10],3); % Design the gaussian filter for smoothing the map
%     spikes_smooth=conv2(gaussian,firingmap);
%     figure(1); imagesc(imrotate(spikes_smooth/max(max(spikes_smooth)),90)); axis off
%     colormap(jet)
%     title(['Rate map of neuron ',num2str(n)]);
%  
% % Autocorrelation
% Rxy = correlation_map(spikes_smooth,spikes_smooth);
% figure(2); imagesc(Rxy); axis off; colormap(jet)
% title(['Autocorrelation map of neuron ',num2str(n)]);
% % Grid scale computation
% % Computing the central peak index of the autocorrelation map
% c = Rxy;
% [s1,s2] = size(c);
% centralpeakindex = 0.5*[s1+1 s2+1]; cpx = centralpeakindex(1); cpy = centralpeakindex(2);
% cc=imregionalmax(c);
% [peakx,peaky] = find(cc ==1); peakmatrix = [peakx peaky];
% centralpeakrepmat = repmat(centralpeakindex,length(peakx),1);
% % Minimum and median grid scale computation
% distfromcentralpeak = diag(pdist2(peakmatrix,centralpeakrepmat));
% if distfromcentralpeak == 0
%     continue;
% end
% [distascending,index] = sort(distfromcentralpeak);
% gridscale = distascending(2); %Minimum grid scale
% mediangridscale = median(distascending(2:min(6,length(distascending)))); %Medial grid scale
% % Grid score computation 
% c=Rxy;
% radius1 = max(distascending(2:min(6,length(distascending)))) + 3;
% [p1 q1] = meshgrid(1:length(c));
% circlemask1 = sqrt((p1-centralpeakindex(1)).^2 + (q1-centralpeakindex(2)).^2) <= radius1;
% c2 = c; c2 = c2.*circlemask1;
% figure(3); imagesc(c2)
% colormap(jet)
% title('Masking the outer peaks of the autocorrelation map')
% axis off
% radius2 = 12;
% [p2 q2] = meshgrid(1:length(c));
% circlemask2 = sqrt((p2-centralpeakindex(1)).^2 + (q2-centralpeakindex(2)).^2) <= radius2;
% circlemask2 = not(circlemask2);
% c3 = c2.*circlemask2;
% figure(4); imagesc(c3)
% colormap(jet)
% title('Masking the central peak of the autocorrelation map')
% axis off
% hgs = gridscore(c3); %Hexagonal Gridness Score of the neuron
% sgs = squaregridscore(c3); %Square Gridness Score of the neuron
% if hgs > 0 && sgs < 0
%     hex = hex + 1;
%     n
% end
% if sgs > 0 && hgs < 0
%     squ = squ + 1;
% end
% figure(5);plot(-1:0.1:1,zeros(21),'k',zeros(21),-1:0.1:1,'k',hgs,sgs,'*r','markersize', 10);%hold on;
% title('HGS and SGS scores','FontSize',20);
% hold on;
% %pause;
% end
% end