clc
clear all
close all
%%
%%%%%%%%%%%%%%%%%%%%%%%%%%          SHUFFLE TEST       %%%%%%%%%%%%%%%
%%%load files%%%%%%
tilted_resp=1;
aligned_resp=0;

%%%%%%%%%%%%%%% SET FILENAMES %%%%%%%%
if aligned_resp
    traj_file='Trajectory_interpolated_aligned_lattice_no_diag.csv';
    spike_file=readtable("encoded_AL_t20p3_b0.8pi_2d_testA.csv");%
    props_file="Aligned_data_props_b0.8pi_std3d1.5_std2d1.5A.mat";
end
if tilted_resp
    traj_file='Trajectory_interpolated_tilted_lattice_no_diagAzra.csv';
    spike_file=readtable("encoded_TL_t20p3_b0.8pi_3d_testA.csv");%
    tilted_props_filename="Tilted_data_props_b0.8pi_std3d1.5_std2d1.5GRID.mat";
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%


table_data=csvread(traj_file); %stores traj data in table
ns_1=table2array(spike_file);%%loads spike data
%plot3(table_data(:,1),table_data(:,2),table_data(:,3)) %to verify trajectory

vl=0.25;%%voxel length
thresh_std=1.5; %set spike threshold


%%%%%%%%%%%%
%2i0 voxels along a dimension within lattice maze
x1=1.125:vl:5.875;
y1=1.125:vl:5.875;
z1=1.125:vl:5.875;
[X,Y,Z]=meshgrid(x1, y1 ,z1);
voxel_coords=[X(:) Y(:) Z(:)]; %first increaseas along x as x,x+1, then along y, then z
table = table_data(1:100000,:);
CC_voxel = {};
sum_den = [];
load("sum_den.mat");

if tilted_resp 
   
    load(tilted_props_filename)
   
    
    neurons = [3	7	12	22	34	35	38	40	42	43	47	50];
    
else
    load(aligned_props_filename)
    
    neurons =neurons;
end
%%

%%%%%%%%%%%%%% SET THESE PARAMS %%%%%%%%
minv=50;%minimum voxels to be considered as grid field
tilted=1;
plot_centroids=1;%plot centroids of all fields of all neurons
%tilted_props_filename="Tilted_data_props_b0.8pi_std3d1.5_std2d1.5GRID.mat";
%aligned_props_filename=" " ;
%neurons=1:50; %neurons to check for
min_dist=5; %merge fields if distance between their centroids are less than this distance %here 5 is number of voxels(<25 cm in the original experiment which is almost one-fourth the cube lenght)

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%



identify_placef=1;%set this to 1 to identify grid fields ,check min voxels
vl=0.25;%voxel length

%firr_thresh=0.2;
%%

%%%%%%%%%%   SET THESE PARAMS    %%%%%%%%%
required_shuffles=2; %required number of shuffles, in the original paper it is 1000
max_iterations=2000; %maximum iterations after which stop the procedure, whether or not getting required_shuffles
min_vol=26; %minimum volume of a field to be considered a grid field
field_tolerance=20; %set this higher for higher tolerance %original paper value is 3%to get no. of new shuffle grid fields similar to original no. of fields (+ or - field_tolerance of this value)
neurons=[1:4];
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%





n=100000; %number of traj points
cv_shuffles=struct(); 
regionprops=1; %leave it at 1
for neuron=neurons
 
    % get the spike train
    spike_train = ns_1(:,neuron);

    
    cv_shuffle=[]; %this will be storing successful shuffle cv values and will respawn with every neuron
    for iter = 1:max_iterations %outer loops that loops till maximum iterations
        if length(cv_shuffle)+1>required_shuffles %stop iterations if required number of successful shuffles reached
            disp('entered5')
            break;
        end
        fprintf('neuron = %d, iter = %d\n', neuron, iter);
        
        shift_amount = randi([1000, n-1000]);
%         shift_amount = 0;%for testing
    
        % Circularly shift the spike train by the random integer
        spike_train_shifted = circshift(spike_train, shift_amount);
        
        %calculate firr of shifted spike train
        for i=1:size(voxel_coords,1) %calculate firr of each voxel
            spikepos = spike_pos_shuffle(table,spike_train_shifted,thresh_std);
