% Author: Etay Hay
% Orientation processing by synaptic integration across first-order tactile neurons (Hay and Pruszynski 2020)
function [models,errs] = derive_model_ga(stim_list,sim_param,ga_param)
max_iter = ga_param.max_iter;
p_mutate = ga_param.p_mutate;
p_cross = ga_param.p_cross;
Nmodels = ga_param.Nmodels;
err_tol = ga_param.err_tol;
Nmr = ga_param.Nmr;
s_group = ga_param.s_group;
mr_loc = get_valid_locations(sim_param,ga_param.d_mr);
[r_loc,c_loc] = find(mr_loc == 1);
for k = 1:Nmodels
mr_subset = [];
for k2 = 1:Nmr
rvec = randperm(length(r_loc));
mr_subset(k2,1) = r_loc(rvec(1));
mr_subset(k2,2) = c_loc(rvec(1));
if (ga_param.mr_wmin == ga_param.mr_wmax)
mr_w(k2,1) = ga_param.mr_wmin;
else
mr_w(k2,1) = ga_param.mr_wmin + rand()*(ga_param.mr_wmax - ga_param.mr_wmin);
end
end
if (ga_param.mr_r1_min == ga_param.mr_r1_max)
mr_r1 = ga_param.mr_r1_min;
mr_r2 = ga_param.mr_r2_max;
else
mr_r1 = ga_param.mr_r1_min + rand()*(ga_param.mr_r1_max - ga_param.mr_r1_min);
mr_r2 = ga_param.mr_r2_min + rand()*(ga_param.mr_r2_max - ga_param.mr_r2_min);
end
if (ga_param.mr_wmin == ga_param.mr_wmax)
m_maxrate = ga_param.m_maxrate_min;
else
m_maxrate = ga_param.m_maxrate_min + rand()*(ga_param.m_maxrate_max - ga_param.m_maxrate_min);
end
models{k} = struct('mr_subset',mr_subset,'mr_w',mr_w,'mr_loc',mr_loc,'mr_r1',mr_r1,'mr_r2',mr_r2,...
's_thresh',ga_param.s_thresh,'m_maxrate',m_maxrate,'d_mr',ga_param.d_mr,'spiking_type',ga_param.spiking_type);
[errs(:,k),m_spike_times,m_spike_rate,o_spike_times,o_spike_rate,m_t] = test_model(models{k},stim_list,sim_param);
end
errs_ug = errs;
for k=1:max(s_group)
inds = find(s_group == k);
errs2(k,:) = mean(errs(inds,:),1);
end
errs = errs2;
iter_i = 1;
while min(max(errs,[],1))>err_tol && iter_i<max_iter
Nunique = Nmodels;
for k = 2:Nmodels
found = 0;
for k2 = 1:(k-1)
if isequal(models{k},models{k2})
found = 1;
end
end
if found
Nunique = Nunique - 1;
end
end
disp([int2str(Nunique),' unique models'])
mutation_size = 1 - iter_i/max_iter;
models2 = new_models(models,ga_param,mutation_size);
errs2 = [];
parfor (k=1:Nmodels,3)
[errs2(:,k),m_spike_times,m_spike_rate,o_spike_times,o_spike_rate,m_t] = test_model(models2{k},stim_list,sim_param);
end
errs2_ug = errs2;
errs2g = [];
for k=1:max(s_group)
inds = find(s_group == k);
errs2g(k,:) = mean(errs2(inds,:),1);
end
errs2 = errs2g;
errs3 = [errs,errs2];
[vs,is] = sort(mean(errs3,1),'ascend');
inds = is(1:Nmodels);
inds1 = inds(find(inds<=Nmodels));
inds2 = inds(find(inds>Nmodels)) - Nmodels;
models3 = {models{inds1},models2{inds2}};
models = models3;
errs4 = [errs(:,inds1),errs2(:,inds2)];
errs = errs4;
errs4_ug = [errs_ug(:,inds1),errs2_ug(:,inds2)];
errs_ug = errs4_ug;
[vmin,imin] = min(mean(errs,1));
disp(['Iteration ',int2str(iter_i),': Minimal average error = ',num2str(round(vmin,2))])
str = '';
for k=1:size(errs,1)
str = [str,' ',num2str(round(errs(k,imin),2))];
end
disp(['Iteration ',int2str(iter_i),': best model errors = ',str])
str = '';
for k=1:size(errs_ug,1)
str = [str,' ',num2str(round(errs_ug(k,imin),2))];
end
disp(['Iteration ',int2str(iter_i),': best model errors, ungrouped = ',str])
save(['models/c',int2str(sim_param.cellnum),'_models_',ga_param.model_type,'_',int2str(ga_param.Nmr)],'models','errs');
iter_i = iter_i+1;
end
end