clc
clear all
close all
%%   
% traj=csvread('D:\Papers\2D paper data\_Submission_files\code and models\trajectories\trajTRIthetavert_20k.csv');
% ot=csvread('D:\Papers\2D paper data\_Submission_files\code and models\experiemental studies\disto codes\encoded_triangular_beta_4_l2.csv');
traj=csvread('E:\WindowsCopy260623\GDrive\3d-paperII\EJN-20230703T083853Z-001\EJN\Revision2\Helix\traj_wo_20k.csv');
out=csvread('E:\WindowsCopy260623\GDrive\3d-paperII\EJN-20230703T083853Z-001\EJN\Revision2\Helix\encoded_helix_t20p1_b1.5pi_2d_train.csv');

%testcsv=csvread('encoded1_test.csv')
% ns_1=table2array(spike_file);
pos=[traj(:,1),traj(:,2)];

%% Getting firing fields of the neuron 3,5,22,25,31, 37, 42 
%for v = [3 5 22 25 31 37 42]
grid_cells = [2,3,29,35,39,40,42,44,47];
place_cells = [1, 6, 8, 15, 18, 19, 20, 21, 23, 24, 27, 31, 34, 36, 41, 43, 45, 46, 48];
neuron_minor = [];
neuron_major = [];
for neuron_number = place_cells   
    %neuron_number=3; % Select the neuron number
    %w=v(:,end-neuron_number);  %Weights from PI to the selected SC neuron
    %ot=w'*piosc_thresh;  %Output activity of the SC neuron
    %ot=ot';
    thresh_std=1.5;
    ot = out(:,neuron_number);
    ot_mean = mean(ot);
    ot_std = std(ot);
    thresh = ot_mean + thresh_std*ot_std;
    firr = find(out(:,neuron_number)>thresh);
    % thresh=max(ot)*.7;
    %firr=find((ot)>thresh)98
    % firr=find(ot(1:20000,neuron_number)>thresh(neuron_number));
    firposgrid=pos(firr,:); %Firing positions on the trajectory
    figure;f = plot(pos(:,1),pos(:,2)); hold on; plot(firposgrid(:,1),firposgrid(:,2),'.r', 'markersize', 15); 
    %axis off
    %title(sprintf('test neuron %d.png',v))
    %saveas(f,sprintf('plots/response_test_%d.png',v))
    %end
        %%% Rate map
    res = 20; %resolution of the rate map
    x = -1:1/res:1;
    y = -1:1/res:1;
    [fx,fy] = meshgrid(x,y);
    %[fx,fy] = meshgrid(0:1/res:4,0:1/res:1);
    firingmap = zeros(length(fx));
    %firingmap = zeros(36, 141);

    gridpoint = [reshape(fx,prod(size(fx)),1) reshape(fx,prod(size(fx)),1)];
    roundinggridpoint = round(gridpoint);
    firposround = round(firposgrid);
    firingvalue = abs(ot(firr));
    for ii = 1:length(firposgrid)
        [~,q1]=min(abs(firposgrid(ii,1)-fx(1,:)));
        [~,q2]=min(abs(firposgrid(ii,2)-fy(:,1)));
        firingmap(q1,q2) = firingvalue(ii);
    end
    firingmap = firingmap/max(max(firingmap));
    gaussian = fspecial('gaussian',[10 10],3); % Design the gaussian filter for smoothing the map
    spikes_smooth=conv2(gaussian,firingmap);
    spikes_smooth = spikes_smooth(end - length(fy) - 5:end,:);
    figure; imagesc(imrotate(spikes_smooth/max(max(spikes_smooth)),90)); axis off
    colormap(jet)
    title('Rate map of the selected neuron')
    %%%
    majorAxis = [];
    minorAxis = [];
    threshold = 0.1 * max(max(firingmap));
    binaryMap = firingmap > threshold;
    stats = regionprops(binaryMap, 'Area', 'Centroid', 'MajorAxisLength', 'MinorAxisLength', 'Orientation');
    for i = 1:length(stats)
        majorAxis = [majorAxis stats(i).MajorAxisLength];
        minorAxis = [minorAxis stats(i).MinorAxisLength];
    end
    neuron_major = [neuron_major majorAxis];
    neuron_minor = [neuron_minor minorAxis];
    %pause;
end

