addpath('../UtilityFunctions');
%% - flags
make_weights = 1;
simulate_network = 1;
pattern_completion_curve_sim = 1;
pattern_completion_curve_plot = 1;
plot_results = 1;
%% - network params
dt = 1;
tau = 10;
NE = 400;
NI = 400;
N = NE+NI;
conn_spar_EE = 1;
N_itr = 30;
eta = 1/5;
m = 0;
mEE = m;
mEI = m/2;
mIE = m/2;
mII = m/2;
po_exc = linspace(0,pi,NE);
po_inh = linspace(0,pi,NI);
po_all = [po_exc, po_inh];
pert_size = .1;
l0 = 0.8;
ts_dur = [10,20,50,100];
N_rep = 10;
t_betw = 50;
Ns_pert =[10,20,50,100,200];
% - sample plot
k1_s = 2;
k2_s = 3;
%%
r_all = {};
w_all = {};
evm = zeros(2,N_itr);
evecm = zeros(2, N);
evecmall = zeros(2,N_itr+1,N);
ccm = {};
ccm0 = {};
ks = [1, 4];
for gi = [2, 1]
k = ks(gi)
J0 = 1/NE*2;
JEE = J0 /conn_spar_EE;
JEI = J0 *k;
JIE = -J0 *k;
JII = -J0 *k;
%% - weight matrix
if make_weights
wEE = zeros(NE,NE);
wEI = zeros(NE,NI);
for i = 1:NE
wEE(i,:) = (1 + mEE * cos(2*(po_exc(i) - po_exc))) * JEE;
wEI(i,:) = (1 + mEI * cos(2*(po_exc(i) - po_inh))) * JEI;
end
wIE = zeros(NI,NE);
wII = zeros(NI,NI);
for i = 1:NI
wIE(i,:) = (1 + mIE * cos(2*(po_inh(i) - po_exc))) * JIE;
wII(i,:) = (1 + mII * cos(2*(po_inh(i) - po_inh))) * JII;
end
wEE = wEE .* binornd(1,conn_spar_EE,[NE,NE]);
w = [wEE, wEI
wIE, wII];
w = w .*(1+1*(rand(N,N)-.5));
w(1:NE,:) = rectify(w(1:NE,:));
w(NE+1:end,:) = -rectify(-w(NE+1:end,:));
w(eye(N)==1) = 0;
end
%% - simulate
if simulate_network
N_pert = Ns_pert(k1_s);
t_dur = ts_dur(k2_s);
t_betw = t_dur;
wp = w;
[V,D] = eig(w);
ev_before = diag(D);
[evh,b] = max(real(diag(D)));
evecmall(gi,1,:) = V(:,b);
% - perturb.
pert_ids = 1:N_pert;
if k == 1; pert_size = .1;
elseif k == 4; pert_size = .1;
end
w_all{gi}{1} = w;
evm(gi,1) = max(real(ev_before));
for kk = 1:N_itr
[r,v, dw] = pert_sim_short(N, NE, wp, dt, tau, pert_ids, pert_size, t_dur, N_rep, t_betw);
r_all{gi}{kk} = r;
zz = dw(1:NE,1:N);
zz(isnan(zz))=0;
wp_test = wp;
wp_test(1:NE,1:NE) = wp_test(1:NE,1:NE) + eta * zz(1:NE,1:NE);
[V,D] = eig(wp_test);
evh = max(real(diag(D)));
if evh < l0
wp = wp_test;
end
w_all{gi}{kk+1} = wp;
[V,D] = eig(wp);
[evh,b] = max(real(diag(D)));
evm(gi, kk+1) = evh;
evecmall(gi,kk+1,:) = V(:,b);
end
[V,D] = eig(wp);
[a,b] = max(real(diag(D)));
evecm(gi,:) = V(:,b);
% - pattern completion
pert_ids = 1:N_pert/2;
N_rep = 1;
t_dur = 300;
t_betw = 0;
if k == 1; pert_size = 1;
elseif k == 4; pert_size = .1;
end
% - just one
pert_ids = 1:N_pert/2;
% before induction/learning
[r_pc0{gi},~, ~] = pert_sim_short(N, NE, w, dt, tau, pert_ids, pert_size, t_dur, N_rep, t_betw);
% after induction/learning
[r_pc{gi},~, ~] = pert_sim_short(N, NE, wp, dt, tau, pert_ids, pert_size, t_dur, N_rep, t_betw);
% - for different fractions
if pattern_completion_curve_sim
ccm{gi} = zeros(1,N_pert-1);
ccm0{gi} = zeros(1,N_pert-1);
for i = 1:N_pert-1
pert_ids = 1:i;
[r,~, ~] = pert_sim_short(N, NE, wp, dt, tau, pert_ids, pert_size, t_dur, N_rep, t_betw);
cc = corr(r(1:NE,:)');
r0 = nanmean(r(:,300:300+t_dur),2);
rp = nanmean(r(:,50:300),2);
dr = (rp - r0);
ccm{gi}(i) = nanmean(dr(i+1:N_pert)) ./ nanmean(dr(1:i));
ccm0{gi}(i) = nanmean(dr(N_pert+1:NE)) ./ nanmean(dr(1:i));
end
end
end
end
%% -
if plot_results
for gi = 1:2
k = ks(gi)
% - sample weights
ws_id = [0, N_itr/2, N_itr]+1;
ttls = {'Initial', 'Middle', 'Final'};
figure('Position',[100,100,500,140])
for i = 1:3
subplot(1,3,i);
if k == 1
title(ttls{i}, 'FontWeight', 'normal');
end
hold on
imagesc(w_all{gi}{ws_id(i)}(1:N_pert,1:N_pert));
colorbar()
axis image
xticks([1,N_pert]);
yticks([1,N_pert]);
if i ~= 1
xticklabels([]); yticklabels([]);
end
if i == 3
xlabel('post #');
ylabel('pre #');
end
set(gca, 'LineWidth', 1, 'FontSize', 15, 'Box', 'off', 'TickDir', 'out', 'ydir', 'normal')
end
print(['k' num2str(k) '_sampleW.png'], '-dpng', '-r300');
% - pattern completion
figure('Position',[100,100,270,225])
hold on
plot(r_pc{gi}(1+N_pert:NE,:)', '-', 'color', [.7,.7,.7]);
plot(r_pc{gi}(1:N_pert/2,:)', '-', 'color', 'r');
plot(r_pc{gi}(1+N_pert/2:N_pert,:)', '-', 'color', [1,0.5,0]);
xlabel('Time')
ylabel('Activity')
if k == 4
yticks([0, .1, .2, .3])
ylim([0, .3]);
else
ylim([0,8])
end
set(gca, 'LineWidth', 1, 'FontSize', 15, 'Box', 'off', 'TickDir', 'out', 'ydir', 'normal')
print(['k' num2str(k) '_patternCompl.png'], '-dpng', '-r300');
% - before induction:
if k == 1
figure('Position',[100,100,250,225])
elseif k == 4
figure('Position',[100,100,270,225])
end
hold on
plot(r_pc0{gi}(1+N_pert:NE,:)', '-', 'color', [.7,.7,.7]);
plot(r_pc0{gi}(1:N_pert/2,:)', '-', 'color', 'r');
plot(r_pc0{gi}(1+N_pert/2:N_pert,:)', '-', 'color', [1,0.5,0]);
xlabel('Time')
ylabel('Activity')
if k == 4
yticks([0, .1, .2, .3])
ylim([0, .3]);
elseif k == 1
ylim([0,8]);
end
set(gca, 'LineWidth', 1, 'FontSize', 15, 'Box', 'off', 'TickDir', 'out', 'ydir', 'normal')
print(['k' num2str(k) '_patternCompl_beforeInduction.png'], '-dpng', '-r300');
% - pattern completion curve
if pattern_completion_curve_plot
figure('Position',[100,100,225,225])
hold on
h1 = plot((1:N_pert-1)/N_pert*100, ccm{gi}*100, '-o', 'linewidth',2, 'color', [1,0.5,0]);
h2 = plot((1:N_pert-1)/N_pert*100, ccm0{gi}*100, '-o', 'linewidth',2, 'color', [.7,.7,.7]);
if k == 4
h3 = plot((1:N_pert-1)/N_pert*100, ccm{1}*100, '--', 'linewidth',2, 'color', [1,0.5,0]);
h4 = plot((1:N_pert-1)/N_pert*100, ccm0{1}*100, '--', 'linewidth',2, 'color', [.7,.7,.7]);
end
if k == 4
legend([h1,h2], {'within', 'outside'}, 'location', 'best')
legend boxoff
end
xlabel('Partial activ. (%)');
ylabel('Fraction resp. (%)');
yticks([0, 25, 50, 75, 100])
ylim([-10,80])
set(gca, 'LineWidth', 1, 'FontSize', 15, 'Box', 'off', 'TickDir', 'out')
print(['k' num2str(k) '_patternCompl_curve.png'], '-dpng', '-r300');
end
end
end