%required output variables:
% cv-array (size-no. of neurons) of all cv values for normal grid cells without shuffle
% cv_shuffles- structure containing all shuffles' cv values arrays for each neuron
% (for successful shuffles)
% cv_shuffles_mean - array (size-no. of neurons) containing mean of all cvs values for a
% particular neuron's shuffles done for all cells 
% for simply testing shuffles-set shift_amount=0 
%the distances are in voxels


clc 
clear all
close all

%%

%%%%%%%%%%%%%% SET THESE PARAMS %%%%%%%%
minv=50;%minimum voxels to be considered as grid field
tilted=0;
plot_centroids=1;%plot centroids of all fields of all neurons
% tilted_props_filename=".mat";
aligned_props_filename="Aligned_data_props_b2.51_t1.7_2Dt1.5.mat" ;
%neurons=1:50; %neurons to check for
min_dist=5; %merge fields if distance between their centroids are less than this distance %here 5 is number of voxels(<25 cm in the original experiment which is almost one-fourth the cube lenght)

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%



identify_placef=1;%set this to 1 to identify grid fields ,check min voxels
vl=0.25;%voxel length

%firr_thresh=0.2;


%%
if tilted 
   
    load(tilted_props_filename)
   
    
    neurons=[5,6,7,10,14,18,25,26,33,36,39,41,46];
    
else
    load(aligned_props_filename)
    
    neurons =[2, 4, 6, 8, 11, 16, 18, 24, 26, 30, 33, 34, 41] + 1;
end




rotm = [[ 0.57357644,  0.,          0.81915204],
        [ 0.57922797,  0.70710678, -0.40557979],
        [-0.57922797,  0.70710678,  0.40557979]];
XYZ = [1 0 0; 0 1 0; 0 0 1];
ABC = rotm*XYZ;
axes_cords = [XYZ; ABC];

%% %%%%%%% IDENTIFY GRID FIELDS %%%%%%%%%%%%

x1=1.125:0.25:5.875;
y1=1.125:0.25:5.875;
z1=1.125:0.25:5.875;

[X,Y,Z]=meshgrid(x1, y1 ,z1);
voxel_coords=[Y(:) X(:) Z(:)];%voxel coords increase along x then y then z

orientation=double.empty;
orientationplot=double.empty;

elongation=double.empty;

yfl=double.empty;
xfl=double.empty;
zfl=double.empty;
centroids=double.empty;

%%identify grid fields from all neurons and store in pf structure
if identify_placef
    pf=struct;
    pf_counter=1;

    for neuron = neurons
        propstable=FIRR_3dprops(neuron).props; %get 3d properties table of the neuron
        firrtable=FIRR_3dprops(neuron).firr; %get firing rate of neuron
        for i=1:height(propstable) 
            
            vol=propstable{i,"Volume"};
            
            if  vol>minv  %if volume of cluster is greater than minvoxels
                pf(pf_counter).props=propstable(i,:);%pf (should be gf) is a structure that has props,neuron number, firing rate of the grid field
                pf(pf_counter).neuron=neuron;
                pf(pf_counter).firr=firrtable;
                pf_counter=pf_counter+1;
                
            end
        end
    end
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%      
%%%%%%%%%%%%%       plot centroids of all the fields   %%%%%%%%%%%%%%%%%%%%%%%%%%%

if plot_centroids

    %%%%%%%%%%%%% FIG 4 a1) Aligned or FIG 4 b1) Tilted %%%%%%%%%%%%%%
    set(0, 'DefaultFigureRenderer', 'painters');
    for i=1:size(pf,2)
        
        centroid=pf(i).props.Centroid; %pf has fields props, neuron,firr
        centroids=[centroids;centroid];
        centroids=centroids;
        
    end
