%
% use this file not the other one
clc 
clear all
close all
%%
%%%files%%%%%%
% traj_file='H:\.shortcut-targets-by-id\1_HPNBtvegcfjF9AB6-jAJHtWwqwC4YCF\3d\3d_lattice_maze\Trajectory_interpolated_aligned_lattice.csv';
traj_file='Trajectory_interpolated_tilted_lattice_no_diagAzra.csv';

% spike_file=readtable("H:\.shortcut-targets-by-id\1_HPNBtvegcfjF9AB6-jAJHtWwqwC4YCF\3d\3d_lattice_maze\encoded_AL_phi2_b1.8_lin.csv");%spike data(table that has trajectory point vs spike)

% props_file="H:\.shortcut-targets-by-id\1_HPNBtvegcfjF9AB6-jAJHtWwqwC4YCF\3d\3d_lattice_maze\props_AL_phi2_b1.8_lin_t0.4_shuffl_bhar.mat";%file where the regionprops is stored for further statistical analysis
% 
% spike_traj = "H:\.shortcut-targets-by-id\1_HPNBtvegcfjF9AB6-jAJHtWwqwC4YCF\3d\3d_lattice_maze\Field_elongation\shuffl_data";
% load(spike_traj)
table_data=csvread(traj_file); %stores traj data in table
%table = 
%table=traj_file; 
% ns_1=table2array(spike_file);%%loads spike data
%table=table(1:size(ns_1)-1,:);
% plot3(table_data(:,1),table_data(:,2),table_data(:,3)) %to verify trajectory
%%
%parameters
neuron_start=1;
neuron_end=50;
vl=0.25;%%voxe
% l length
thresh_spike=0.5; %set spike threshold
%firr_thresh=0.2;  % set firing rate threshold
%aligned case
% avoid_neurons = [1,2,5,6,8,9,10,11,12,15,18,19,23,25,27,28,30,32,33,35,37,38,41,42,46,48,49,50];
%tilted case
%avoid_neurons=[2,5,6,10,13,15,20,25,26,32,35,44,45]; %neurons that dont look like place cells
%%% variables
plot_firr = 0;
regionprops = 1;
%%%
%2i0 voxels along a dimension within lattice maze
x1=1.125:vl:5.875;
y1=1.125:vl:5.875;
z1=1.125:vl:5.875;
[X,Y,Z]=meshgrid(x1, y1 ,z1);
voxel_coords=[Y(:) X(:) Z(:)]; %first increaseas along x as x,x+1, then along y, then z 
%%
% spikecurr = table2array(spike_file(:,7));
% firr = spikecurr>thresh_spike*max(spikecurr);
% plot3(table(:,1), table(:,2), table(:,3))
% hold on
% 
% plot3(table(firr,1), table(firr,2), table(firr,3),'.r')
%% 

table = table_data(1:99999,:);
CC_voxel = {};
sum_den = [];
tic
% for i=1:size(voxel_coords,1)
%     sum_den = [sum_den, den(voxel_coords(i,:),vl,table)];
% end
load("sum_den.mat");
toc

%%
FIRR_3dprops = struct('props', cell(1, 8100), 'firr', cell(1, 8100));

%%
% for neuron=setdiff(neuron_start:neuron_end,avoid_neurons) %for neuron=1:size(ns_1,2)
tic
load('shuffl_data.mat');
for neuron=1:length(shuffl_data)
    neuron
%     tic
    for i=1:size(voxel_coords,1) %calculate firr of each voxel
        spikepos = shuffl_data{1,neuron};
%         tic
%         sum_num = num(voxel_coords(i,:),vl,spikepos);
        sum_num2 = num2(voxel_coords(i,:),vl,spikepos);
%         toc
        firr(i)=sum_num2/sum_den(i);
    end
    firr_arr=firr.';
%     toc
    
    if plot_firr
        firr_arr(isnan(firr_arr))=0; 
        col=vals2colormap(firr_arr,'jet',[0.2 1]);
        
        figure
        plotcube([5 5 5],[0 0 0],0,[1 1 1])
        hold on
%         tic
        for i=1:length(voxel_coords)%plot 3d firing rate voxel
            
            P = [voxel_coords(i,1),voxel_coords(i,2),voxel_coords(i,3)] ;   % your center point
            L = [vl,vl,vl] ;  % your cube dimensions
            O = P-L/2 ;
            %if firr_arr(i)>firr_thresh*max(firr_arr)
             if firr_arr(i)>0   
                plotcube(L,O,0.4,col(i,:)) %firr_col(i,:)  % use function plotcube
                hold on
            end
%         toc
        end
        colormap(jet)
        title(["Firing rate,neuron:",num2str(neuron)])
        xlabel("x")
        ylabel("y")
        zlabel("z")
    end
    
    if regionprops
        dim=length(X);
        binary_volume=zeros(dim,dim,dim);%initialise volume with 0s
        
        counter=1;
%         
%         firr_thresholded=firr_thresh*max(firr_arr);
%         a = find(firr_arr > firr_thresholded);
%         bv_1d_z = zeros(8000,1);
%         bv_1d_z(a)=1;
%         binary_volume = reshape(bv_1d_z, [20,20,20]);
        
