clc
clear all
close all
%%   
% traj=csvread('D:\Papers\2D paper data\_Submission_files\code and models\trajectories\trajTRIthetavert_20k.csv');
% ot=csvread('D:\Papers\2D paper data\_Submission_files\code and models\experiemental studies\disto codes\encoded_triangular_beta_4_l2.csv');
traj=csvread('E:\WindowsCopy260623\GDrive\3d-paperII\EJN-20230703T083853Z-001\EJN\Revision2\traj_latcheck_30k.csv');
out=csvread('E:\WindowsCopy260623\GDrive\3d-paperII\EJN-20230703T083853Z-001\EJN\Revision2\encoded_TL_t20p3_b0.8pi_2d_trainA.csv');
% out=csvread('E:\WindowsCopy260623\GDrive\3d-paperII\EJN-20230703T083853Z-001\EJN\Revision2\encoded_AL_t20p3_b2.51_2D_cossin.csv');

%testcsv=csvread('encoded1_test.csv')
% ns_1=table2array(spike_file);
pos=[traj(:,1),traj(:,2)];

%% Getting firing fields of the neuron 3,5,22,25,31, 37, 42 
%for v = [3 5 22 25 31 37 42]
   
%neuron_number=3; % Select the neuron number
%w=v(:,end-neuron_number);  %Weights from PI to the selected SC neuron
%ot=w'*piosc_thresh;  %Output activity of the SC neuron
%ot=ot';
place_cells = [5,]; %TL
% place_cells = [2, 4, 6, 8, 11, 16, 18, 24, 26, 30, 33, 34, 41] + 1;
for neuron_number = place_cells
    thresh_std=1.5;
    ot = out(:,neuron_number);
    ot_mean = mean(ot);
    ot_std = std(ot);
    thresh = ot_mean + thresh_std*ot_std;
    firr = find(out(:,neuron_number)>thresh);
    % thresh=max(ot)*.7;
    %firr=find((ot)>thresh)98
    % firr=find(ot(1:20000,neuron_number)>thresh(neuron_number));
    firposgrid=pos(firr,:); %Firing positions on the trajectory
    figure;f = plot(pos(:,1),pos(:,2)); hold on; plot(firposgrid(:,1),firposgrid(:,2),'.r', 'markersize', 15); 
    %axis off
    %title(sprintf('test neuron %d.png',v))
    %saveas(f,sprintf('plots/response_test_%d.png',v))
    %end
   % %% Rate map
    res = 10; %resolution of the rate map
    x = 1:1/res:6;
    y = 1:1/res:6;
    [fx,fy] = meshgrid(x,y);
    %[fx,fy] = meshgrid(0:1/res:4,0:1/res:1);
    firingmap = zeros(length(fx));
    %firingmap = zeros(36, 141);

    gridpoint = [reshape(fx,prod(size(fx)),1) reshape(fx,prod(size(fx)),1)];
    roundinggridpoint = round(gridpoint);
    firposround = round(firposgrid);
    firingvalue = abs(ot(firr));
    for ii = 1:length(firposgrid)
        [~,q1]=min(abs(firposgrid(ii,1)-fx(1,:)));
        [~,q2]=min(abs(firposgrid(ii,2)-fy(:,1)));
        firingmap(q1,q2) = firingvalue(ii);
    end
    firingmap = firingmap/max(max(firingmap));
    gaussian = fspecial('gaussian',[10 10],3); % Design the gaussian filter for smoothing the map
    spikes_smooth=conv2(gaussian,firingmap);
    spikes_smooth = spikes_smooth(end - length(fy) - 5:end,:);
    figure; imagesc(imrotate(spikes_smooth/max(max(spikes_smooth)),90)); axis off
    colormap(jet)
    
    %title('Rate map of the selected neuron')
%     savefig(['FF2D_AL', num2str(neuron_number)]);
%     saveas(gcf, ['FF2D_AL', num2str(neuron_number), '.png'])
end
%% 
hex = 0;
squ = 0;
neu_hex = [];
neu_squ = [];
neu_hgs = [];
neu_sgs = [];
%grid_cells = [3	7	12	22	34	35	38	40	42	43	47	50];
for n = 1:50
%     n
%for n = grid_cells
    ot = out(:,n); 