%     centroids = rotm * centroids';
    centroids_scaled=(5/20)*(centroids)+1;
    if tilted == 1
    centroids_scaled1 = rotm * centroids_scaled';
    centroids_scaled1 = centroids_scaled1';

    end
    
    %%hold off
    
    %figure
    if tilted == 1
             V1 = [1 1 1];
             V2 = [6 1 1];
             V3 = [6 1 6];
             V4 = [1 1 6];
             V5 = [1 6 6];
             V6 = [1 6 1];
             V7 = [1 6 6];
             V8 = [6 6 6];
             V9 = [6 6 1];
             V10 = [1 6 1];
             V11 = [1 1 1];
             V12 = [1 1 6];
             V13 = [6 1 6];
             V14 = [6 6 6];
             V15 = [6 6 1];
             V16 = [6 1 1];
             C = [V1;V2;V3;V4;V5;V6;V7;V8;V9;V10;V11;V12;V13;V14;V15;V16];
             E = rotm*C';
             E = E';
             figure
             plot3(E(:,1),E(:,2),E(:,3),'black')
             hold on
    end
    if tilted
    scatter3(centroids_scaled1(:,1) ,centroids_scaled1(:,2),  centroids_scaled1(:,3),30, centroids_scaled1(:,3),"filled")
    title("Centroids of place fields- Tilted lattice")
    centroid = mean(E);

    % Define line directions in ABC axes
    line_directions = [ABC(:, 1)'; ABC(:, 2)'; ABC(:, 3)'];

    % Define line length
    line_length = 5;

    % Calculate line end points
    line_ends = centroid + line_length * line_directions;
    
    % Plot the lines
    hold on;
    line_colors = {'red', 'green', 'blue'};
    for i = 1:3
        line_coords = [centroid; line_ends(i, :)];
        plot3(line_coords(:, 1), line_coords(:, 2), line_coords(:, 3),  'Color', line_colors{i},'LineWidth', 2 );
    end
%     hold on
%     plot3(x_centroid_rot, y_centroid_rot, z_centroid_rot);
    daspect([1 1 1]);

        %hold on
    else
    scatter3(centroids_scaled(:,1) ,centroids_scaled(:,2),  centroids_scaled(:,3),30, centroids_scaled(:,3),"filled")
    title("Centroids of place fields- Aligned lattice")
    %hold on
        
    end
    xlabel('x','Fontsize',15)
    ylabel('y','Fontsize',15)
    zlabel('z','Fontsize',15)
    
    if tilted  
    medx=mean(centroids_scaled1(:,1));
    medy=mean(centroids_scaled1(:,2));
    medz=mean(centroids_scaled1(:,3));
    else
    medx=mean(centroids_scaled(:,1));
    
    medy=mean(centroids_scaled(:,2));
    medz=mean(centroids_scaled(:,3));
    end
    if tilted == 0
        hold on
%
      scatter3(medx,medy,medz,100,'filled','red')
      hold on
    end
    
    if tilted == 1
    hold on
      scatter3(medx,medy,medz,100,'filled','red')% scatter3(medx,3.5,3.5,100,'filled','green')

    end

    x2=3.5*ones(length(x1),1);% x2=3.5*ones(length(x1),1);
    y2=3.5*ones(length(y1),1);
    z2=3.5*ones(length(z1),1);
    xyz_rot = rotm * [x2 y2 z2]';

    if tilted ==1
%         plot3(x1, xyz_rot(2, :), xyz_rot(3, :), 'Color', 'red');
%         hold on
%         plot3(xyz_rot(1, :), y1, xyz_rot(3, :), 'Color', 'blue');
%         hold on
%         plot3(xyz_rot(1, :), xyz_rot(2, :), z1, 'Color', 'green');
%         hold on
%         plot3(x1,y2,z2,'Color','black');
%         hold on
%         plot3(x2,y1,z2,'Color','black');
%         hold on
%         plot3(x2,y2,z1,'Color','black');
%         hold on
    end
    
    hold off

    %%%%%%%%%%% Fig 4a2) Aligned or Fig 4b2) Tilted  %%%%%%%%%%%
    
    if tilted
        cc = inv(rotm)*centroids_scaled';
        cc = cc';
    end

end
% adjust voxel length to normal length

%%%%%%%%%%%%%%
%%
%%%%%%%%%%%%%%%%%%%%%  calculate cv of all neurons    %%%%%%%%%%
unique_neurons = unique([pf(:).neuron]);%get the number of neurons
count=0; %to keep track of the row in the pf structure %at the end of all iterations, count will be the same number of rows as the pf structure

