function [yR,yRR,S,R,indR,indRR,tr]=wavpackfilter(s,treethr,psig,prunethr,plottree)
%
%[yR,yRR,S,R,indR,indRR,T]=wavpackfilter(s,treethr,psig,prunethr,plottree);
%
%s: signal to be filtered
%treethr: threshold on interpeak interval variance for tree analysis
%(default = 0.005)
%psig: average signal power significance (default = 0.05)
%prunethr: threshold on interpeak interval variance for final pruning
%(default = treethr)
%plottree = 0 to turn off plot of wavelet packet tree during execution
%
%yR: filtered signal
%yRR: extracted signal (reconstructed from filtered components)
%S: synthesized signal with all nodes included
%R: matrix of terminal tree node reconstructions
%indR: index of terminal tree nodes used in constructing yR
%indRR: index of terminal tree nodes used in constructing yRR
%T: returned wavelet packet binary tree
%
%
%Osbert Zalay, August 2007


%nargin = number of variables passed to the function,
%used to determine if the defaults are required.

if nargin < 2
    treethr=0.005;
    psig=0.05;
    prunethr=treethr;
    plottree=1;
end
if nargin < 3    
    psig=0.05;
    prunethr=treethr;
    plottree=1;
end
if nargin < 4
    prunethr=treethr;
    plottree=1;
end
if nargin < 5
    plottree=1;
end

%calls three internal functions:
%1.wavepackdist, 2.treeprune, 3.extractRebuild

tr=wavpackdist(s,treethr,psig,plottree);
[yR,R,S,indR]=treeprune(tr,prunethr);
[yRR indRR]=extractRebuild(R,indR);

function [T,v0,p0] = wavpackdist(x,treethr,pthr,pltflg)
mu=1e-3; %mu = 0.001
[m,n]=size(x); %x is signal, m,n are the rows and columns of the signal
if n > m
    x=x.';
    m=n;
end
lg2m=log2(m);
T=wpdec(x,0,'dmey');
[v0 p0]=ithnodestat(T,0);
if (pltflg ~= 0) 
    fig = plot(T);
end
npar2=leaves(T);
npar1=NaN;
stopval=10^10;
while ~isequal(npar1,npar2)
    npar1=npar2;
    for i=1:length(npar1)
        T=wpsplt(T,npar1(i));
        kids=nodedesc(T,npar1(i));
        [v1 p1]=ithnodestat(T,kids(1));
        [v2 p2]=ithnodestat(T,kids(2));
        [v3 p3]=ithnodestat(T,kids(3));
        %if (v1 < treethr) | ((log2(kids(2)+2)-1) > lg2m)...
        if (v1 < treethr) | (ceil(log2(kids(2)+2)-1) > lg2m)...
                | (p1 < pthr*p0) | (v1 == stopval)
            T=wpjoin(T,npar1(i));
        end   
    end
    pause(0.1)
    if (pltflg ~= 0)
        plot(T,fig);
    end
    npar2=leaves(T);
    endnodestat(T);
end

function [yR,R,S,indR,va,pa] = treeprune(T,thr)
[R,S,indx]=wpreconstruct(T);
[va,pa]=endnodestat(T);
[m,lenR]=size(R);
indR=zeros(lenR,1);
yR=zeros(m,1);
count=0;
ind=0;
for i=1:lenR
    ind=lenR-i+1;
    if va(ind) >= thr
        yR=yR+R(:,ind);
        vR=intvar(yR,1);
        if vR < thr
            yR=yR-R(:,ind);
        else
           count=count+1;
           indR(count)=ind;
        end
    end
end
indR=flipud(indR(1:count,1));

function [va,pa] = endnodestat(wtr)
N=leaves(wtr);
lenN=length(N);
va=zeros(1,lenN);
pa=zeros(1,lenN);
for i=1:lenN
    [va(i) pa(i)]=ithnodestat(wtr,N(i));
end

function [v p] = ithnodestat(wtr,N)
r=wprcoef(wtr,N);
lenr=length(r);
v=intvar(r,1);
p=sum(r.^2)/lenr;

function [varint]=intvar(s,flt)
x=wkeep(s,round(length(s)*0.9));
lenx=length(x);
interval=zeros(lenx,1);
fltr=[1 1 1]/3; 
x1=x(1); x2=x(lenx); 
for j=1:flt
	c=conv(fltr,x);
	x=c(2:lenx+1);
	x(1)=x1;  
    x(lenx)=x2; 
end
count=0;
keepgoing=1; start=2; k=start;
while (keepgoing) & (k < lenx)
    if x(k-1)<x(k) & x(k+1)<x(k)
        maxind1=k;
        keepgoing=0;
    end
    k=k+1;
    start=k;
end
if (start < lenx)
    for i=start:(lenx-1)
        if x(i-1)<x(i) & x(i+1)<x(i)
            count=count+1;
            maxind2=i;
            interval(count)=maxind2-maxind1;
            maxind1=maxind2;
        end
    end
end
interval=interval(find(interval));
if length(interval)<2
    varint=10^10;
else
    interval=interval./mean(interval);
    varint=var(interval);
end

function [R,S,indx] = wpreconstruct(wtr)
N = leaves(wtr);
lenN = length(N);
S = wprcoef(wtr);
lenS=length(S);
R=zeros(lenS,lenN);
for i=1:lenN
    R(:,i)=wprcoef(wtr,N(i));
end
indx=[[1:lenN].' N];

function [yRR indRR]=extractRebuild(R,indR)
[m,lenR]=size(R);
lenI=length(indR);
indRR=zeros(lenR,1);
count=1;
for i=1:lenR
    ind=find(indR==i);
    if isempty(ind)
        indRR(count)=i;
        count=count+1;
    end
end
indRR=indRR(find(indRR));
lenIR=length(indRR);
yRR=zeros(m,1);
for i=1:lenIR
    yRR=yRR+R(:,indRR(i));
end