%     thresh=max(ot)*.80; 
    ot_mean = mean(ot);
    ot_std = std(ot);
    thresh = ot_mean + 0.1*ot_std;
    firr=find((ot)>thresh);
    firposgrid=pos(firr,:); %Firing positions on the trajectory

    res = 10; %resolution of the rate map
    fx = meshgrid(1:1/res:6);
    fy = meshgrid(1:1/res:6);
    firingmap = zeros(length(fx));
    gridpoint = [reshape(fx,prod(size(fx)),1) reshape(fy,prod(size(fx)),1)];
    roundinggridpoint = round(gridpoint);
    firposround = round(firposgrid);
    firingvalue = abs(ot(firr));
    for ii = 1:length(firposgrid)
        [~,q1]=min(abs(firposgrid(ii,1)-fx(1,:)));
        [~,q2]=min(abs(firposgrid(ii,2)-fy(1,:)));
        firingmap(q1,q2) = firingvalue(ii);
    end
    firingmap = firingmap/max(max(firingmap));
    gaussian = fspecial('gaussian',[10 10],3); % Design the gaussian filter for smoothing the map
    spikes_smooth=conv2(gaussian,firingmap);
    figure(1); imagesc(imrotate(spikes_smooth/max(max(spikes_smooth)),90)); axis off
    colormap(jet)
    title(['Rate map of neuron ',num2str(n)]);
 
% Autocorrelation
Rxy = correlation_map(spikes_smooth,spikes_smooth);
figure(2); imagesc(Rxy); axis off; colormap(jet)
title(['Autocorrelation map of neuron ',num2str(n)]);
% Grid scale computation
% Computing the central peak index of the autocorrelation map
c = Rxy;
[s1,s2] = size(c);
centralpeakindex = 0.5*[s1+1 s2+1]; cpx = centralpeakindex(1); cpy = centralpeakindex(2);
cc=imregionalmax(c);
[peakx,peaky] = find(cc == 1); peakmatrix = [peakx peaky];
if length(peakx)>= 20000
    continue;
end
centralpeakrepmat = repmat(centralpeakindex,length(peakx),1);
% Minimum and median grid scale computation
distfromcentralpeak = diag(pdist2(peakmatrix,centralpeakrepmat));
if distfromcentralpeak == 0
    continue;
end
[distascending,index] = sort(distfromcentralpeak);
gridscale = distascending(2); %Minimum grid scale
mediangridscale = median(distascending(2:min(6,length(distascending)))); %Medial grid scale
% Grid score computation 
c=Rxy;
radius1 = max(distascending(2:min(6,length(distascending)))) + 3;
[p1 q1] = meshgrid(1:length(c));
circlemask1 = sqrt((p1-centralpeakindex(1)).^2 + (q1-centralpeakindex(2)).^2) <= radius1;
c2 = c; c2 = c2.*circlemask1;
figure(3); imagesc(c2)
colormap(jet)
title('Masking the outer peaks of the autocorrelation map')
axis off
radius2 = 12;
[p2 q2] = meshgrid(1:length(c));
circlemask2 = sqrt((p2-centralpeakindex(1)).^2 + (q2-centralpeakindex(2)).^2) <= radius2;
circlemask2 = not(circlemask2);
c3 = c2.*circlemask2;
figure(4); imagesc(c3)
colormap(jet)
title('Masking the central peak of the autocorrelation map')
axis off
hgs = gridscore(c3); %Hexagonal Gridness Score of the neuron
sgs = squaregridscore(c3); %Square Gridness Score of the neuron
if hgs > 0.3 && sgs < 0.3
    hex = hex + 1;
    neu_hex = [neu_hex, n];
    neu_hgs = [neu_hgs, hgs];
end
if sgs > 0.3 && hgs < 0.3
    squ = squ + 1;
    neu_squ = [neu_squ, n];
    neu_sgs = [neu_sgs, sgs];
end
figure(5);plot(-1:0.1:1,zeros(21),'k',zeros(21),-1:0.1:1,'k',hgs,sgs,'*r','markersize', 10);hold on;
title('HGS and SGS scores','FontSize',20);
hold on;
%pause;
end