function H=plot(x,varargin)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% KDE plotting function
%
%  plot(kde)            -- plot a KDE with various features.
%  plot(kde,style)      --  style is of the form
%  plot(kde,dim,style)  --    [STR1,STR2,...] where
%
%  style is of the form [STR1 STR2 STR3...] where
%
%  STR1 is the style to plot the line (1D) or kernel locations (2+D).  Note
%     that options must be specified in lowercase, e.g. 'ro' for red circles.
%     Default style is '-b'
%  STRN is a style for various optional plot features:
%    'W' : show kernel weights (by color: black = low, color in STR1 = high)
%    'S' : show relative kernel sizes, as circles around each center
%    'Bs': show KD-tree structure / bounding boxes, 's' is e.g. '-b' for
%             bounding boxes of solid blue lines
%    'Nd':show d levels of the bounding boxes (d an integer) (default all)
%  The default style is '.b'
%    Example styles:   '*kS-b'  -- black centers, blue variance circles
%                      '.gWS-kB-rN3' -- green dots, colored by weights,
%                          with black variances and 3 levels of bounding
%                          boxes in red.
%                      'S-k' -- just plot variances (no centers), in black
%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

% Copyright (C) 2003 Alexander Ihler; distributable under GPL -- see README.txt

dim = []; args = [];
if (nargin == 1)
  dim = 1:getDim(x);
  if (getDim(x) == 1) args = 'b-'; else args = 'b.'; end;
elseif (nargin == 2)
  dim = [1:getDim(x)];
  args = varargin{1};
else 
  dim = varargin{1};
  args = varargin{2};
end;
%
% Separate the arguments into:  [plot info][TAG taginfo][TAG taginfo] ...
%   where TAG is one of 'W','B','S'.
%
F = find(args ~= lower(args)); Fmin = min([F,length(args)+1]);
argsPlot = args(1:Fmin-1);
argsKDE = args(Fmin:length(args));

if (getDim(x) == 1)
  pts = getPoints(x);
  N = 200;   range = [min(pts),max(pts)];
  range(1) = range(1) - .05*(range(2)-range(1));
  range(2) = range(2) + .05*(range(2)-range(1));
  H=draw1D(x,linspace(range(1),range(2),N),argsPlot,argsKDE);
else
  H=drawAllPairs(x,dim,argsPlot,argsKDE);
end;

%
% Internal functions
%
function e=draw1D(x, bins, style, myStyle)
  y = evaluate(x,bins);
  e=plot(bins,y,style);
  
  mx = max(y);
  wts = getWeights(x);
    
  if(strfind(myStyle,'W'))
    holdf = ishold; hold on;
    subStyle = extract(myStyle,'W');
    subStyle=subStyle(2:end); if (length(subStyle)==0) subStyle = '^'; end;
    etmp = stem(getPoints(x), wts, subStyle);
    e = [e;etmp];
    if (~holdf) hold off; end;
  end
  
  if(strfind(myStyle,'S'))
    holdf = ishold; hold on;
    subStyle = extract(myStyle,'S');
    subStyle=subStyle(2:end); if (length(subStyle)==0) subStyle = '--b'; end;

    bw = getBW(x); pts = getPoints(x); type = getType(x);
    for i=1:length(pts)
      xtmp = kde(pts(i), bw(i), 1, type);
      etmp = plot(bins, evaluate(xtmp, bins) * wts(i), subStyle);
      e = [e;etmp];
    end
    
    % for plotting only gaussian kernels faster:
