%% set parameters
clearvars
%----------------------network params
options=[2 4 4]; %number of task variable identity options per task variable
remove_doubles=1; % if 1, don't allow 1st and 2nd cue to be the same
th_var=.27; %threshold parameter
icon_prob=.25; %connection probability
baseline=50; %baseline number of cells per population
incs_vec=[1.6,1,1.2]; %relative number of cells in each task variable variable population
imean_weight=.207; %mean of weight distribution
iwsig_perc=1; %ratio of std of weight distribution to mean
dist_var='gaus'; %distribution of weights. Options: 'gaus','powr','expo','logn','unif'
trials=10; %trials per condition
%-----------------learning params
Nl=3; %Number of populations to increase weights to
max_step=8; %total number of learning steps (1st step=random network)
freelearn=1; %1 for free learning, 0 for constrained
delr=.2; %learning rate
%-----------------Activity params
mFR=4.9; %desired mean firing rate
gnois_var=.88; %multiplicative noise parameter
addsig_perc=5.8; %additive noise parameter
Nuse=90; %number of "PFC" cells to use in analysis. total PFC population size is equal to total input population size
%------------necessary computations
inputs=length(options);
cue_perms=prod(options(1:end));
CellsOptsInps=ones(size(options))*baseline;
OptsInt=options;
CellsOptsInps=ceil(CellsOptsInps.*incs_vec);
InpN=sum((OptsInt.*CellsOptsInps)); RanN=floor(InpN);
rw_cells=[1:RanN]; %if I want to only have subset of cells do learning
% make matrix of input activity for all conditions
cue_mat=NaN(cue_perms, inputs);
for oi=1:length(options)
n=options(oi);
if oi<length(options)
reps=prod(options(oi+1:end));
else
reps=1;
end
chunk_reps=cue_perms/(reps*n);
for cri=1:chunk_reps
for pn=1:n
starti=1+(cri-1)*n*reps+(pn-1)*reps;
cue_mat(starti:starti+reps-1,oi)=pn*ones(reps,1);
end
end
end
if remove_doubles
lcue=cue_mat(:,end);
samecue_inds=find(cue_mat(:,end-1)==lcue);
cue_mat(samecue_inds,:)=[];
cue_perms=size(cue_mat,1);
end
clear inp_starts; clear inp_ends
inp_starts(1)=1; ii=1; %find input boundaries
for j=1:sum(OptsInt)
inp_ends(j)=inp_starts(j)+CellsOptsInps(ii)-1;
inp_starts(j+1)=inp_ends(j)+1;
if j==sum(OptsInt(1:ii))
ii=ii+1;
end
end
addsig=imean_weight*addsig_perc;
iwsig=iwsig_perc*imean_weight;
%% do learning and create activity for each condition
for rwi=1:max_step %through all learning steps
disp(['learning step: ',num2str(rwi)])
if rwi==1 %make initial random weight matrix
[W, imean_weight, iwsig]=weight_maker(RanN,InpN,imean_weight,iwsig,dist_var);
W=W.*double(rand(RanN,InpN)<=icon_prob); %connection probability
W(W<0)=0; %only positive weights
threshes=thvar_er(th_var, W); %calculate thresholds for each cell
else %apply learning step
[cx, cy]=find(W>0);
if freelearn %increase input populations without regard to task variable
for c1=1:length(rw_cells)
c=rw_cells(c1);
c_inps=cy(cx==c);
cinps_val=W(c,c_inps);
for ii=1:length(inp_starts)-1 %find max based on weights
ind_vec=[((c_inps>=inp_starts(ii)).*(c_inps<=inp_ends(ii)))==1];
hvec(ii)=sum(cinps_val(ind_vec));
end
[val(c,:), mi(c,:)]=sort(hvec,'descend'); %sort populations according to total input weight
mi_inps=squeeze(mi(c,:));
add_inds=[];
for ci=1:Nl
add_inds=[add_inds, [inp_starts(mi_inps(ci)):inp_ends(mi_inps(ci))]]; %create list of inputs to be increased
end
nadd_inds=1:InpN; nadd_inds=setdiff(nadd_inds, add_inds);
t_inp=sum(W(c,:));
W(c,add_inds)=W(c,add_inds).*(1+delr); %increase weights
ai_sum=sum(W(c,add_inds));
W(c,:)=(W(c,:)./sum(W(c,:))).*(t_inp); %normalize
end
else %constrained learning
for c=1:length(rw_cells)
c=rw_cells(c);
c_inps=cy(cx==c);
cinps_val=W(c,c_inps); % allow rw to noise
class1=1; clear classmax_ii classmax_val
for oii=1:length(OptsInt) %go over all input types
class_end=class1+OptsInt(oii)-1;
clear hvec
ii1=0;
for ii=class1:class_end %find max based on weights
ii1=ii1+1;
ind_vec=[((c_inps>=inp_starts(ii)).*(c_inps<=inp_ends(ii)))==1];
hvec(ii1)=sum(cinps_val(ind_vec));
end
ii_inds=[class1:class_end];
[val, mi]=sort(hvec,'descend');
classmax_ii(oii)=ii_inds(mi(1));
classmax_val(oii)=val(1);
classmax_ii2(oii)=ii_inds(mi(2)); %if Nl>3, will need to know 2nd strongest input population from each task variable
classmax_val2(oii)=val(2);
class1=class_end+1;
end
[valB, miB]=sort(classmax_val,'descend'); %order each input population with constraint that they are from different task variables
[valB2, miB2]=sort(classmax_val2,'descend');
mi_inps=[classmax_ii(miB), classmax_ii2(miB2)];
add_inds=[];
for ci=1:Nl
add_inds=[add_inds, [inp_starts(mi_inps(ci)):inp_ends(mi_inps(ci))]];
end
nadd_inds=1:InpN; nadd_inds=setdiff(nadd_inds, add_inds);
t_inp=sum(W(c,:));
W(c,add_inds)=W(c,add_inds).*(1+delr); %increase weights
ai_sum=sum(W(c,add_inds));
W(c,:)=(W(c,:)./sum(W(c,:))).*(t_inp); %normalize
end
end %end of learning type conditional
end
W(W<0)=0; %theres no reason that any weights should've gone negative with learning, but just to be safe.
W_all(rwi,:,:)=W;
%% run simulation with current weights
activity1=runRCN_track( W, threshes, cue_mat, addsig, RanN, InpN, OptsInt,trials,inp_starts, inp_ends); %unscaled activity
tnum=(mFR*.9)/mean(activity1(:)); %scale to rough area of desired mean, before multiplicative noise is added
activity2=activity1(:,:,:).*tnum;
g_add=normrnd(activity2,activity2.*gnois_var); %add multiplicative noise
g_add(g_add<0)=0; %mult noise will likely lead to negative firing rates, so fix that
rinds=1:Nuse;
addFR=mFR-mean(g_add(:)); %in case multiplicative noise didn't end up skewing things up enough
activity=g_add(rinds,:,:)+ max(addFR,0);
selec_mats=perf_ANOV(activity,cue_mat);
sel_matfull(rwi,:,:)=selec_mats.full; sel_matinter(rwi,:,:)=selec_mats.inter; %using 3-way or 2-way ANOVA
clustval=calc_clustering(activity, cue_mat);
ClustVals(rwi)=clustval(2);
activity_all(rwi,:,:)=mean(activity,3);
meanz=mean(activity(:,:,:),3); varz=(std(activity(:,:,:),[],3)).^2;
for c=1:size(activity,1)
ffts(c)=nanmean(squeeze(varz(c,:))./squeeze(meanz(c,:)));
end
meanzz=mean(meanz,2); varzz=std(meanz,[],2).^2; RVs=(varzz./meanzz);
FRMs=mean(mean(activity,3),2);
FFAs_rwi(rwi,:)=[mean(RVs), mean(ffts), mean(FRMs)]; %RV, FF_T, and mean FR
end
%% plot properties over the course of learning
figure; % FF_T, RV, Mixed, Pure, Clust
subplot(2,3,1)
hold all
plot(FFAs_rwi(:,2))
plot([1,max_step],[2.8,2.8],'k:')
title('FF_T')
subplot(2,3,2)
hold all
plot(FFAs_rwi(:,1))
plot([1,max_step],[1.1,1.1],'k:')
title('RV')
subplot(2,3,3)
hold all
plot(sel_matfull(:,end,end))
plot([1,max_step],[.51,.51],'k:')
title('% Mixed')
subplot(2,3,4)
hold all
plot(sel_matfull(:,end-1,end-1))
plot([1,max_step],[.86,.86],'k:')
title('% Pure')
subplot(2,3,5)
hold all
plot(ClustVals)
plot([1,max_step],[186,186],'k:')
xlabel('Learning Steps')
title('Clustering Value')
subplot(2,3,6)
hold all
bar([0.6556, 0.3333, 0.5333, 0.2667, 0.2778, 0.1444, 0.0778],'k')
plot(diag(squeeze(sel_matfull(end,1:end-2,1:end-2))))
title('Selectivity at End') %full selectivity profile after the last learning step