%         tic
        for k=1:dim %creates binary volumetric image
            j=1;
            for j=1:dim
                i=1;
                for i=1:dim
                    %if firr_arr(counter)>firr_thresh*max(firr_arr)
                    if firr_arr(counter)>0
                        binary_volume(i,j,k)=1; 
                    end 
                    counter=counter+1;
                end
            end
        end
%         toc
        CC = bwconncomp(binary_volume); %%checks contiguous voxels
        stats = regionprops3(CC,"all");  
        FIRR_3dprops(neuron).props=stats;%all 3d properties
        FIRR_3dprops(neuron).firr=firr_arr;
        CC_voxel{neuron} = CC.PixelIdxList; 
        
        % save(props_file,'FIRR_3dprops');
   
    end
end
toc

%%

% % Calculate number of place field per neuron
% pfs = [1];
% k_sum = 0;
% for i=1:length(FIRR_3dprops)
%     k = FIRR_3dprops(i).props.Volume;
%     for j=1:length(k)
%         if k(j)>50
%             k_sum = k_sum+1;
%         end
%     end
% 
%     if ~mod(i,100)
%         i;
%         pfs = [pfs, k_sum];
%         k_sum = 0;
%     end
% end
%% elongation index
% elongation = [];
% for neuron=1:length(FIRR_3dprops)
%     t=FIRR_3dprops(neuron).props;
%     arr=t.PrincipalAxisLength;
%     for i = 1:length(t.Volume) 
%         if t.Volume(i) > 50
% 
%             elong=2*arr(i,1)/(arr(i,2)+arr(i,3));
%             if elong>2
%                 neuron ;
%                 i;
%                 arr(i,:);
%                 elong;
%             end
%             elongation=[elongation ;elong];
%         end
%     end
% end
% histogram(elongation, 200)
%%
% Calculate number of place field per neuron
% pfs = [1];
% k_sum = 0;
% for i=1:length(FIRR_3dprops)
%     k = FIRR_3dprops(i).props.Volume;
%     for j=1:length(k)
%         if k(j)>50
%             k_sum = k_sum+1;
%         end
%     end
% 
%     if ~mod(i,100)
%         i;
%         pfs = [pfs, k_sum];
%         k_sum = 0;
%     end
% end

elongation = [];
% elongations{1,156} = [];
for neuron=1:length(FIRR_3dprops)
    t=FIRR_3dprops(neuron).props;
    arr=t.PrincipalAxisLength;
    for i = 1:length(t.Volume) 
        if t.Volume(i) > 50           
            elong=2*arr(i,1)/(arr(i,2)+arr(i,3));
            elongation=[elongation ;elong];
        end
    end
    if ~mod(neuron,100)
        ii = neuron/100;
        elongations{1,ii} = elongation;
        %elongation = [];
    end
end
% histogram(elongation, 200)
%%   
% elongation indices per neuron
% elongations = {};
% for i=2:length(pfs)
%     elongations{1,i-1} = elongation(pfs(i-1):(pfs(i-1)+pfs(i)-1));
% end
%%
% create probab dist
% mmin = 1;
% mmax = max(neu1)+1;
% bins = 20;
% bins_l = linspace(mmin, mmax, bins);
% count = [];
% for i=1:(length(bins_l)-1)
%     r_ = find(neu1>bins_l(i) & neu1<bins_l(i+1));
%     count = [count, length(r_)];
% end
percentiles = [];
means = []; stdevs = [];
for i=1:length(elongations)
% for i =i    
    neu = elongations{1,i};
%     figure(i)
%     histogram(neu,200)
    percentiles = [percentiles, prctile(neu,95)];
    means = [means, mean(neu)];
    stdevs = [stdevs, std(neu)];
    
end

%%
FIRR_3dpropsOriginal = load("Tilted_data_props_b0.8pi_std3d1.5_std2d1.5.mat");
neurons = [5,6,7,10,14,18,25,26,33,36,39,41,46];
minv = 50;
pf=struct;
pf_counter=1;
for neuron=neurons
    propstable=FIRR_3dpropsOriginal.FIRR_3dprops(neuron).props; %get 3d properties table of the neuron
    firrtable=FIRR_3dpropsOriginal.FIRR_3dprops(neuron).firr;
    for i=1:height(propstable) 

        vol=propstable{i,"Volume"};

        if  vol>minv  %if volume of cluster is greater than minvoxels
            pf(pf_counter).props=propstable(i,:);
            pf(pf_counter).neuron=neuron;
            pf(pf_counter).firr=firrtable;
            pf_counter=pf_counter+1;

        end
    end
end
elongation_org = [];
for i=1:length(pf)
    t=pf(i).props;
    arr=t.PrincipalAxisLength;
    elong_org=2*arr(1)/(arr(2)+arr(3));
    elongation_org=[elongation_org ;elong_org];
end
elongation_org = elongation_org';
compare = elongation_org <= percentiles;
%%
field_elong_stats = [percentiles; means; stdevs; elongation_org; compare]';
save('field_elong_stats_TL.mat','field_elong_stats');