for i = 1:length(unique_neurons) %outer loop is looping across all neurons
    
    neuron_rows = size(pf([pf(:).neuron] == unique_neurons(i)).neuron, 1);% gets number of rows for a neuron in the pf structure, each row is for a field
    centroids=[];
    fields_count(i)=length(neuron_rows); %number of grid fields per cell-need it later to compare if the shuffle is within +-3 range of this
    for j = 1:length(neuron_rows) %collect all centroids of a particular neuron
        
        
            
        count=count+1;
        centroid=pf(count).props.Centroid; %get centroid of a particular grid field%pf has fields props, neuron,firr%props has propert Centroid
        centroids=[centroids;centroid];
        centroids=centroids;
        
    end 
       
    centroids_old=centroids;
    
    pts = centroids;
    ds = pdist2(pts,pts,'euclidean');      
    ds(eye(size(ds),'logical')) = inf;
    while any(ds(:)<min_dist)   
        for kk = 1:size(ds,2)
            if any(ds(:,kk)<min_dist)
                id1 = find(ds(:,kk)<min_dist,1,'first');
                mpoint = nanmean(pts([kk id1],:),1);
                pts([kk id1],:) = [];
                pts = [pts; mpoint];
                ds = pdist2(pts,pts,'euclidean');    
                ds(eye(size(ds),'logical')) = inf;
                break
            end
        end
    end %merge fields with less than min_dist; returns pts
    centroids=pts;
    
   
    
    % Find the distance to the three nearest neighbors for each centroid
    [~, distances] = knnsearch(centroids, centroids, 'K', 4);

    % % Discard the first column (which is the distance to itself)
    distances = distances(:, 2:end); % distances is a matrix of size (fields per cell x 3) where 3 is the 3 nearest neighbours


    %Remove repeated values and assign to a 1D array
    distances_unique = unique(distances(:))';

    % Calculate the mean and standard deviation of all interfield distances
    mean_distance = mean(distances_unique(:));
    std_distance = std(distances_unique(:));

    % Calculate the CV value
    
    cv(i) = std_distance / mean_distance;

    

end
%histogram(cv) %plot histogram of cv values





%%
%%%%%%%%%%%%%%%%%%%%%%%%%%          SHUFFLE TEST       %%%%%%%%%%%%%%%
%%%load files%%%%%%
tilted_resp=1;
aligned_resp=0;

%%%%%%%%%%%%%%% SET FILENAMES %%%%%%%%
if aligned_resp
    traj_file='Trajectory_interpolated_aligned_lattice_no_diag.csv';
    spike_file=readtable("encoded_AL_t20p3_b0.8pi_2d_testA.csv");%
    props_file="Aligned_data_props_b0.8pi_std3d1.5_std2d1.5A.mat";
end
if tilted_resp
    traj_file='Trajectory_interpolated_tilted_lattice_no_diagAzra.csv';
    spike_file=readtable("encoded_TL_t20p3_b0.8pi_3d_testA.csv");%
    %props_file="propsGRID.mat";
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%


table_data=csvread(traj_file); %stores traj data in table
ns_1=table2array(spike_file);%%loads spike data
%plot3(table_data(:,1),table_data(:,2),table_data(:,3)) %to verify trajectory

vl=0.25;%%voxel length
thresh_std=1.5; %set spike threshold


%%%%%%%%%%%%
%2i0 voxels along a dimension within lattice maze
x1=1.125:vl:5.875;
y1=1.125:vl:5.875;
z1=1.125:vl:5.875;
[X,Y,Z]=meshgrid(x1, y1 ,z1);
voxel_coords=[X(:) Y(:) Z(:)]; %first increaseas along x as x,x+1, then along y, then z
table = table_data(1:100000,:);
CC_voxel = {};
sum_den = [];
load("sum_den.mat");
%%%%%%%% spike_pos_shuffle_func%%%%%



%%

%%%%%%%%%%   SET THESE PARAMS    %%%%%%%%%
required_shuffles=4; %required number of shuffles, in the original paper it is 1000
max_iterations=100; %maximum iterations after which stop the procedure, whether or not getting required_shuffles
min_vol=26; %minimum volume of a field to be considered a grid field
field_tolerance=20; %set this higher for higher tolerance %original paper value is 3%to get no. of new shuffle grid fields similar to original no. of fields (+ or - field_tolerance of this value)
neurons=7;
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%





n=100000; %number of traj points
cv_shuffles=struct(); 
regionprops=1; %leave it at 1
for neuron=neurons
 
    % get the spike train
    spike_train = ns_1(:,neuron);

    
    cv_shuffle=[]; %this will be storing successful shuffle cv values and will respawn with every neuron
    for iter = 1:max_iterations %outer loops that loops till maximum iterations
        if length(cv_shuffle)+1>required_shuffles %stop iterations if required number of successful shuffles reached
            disp('entered5')
            break;
        end
        fprintf('neuron = %d, iter = %d\n', neuron, iter);
        
        shift_amount = randi([100, n-100]);