%         
            sum_num = num2(voxel_coords(i,:),vl,spikepos);
        
            firr(i)=sum_num/sum_den(i);
        end
        firr_arr=firr.';
        
        
        if regionprops
            dim=length(X);
            binary_volume=zeros(dim,dim,dim);%initialise volume with 0s

            counter=1;


            for k=1:dim %creates binary volumetric image
                j=1;
                for j=1:dim
                    i=1;
                    for i=1:dim
                        %if firr_arr(counter)>firr_thresh*max(firr_arr)
                        if firr_arr(counter)>0
                            binary_volume(i,j,k)=1; 
                        end 
                        counter=counter+1;
                    end
                end
            end

            CC = bwconncomp(binary_volume); %%checks contiguous voxels
            stats = regionprops3(CC,"all");  
            centroids_shuffle=stats.Centroid; %get centroids of connected components of shuffle (default connectivity 26)
            volume=stats.Volume; %get volumes of these connected components
            
        end
        % only select grid fields that have minimum volume
        rows_to_select = volume > min_vol;

        % Select only those rows from centroids
        centroids_shuffle = centroids_shuffle(rows_to_select,:);
        
        %merge fields with less than min_dist; returns pts which are the
        %centroids of the merged fields
        pts = centroids_shuffle;
        ds = pdist2(pts,pts,'euclidean');      
        ds(eye(size(ds),'logical')) = inf;
        while any(ds(:)<min_dist)   
            for kk = 1:size(ds,2)
                if any(ds(:,kk)<min_dist)
                    id1 = find(ds(:,kk)<min_dist,1,'first');
                    mpoint = nanmean(pts([kk id1],:),1);
                    pts([kk id1],:) = [];
                    pts = [pts; mpoint];
                    ds = pdist2(pts,pts,'euclidean');    
                    ds(eye(size(ds),'logical')) = inf;
                    break
                end
            end
        end %merge fields with less than min_dist; returns pts
        centroids_shuffle=pts;
    
        
        total_fields=size(centroids_shuffle,1); %count fields that satisfy minimum volume condition
        
        if total_fields>0
            disp('entered')
            if (total_fields>fields_count(neuron)+field_tolerance) || (total_fields<fields_count(neuron)-field_tolerance)% if total fields is out of bounds by 3 fields in comparison to original before shuffle
                disp('entered1')
                continue;
            end
        end
        if total_fields==0 %if there are no fields
            disp('entered2')
            continue;
        end
       
        
        if total_fields<4 %less than 4 fields
            disp('entered4')
            continue;
        end
        
        disp('entered3')
  
       
            % Find the distance to the three nearest neighbors for each centroid
        [~, distances] = knnsearch(centroids_shuffle, centroids_shuffle, 'K', 4);

        % % Discard the first column (which is the distance to itself)
        distances = distances(:, 2:end); % distances is a matrix of size (fields per cell x 3) where 3 is the 3 nearest neighbours
        
        %discard repeated distances
        distances_unique = unique(distances(:))';

        % Calculate the mean and standard deviation of all interfield distances
        mean_distance = mean(distances_unique(:));
        std_distance = std(distances_unique(:));

        % Calculate the CV value
        
        cv_shuffle(iter) = std_distance / mean_distance;
        

        
    end
    cv_shuffles(neuron).cv_shuffle=cv_shuffle;% structure with shuffle cv values for each neuron (for successful iterations)
    cv_shuffles_mean(neuron)=mean(cv_shuffle);%structure with mean of cv values for each neuron; size of this array is number of neurons
    
end

%%
%%%%%%%% spike_pos_shuffle_func%%%%%
function [spikepos]=spike_pos_shuffle(table,ns,thresh_std) %%return position whose spike data crosses threshold
      ot = ns;
      ot_mean = mean(ot);
      ot_std = std(ot);
      thresh = ot_mean + thresh_std*ot_std;
      %thresh=thresh_spike*max(ns(:,neuron));
      rows_logical = find(ns>thresh);
      
      rows_pos = table(rows_logical,:);
      
      
     
      spikepos=rows_pos;
end