%%
clc 
clear all
close all
%% Select aligned lattice or tilted lattice for which properties are to be generated
aligned_resp = 0;
tilted_resp = 1;
%%
%%%load files%%%%%%
if aligned_resp
    traj_file='Trajectory_interpolated_aligned_lattice_no_diag.csv';
    spike_file=readtable("encoded_AL_t20p3_b0.8pi_2d_testA.csv");%
    props_file="Aligned_data_props_b0.8pi_std3d1.5_std2d1.5A.mat";
end
if tilted_resp
    traj_file='Trajectory_interpolated_tilted_lattice_no_diagAzra.csv';
    spike_file=readtable("encoded_TL_t20p3_b0.8pi_3d_testA.csv");%
    props_file="Tilted_data_props_b0.8pi_std3d1.5_std2d1.5GRID.mat";
end

table_data=csvread(traj_file); %stores traj data in table
ns_1=table2array(spike_file);%%loads spike data
plot3(table_data(:,1),table_data(:,2),table_data(:,3)) %to verify trajectory
%%
% rotm = ;
% %rotm_inv = 
% table_data = rotm'*table_data';
% table_data = table_data';
% plot3(table_data(:,1),table_data(:,2),table_data(:,3)) %to verify trajectory
%%
%parameters
neuron_start=1;
neuron_end=50;
vl=0.25;%%voxel length
thresh_std=1.5; %set spike threshold
%aligned case
if aligned_resp
    neurons = [ 1 5 8 9 10 11 14 18 20 25 26 27 33 36 39];
end
%tilted case
if tilted_resp
    neurons =  [33, 36, 5, 6, 7, 39, 41, 10, 14, 46, 18, 25, 26];  %For PLACE
    neurons = sort(neurons);
    %neurons = [3	7	12	22	34	35	38	40	42	43	47	50];  %For GRID
    %neurons = [ 1  3  4  6  9 10 11 13 14 15 17 21 26 27 32 33 34 35 37 39 41 42 43 44 46 48 49];
    %neurons = [0, 1, 2, 3, 5, 6, 7, 8, 10, 11, 14, 19, 20, 23, 24, 25, 26, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 40, 41, 42, 43, 45, 46, 47, 48, 49]+1; %With aligned phi case
    %avoid_neurons=[2,3,5,6,9,10,12,13,19,23,24,25,28,33,34,36,41,42,43,44,45,50]; %neurons that dont look like place cells
end
%% variables
plot_firr = 0;
regionprops = 1;
%%
%2i0 voxels along a dimension within lattice maze
x1=1.125:vl:5.875;
y1=1.125:vl:5.875;
z1=1.125:vl:5.875;
[X,Y,Z]=meshgrid(x1, y1 ,z1);
voxel_coords=[X(:) Y(:) Z(:)]; %first increaseas along x as x,x+1, then along y, then z 
%voxel_coords = voxel_coords*rotm;
%% 
table = table_data(1:100000,:);
CC_voxel = {};
sum_den = [];
% for i = 1:size(voxel_coords,1)
%     sum_den = [sum_den, den(voxel_coords(i,:),vl,table)];
% end
load("sum_den.mat");
%%
% for neuron=setdiff(neuron_start:neuron_end,avoid_neurons) %for neuron=1:size(ns_1,2)
for neuron = neurons
    neuron
    
    %tic
    for i=1:size(voxel_coords,1) %calculate firr of each voxel
        spikepos = spike_pos(table,ns_1,neuron,thresh_std);
%         spikepos = shuffl_data{1,neuron};
        sum_num = num2(voxel_coords(i,:),vl,spikepos);
        
        firr(i)=sum_num/sum_den(i);
    end
    firr_arr=firr.';
    %toc
    
    if plot_firr
        firr_arr(isnan(firr_arr))=0; 
        col=vals2colormap(firr_arr,'jet',[0.2 1]);
        
        figure
        plotcube([5 5 5],[0 0 0],0,[1 1 1])
        hold on
        for i=1:length(voxel_coords)%plot 3d firing rate voxel
            
            P = [voxel_coords(i,1),voxel_coords(i,2),voxel_coords(i,3)] ;   % your center point
            L = [vl,vl,vl] ;  % your cube dimensions
            O = P-L/2 ;
            %if firr_arr(i)>firr_thresh*max(firr_arr)
             if firr_arr(i)>0   
                plotcube(L,O,0.4,col(i,:)) %firr_col(i,:)  % use function plotcube
                hold on
            end
          
        end
        colormap(jet)
        title(["Firing rate,neuron:",num2str(neuron)])
        xlabel("x")
        ylabel("y")
        zlabel("z")
    end
    
    if regionprops
        dim=length(X);
        binary_volume=zeros(dim,dim,dim);%initialise volume with 0s
        
        counter=1;
%         
%         firr_thresholded=firr_thresh*max(firr_arr);
%         a = find(firr_arr > firr_thresholded);
%         bv_1d_z = zeros(8000,1);
%         bv_1d_z(a)=1;
%         binary_volume = reshape(bv_1d_z, [20,20,20]);
        
        
        for k=1:dim %creates binary volumetric image
            j=1;
            for j=1:dim
                i=1;
                for i=1:dim
                    %if firr_arr(counter)>firr_thresh*max(firr_arr)
                    if firr_arr(counter)>0
                        binary_volume(i,j,k)=1; 
                    end 
                    counter=counter+1;
                end
            end
        end
        
        CC = bwconncomp(binary_volume); %%checks contiguous voxels
        stats = regionprops3(CC,"all");  
        FIRR_3dprops(neuron).props=stats;%all 3d properties
        FIRR_3dprops(neuron).firr=firr_arr;
        CC_voxel{neuron} = CC.PixelIdxList; 
        
        save(props_file,'FIRR_3dprops');
   
    end
end