clear all
close all
clc
%% System parameteres 
T = 600; %total time 
dt = 1e-2; %integration time step 
nt = round(T/dt); %number of steps 
gstar = 1.5; %initial coupling strength (scaled) 
N = 2000; %number of neurons 
omega = randn(N,N)/sqrt(N); %static weights.  
imin = round(50/dt); %start RLS 
imax = round(300/dt); %stop RLS
z0 = randn(N,1);  %initial condition for training 
time = (1:nt)*dt; %total time. 
m = 1/2; %threshold power law power. 
q = 1/(m-1); %rescaling coefficient.
%% Generate inputs for pitchfork system 
inp = 0;
k = 0;
for i = 1:nt
if mod(i,round(5/dt))==1 
k = k + 1;
z = randn; %random input. 
end
inp(i,1) = z*mod(k,2); %0 every other time. 
end

%% supervisor system
sup = zeros(nt,1);
for j = 2:nt
sup(j,1) = sup(j-1,1) + dt*(sup(j-1,1)*0.5 - sup(j-1,1)^3 + inp(j-1,1));
end

%% Dimensions/parameters
dim = size(sup,2); %dimension of supervisor 
eta = (2*rand(N,dim)-1); %encoder
win = 2*rand(N,1)-1; %input weights
phi = zeros(N,dim); %decoder 

%% RLS parameters 
alpha = 1; 
P = (1/alpha) * eye(N);
train = 1;

%% Train with RLS
[storexl,storerl,phi] = RLS_net2(N,z0,gstar,q,omega,phi,eta,nt,imax,imin,P,m,dt,sup,train,inp,win);
%% Test with a new initial condition. 
%% Generate a test input 
inp1 = 0;
k = 0;
for i = 1:nt
if mod(i,round(5/dt))==1 
k = k + 1;
z = randn;
end
inp1(i,1) = z*mod(k,2);
end
%% Generaet a test initial condition. 
sup1 = zeros(nt,1);
for j = 2:nt
sup1(j,1) = sup1(j-1,1) + dt*(sup1(j-1,1)*0.5 - sup1(j-1,1)^3 + inp1(j-1,1));
end
%% train at gstar. 
z0 = randn(N,1); 
train = 0;
disp('Training')
[storext,storet] =  RLS_net2(N,z0,gstar,q,omega,phi,eta,nt,imax,imin,P,m,dt,sup1,train,inp1,win);
disp('Testing Complete')
%% Change g/rescale the reservoir dynamics.  
g = 1.9;
z1 = ((gstar/g)^q)*z0;
phi_hat = phi*(g/gstar)^(q*m);
eta_hat = eta*(gstar/g)^(q);
win_hat = win*(gstar/g)^q;
%% Test with the rescaled reservoir
train = 0; %flag to have RLS on, or do full testing. 
disp('Testing')
[storext2,store2] = RLS_net2(N,z1,gstar,q,omega,phi,eta,nt,imax,imin,P,m,dt,sup1,train,inp1,win);
disp('Testing Complete')

disp('Testing at rescaled g')
[storext2,store2] = RLS_net2(N,z1,g,q,omega,phi_hat,eta_hat,nt,imax,imin,P,m,dt,sup1,train,inp1,win_hat);
disp('Testing Complete')
%% 
save sim_2.mat -v7.3 

%% Plotting. 
f1 = figure(1);
clf
mx = 4;
my = 2;
subplot(mx,my,1:2)
plot(time,sup,'k','LineWidth',2), hold on 
plot(time,storexl,'LineWidth',2), hold on 
title('Training a threshold power-law RNN on the pitchform system')
patch([dt*imin,dt*imax,dt*imax,dt*imin],[-2,-2,2,2],'b','edgealpha',0,'facealpha',0.1)
text(100,1,'RLS ON')
xlim([0,400])
xlabel('Time')
legend('Superivsor','Network Output')
ylim([-1.5,1.5])
subplot(mx,my,3:4)
plot((1:nt)*dt,inp,'LineWidth',2)
title('Input Signal')
xlabel('Time')
xlim([0,400])
ylim([-3,3])
subplot(mx,my,5)
plot(time,sup1,'k','Linewidth',2), hold on 
plot(time,storext,'b','Linewidth',2), hold on 
plot(time,storext2,'r--','Linewidth',2), hold off
legend('Supervisor','Network 1','Network 2 (Rescaled)')

ylim([-1.7,1.7])
xlim([0,400])
xlabel('Time')
subplot(mx,my,7)
plot((1:nt)*dt,inp1,'LineWidth',2)
title('Input')
xlabel('Time')
xlim([0,400])
ylim([-4,4])
subplot(mx,my,6)

plot(time,sup1,'k','Linewidth',2), hold on 
plot(time,storext,'b','Linewidth',2), hold on 
plot(time,storext2,'r--','Linewidth',2), hold off
xlabel('Time')
title('Zoom')
ylim([-1.5,1.5])
xlim([350,400])
subplot(mx,my,8)
plot((1:nt)*dt,inp1,'LineWidth',2)
xlim([350,400])
ylim([-3,3])
xlabel('Time')
title('Zoom')
f1 = figure(1);
set(f1,'position',[0,0,800,1000])
print(f1,'f1.svg','-dsvg','-painters')