%    gr = repmat(bins, length(wts), 1);
%    wts = repmat(wts', 1, size(gr, 2));
%    pts = repmat(getPoints(x)', 1, size(gr,2));
%    bw = repmat(getBW(x)', 1, size(gr, 2));
%    etmp = plot( bins, (wts.*(1./(sqrt(2*pi)*bw)).*exp(-(gr-pts).^2./(2*bw.^2)))', subStyle);

    e = [e;etmp];
    if (~holdf) hold off; end;
  end
  
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function e=drawAllPairs(x,dims,style,myStyle)
  pts = getPoints(x);
  holdf = ishold;
  e = [];
  Nout = length(dims);
  PlotI2 = triu(repmat(dims ,[Nout,1]), 1);
  PlotI1 = triu(repmat((dims)',[1,Nout]), 1);
  PlotI1 = PlotI1(find(PlotI1)); PlotI2 = PlotI2(find(PlotI2));
  Ncol = fix(sqrt(length(PlotI2))); Nrow = ceil(length(PlotI2)/Ncol);
  for iT=1:length(PlotI2)                   %  output all dimension pairs:
    subplot(Nrow,Ncol,iT);
    holdfs = ishold;
    if (holdf) hold on; end;
    etmp = drawPair(x,[PlotI1(iT),PlotI2(iT)],style,myStyle);
    if (~holdfs) hold off; end;
    e = [e;etmp];
  end;   
  drawnow;

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function etmp = drawPair(x,dims,style,myStyle)
  pts = getPoints(x);
  etmp = [];
  
  if (strfind(myStyle,'W'))  %plot weight info with color
    subStyle = extract(myStyle,'W');
    %wts = getWeights(x);  wts = wts - min(wts); 
    %if (max(wts)==0) wts = wts+1; else wts = .75*wts/max(wts) + .25; end;
    wts = getWeights(x);  
    wts = .75*wts/max(wts) + .25;
    holdf = ishold;
    for i=1:size(pts,2)
      etmp2 = plot(pts(dims(1),i),pts(dims(2),i),style);   % plot each location and
      hold on;
      etmp = [etmp;etmp2]; ctmp = get(etmp2,'Color');      %  its color in prop.
      set(etmp2,'Color',[1 1 1] - wts(i)*([1 1 1] - ctmp));%  to its weight
    end;
    if (~holdf) hold off; end;
  else                                                    % otherwise, just plot
    if (length(style)~=0)
      etmp = plot(pts(dims(1),:),pts(dims(2),:),style);   %   all locations
  end; end;
  
  if (strfind(myStyle,'S'))  %plot BW info with circles
    subStyle = extract(myStyle,'S');
    subStyle=subStyle(2:end); if (length(subStyle)==0) subStyle = '-b'; end;
    pts = getPoints(x); sig=getBW(x);
	meanX = pts(dims(1),:);
	meanY = pts(dims(2),:);
	sigX = sig(dims(1),:);
	sigY = sig(dims(2),:);
	
	holdf = ishold;	hold on;
	theta = linspace(0,2*pi,100);
	for (i = 1:length(meanX))
      etmp2 = plot(meanX(i)+sigX(i)*cos(theta), meanY(i)+sigY(i)*sin(theta),subStyle);
      etmp = [etmp;etmp2];
	end	
	if (~holdf) hold off; end;    
  end;
  
  if (strfind(myStyle,'B'))  %plot bounding box info with rectangles
    levels = [];
    if (isempty(strfind(myStyle,'N'))) levels = ceil(log2(getNpts(x)));
    else
      subStyle = extract(myStyle,'N');              % get # of balls if spec'd
      levels = sscanf(subStyle(2:end),'%d');
      levels = min(ceil(log2(getNpts(x))), levels);
    end;
    subStyle = extract(myStyle,'B');              % get plot style if spec'd
    subStyle=subStyle(2:end);
    if(isempty(subStyle))     subStyle = '-b';    end
      
    N = getNpts(x);
    indices = []; tmp = [1];
    for i=1:levels
      indices = [indices tmp];
      % get rid of NO_CHILD right children
      rt = double(x.rightch(tmp)) + 1;
      rt = rt .* (rt < N+1) + 1 * (rt > N);
      tmp = [double(x.leftch(tmp))+1   rt];
    end
    
    rX = x.ranges(dims(1),indices);
    rY = x.ranges(dims(2),indices);
    nodesX = x.centers(dims(1),indices);
    nodesY = x.centers(dims(2),indices);
    leavesX = x.centers(dims(1),N+1:end);
    leavesY = x.centers(dims(2),N+1:end);

    squaresX(1,:) = nodesX + rX;
    squaresX(2,:) = nodesX + rX;
    squaresX(3,:) = nodesX - rX;
    squaresX(4,:) = nodesX - rX;
    squaresX(5,:) = nodesX + rX;
    
    squaresY(1,:) = nodesY + rY;
    squaresY(2,:) = nodesY - rY;
    squaresY(3,:) = nodesY - rY;
    squaresY(4,:) = nodesY + rY;
    squaresY(5,:) = nodesY + rY;
    
    holdf = ishold;	hold on;
    for i = 1:length(indices)
      etmp2 = plot(squaresX, squaresY, subStyle);
      etmp = [etmp;etmp2];
    end
%   axis square	
    if (~holdf) hold off; end;
    
  end;  
  
  titlestr = ['Dim ',int2str(dims(1)),' v. ',int2str(dims(2))];
  title(titlestr);

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function substr=extract(str,tag)
  substr = [];
  loc = strfind(str,tag);
  if (loc)
    substr = str(loc(1):length(str));
    loc = find(substr(2:end) ~= lower(substr(2:end)));
    if (loc) substr = substr(1:loc(1)); end;
  end;