%         shift_amount = 0;%for testing
    
        % Circularly shift the spike train by the random integer
        spike_train_shifted = circshift(spike_train, shift_amount);
        
        %calculate firr of shifted spike train
        for i=1:size(voxel_coords,1) %calculate firr of each voxel
            spikepos = spike_pos_shuffle(table,spike_train_shifted,thresh_std);
%         
            sum_num = num2(voxel_coords(i,:),vl,spikepos);
        
            firr(i)=sum_num/sum_den(i);
        end
        firr_arr=firr.';
        
        
        if regionprops
            dim=length(X);
            binary_volume=zeros(dim,dim,dim);%initialise volume with 0s

            counter=1;


            for k=1:dim %creates binary volumetric image
                j=1;
                for j=1:dim
                    i=1;
                    for i=1:dim
                        %if firr_arr(counter)>firr_thresh*max(firr_arr)
                        if firr_arr(counter)>0
                            binary_volume(i,j,k)=1; 
                        end 
                        counter=counter+1;
                    end
                end
            end

            CC = bwconncomp(binary_volume); %%checks contiguous voxels
            stats = regionprops3(CC,"all");  
            centroids_shuffle=stats.Centroid; %get centroids of connected components of shuffle (default connectivity 26)
            volume=stats.Volume; %get volumes of these connected components
            
        end
        % only select grid fields that have minimum volume
        rows_to_select = volume > min_vol;

        % Select only those rows from centroids
        centroids_shuffle = centroids_shuffle(rows_to_select,:);
        
        %merge fields with less than min_dist; returns pts which are the
        %centroids of the merged fields
        pts = centroids_shuffle;
        ds = pdist2(pts,pts,'euclidean');      
        ds(eye(size(ds),'logical')) = inf;
        while any(ds(:)<min_dist)   
            for kk = 1:size(ds,2)
                if any(ds(:,kk)<min_dist)
                    id1 = find(ds(:,kk)<min_dist,1,'first');
                    mpoint = nanmean(pts([kk id1],:),1);
                    pts([kk id1],:) = [];
                    pts = [pts; mpoint];
                    ds = pdist2(pts,pts,'euclidean');    
                    ds(eye(size(ds),'logical')) = inf;
                    break
                end
            end
        end %merge fields with less than min_dist; returns pts
        centroids_shuffle=pts;
    
        
        total_fields=size(centroids_shuffle,1); %count fields that satisfy minimum volume condition
        
        if total_fields>0
            disp('entered')
            if (total_fields>fields_count(neuron)+field_tolerance) || (total_fields<fields_count(neuron)-field_tolerance)% if total fields is out of bounds by 3 fields in comparison to original before shuffle
                disp('entered1')
                continue;
            end
        end
        if total_fields==0 %if there are no fields
            disp('entered2')
            continue;
        end
       
        
        if total_fields<4 %less than 4 fields
            disp('entered4')
            continue;
        end
        
        disp('entered3')
  
       
            % Find the distance to the three nearest neighbors for each centroid
        [~, distances] = knnsearch(centroids_shuffle, centroids_shuffle, 'K', 4);

        % % Discard the first column (which is the distance to itself)
        distances = distances(:, 2:end); % distances is a matrix of size (fields per cell x 3) where 3 is the 3 nearest neighbours
        
        %discard repeated distances
        distances_unique = unique(distances(:))';

        % Calculate the mean and standard deviation of all interfield distances
        mean_distance = mean(distances_unique(:));
        std_distance = std(distances_unique(:));

        % Calculate the CV value
        
        cv_shuffle(iter) = std_distance / mean_distance;
        

        
    end
    cv_shuffles(neuron).cv_shuffle=cv_shuffle;% structure with shuffle cv values for each neuron (for successful iterations)
    cv_shuffles_mean(neuron)=mean(cv_shuffle);%structure with mean of cv values for each neuron; size of this array is number of neurons
    
end


%%
%%%spike_pos_shuffle function%%%a

function [spikepos]=spike_pos_shuffle(table,ns,thresh_std) %%return position whose spike data crosses threshold
      ot = ns;
      ot_mean = mean(ot);
      ot_std = std(ot);
      thresh = ot_mean + thresh_std*ot_std;
      %thresh=thresh_spike*max(ns(:,neuron));
      rows_logical = find(ns>thresh);
      
      rows_pos = table(rows_logical,:);
      
      
     
      spikepos=rows_pos;
end