%% 
hex = 0;
squ = 0;
neu_hex = [];
neu_squ = [];
neu_hgs = [];
neu_sgs = [];
% grid_cells = [2,3,29,35,39,40,42,44,47];
for n = 1:50
%     n
% for n = grid_cells
    n
    ot = out(:,n); 
%     thresh=max(ot)*.80; 
    ot_mean = mean(ot);
    ot_std = std(ot);
    thresh = ot_mean + 1.5*ot_std;
    firr=find((ot)>thresh);
    firposgrid=pos(firr,:); %Firing positions on the trajectory

    res = 25; %resolution of the rate map
    fx = meshgrid(-1:1/res:1);
    fy = meshgrid(-1:1/res:1);
    firingmap = zeros(length(fx));
    gridpoint = [reshape(fx,prod(size(fx)),1) reshape(fy,prod(size(fx)),1)];
    roundinggridpoint = round(gridpoint);
    firposround = round(firposgrid);
    firingvalue = abs(ot(firr));
    for ii = 1:length(firposgrid)
        [~,q1]=min(abs(firposgrid(ii,1)-fx(1,:)));
        [~,q2]=min(abs(firposgrid(ii,2)-fy(1,:)));
        firingmap(q1,q2) = firingvalue(ii);
    end
    firingmap = firingmap/max(max(firingmap));
    gaussian = fspecial('gaussian',[10 10],3); % Design the gaussian filter for smoothing the map
    spikes_smooth=conv2(gaussian,firingmap);
    figure(1); imagesc(imrotate(spikes_smooth/max(max(spikes_smooth)),90)); axis off
    colormap(jet)
    title(['Rate map of neuron ',num2str(n)]);
 
% Autocorrelation
Rxy = correlation_map(spikes_smooth,spikes_smooth);
figure(2); imagesc(Rxy); axis off; colormap(jet)
title(['Autocorrelation map of neuron ',num2str(n)]);
% Grid scale computation
% Computing the central peak index of the autocorrelation map
c = Rxy;
[s1,s2] = size(c);
centralpeakindex = 0.5*[s1+1 s2+1]; cpx = centralpeakindex(1); cpy = centralpeakindex(2);
cc=imregionalmax(c);
[peakx,peaky] = find(cc == 1); peakmatrix = [peakx peaky];
if length(peakx)>= 20000
    continue;
end
centralpeakrepmat = repmat(centralpeakindex,length(peakx),1);
% Minimum and median grid scale computation
distfromcentralpeak = diag(pdist2(peakmatrix,centralpeakrepmat));
if distfromcentralpeak == 0
    continue;
end
[distascending,index] = sort(distfromcentralpeak);
gridscale = distascending(2); %Minimum grid scale
mediangridscale = median(distascending(2:min(6,length(distascending)))); %Medial grid scale
% Grid score computation 
c=Rxy;
radius1 = max(distascending(2:min(6,length(distascending)))) + 3;
[p1 q1] = meshgrid(1:length(c));
circlemask1 = sqrt((p1-centralpeakindex(1)).^2 + (q1-centralpeakindex(2)).^2) <= radius1;
c2 = c; c2 = c2.*circlemask1;
figure(3); imagesc(c2)
colormap(jet)
title('Masking the outer peaks of the autocorrelation map')
axis off
radius2 = 12;
[p2 q2] = meshgrid(1:length(c));
circlemask2 = sqrt((p2-centralpeakindex(1)).^2 + (q2-centralpeakindex(2)).^2) <= radius2;
circlemask2 = not(circlemask2);
c3 = c2.*circlemask2;
figure(4); imagesc(c3)
colormap(jet)
title('Masking the central peak of the autocorrelation map')
axis off
hgs = gridscore(c3); %Hexagonal Gridness Score of the neuron
sgs = squaregridscore(c3); %Square Gridness Score of the neuron
if hgs > 0.3 && sgs < 0.3
    hex = hex + 1;
    neu_hex = [neu_hex, n];
    neu_hgs = [neu_hgs, hgs];
end
if sgs > 0.3 && hgs < 0.3
    squ = squ + 1;
    neu_squ = [neu_squ, n];
    neu_sgs = [neu_sgs, sgs];
end
figure(5);plot(-1:0.1:1,zeros(21),'k',zeros(21),-1:0.1:1,'k',hgs,sgs,'*r','markersize', 10);hold on;
title('HGS and SGS scores','FontSize',20);
hold on;
% pause;
end