//////////////////////////////////////////////////////////////////////////////////////
// BallTreeDensity.h  --  class definition for a tree-based kernel density estimate
//
// A few functions are defined only for MEX calls (construction & load from matlab)
// Most others can be used more generally.
// 
//////////////////////////////////////////////////////////////////////////////////////
//
// Written by Alex Ihler and Mike Mandel
// Copyright (C) 2003 Alexander Ihler; distributable under GPL -- see README.txt
//
//////////////////////////////////////////////////////////////////////////////////////
#ifndef __BALL_TREE_DENSITY_H
#define __BALL_TREE_DENSITY_H

#include "BallTree.h"
#include <assert.h>
#include <float.h>

class BallTreeDensity : public BallTree {
 public:
  enum KernelType { Gaussian, Epanetchnikov, Laplacian };
  KernelType getType(void) const { return type; };

  enum Gradient { WRTMean, WRTVariance, WRTWeight };

  /////////////////////////////
  // Constructors
  /////////////////////////////

  //BallTreeDensity( unsigned int d, index N, double* points_,
  //     double* weights_, double* bandwidths_);
#ifdef MEX             // for loading ball trees from matlab
  BallTreeDensity() : BallTree() { bandwidth = bandwidthMax = bandwidthMin = NULL; }
  BallTreeDensity(const mxArray* structure);
  static mxArray* createInMatlab(const mxArray* pts, const mxArray* wts, const mxArray* bw, BallTreeDensity::KernelType _type=Gaussian);
#endif

  /////////////////////////////
  // Accessor Functions  
  /////////////////////////////
  const double* mean(BallTree::index i)     const { return means+i*dims; }
  const double* variance(BallTree::index i) const { return bandwidth+i*dims; } // !!! only works for Gaussian

  const double* bw(BallTree::index i)     const { return bandwidth   +i*dims; }
  const double* bwMax(BallTree::index i)  const { return bandwidthMax+i*dims*multibandwidth; }
  const double* bwMin(BallTree::index i)  const { return bandwidthMin+i*dims*multibandwidth; }
  bool bwUniform(void) const { return multibandwidth==0; };
  //   -- Others inherited from BallTree --

///////////////////////////////
//
// Evaluation of the density at a set of points:
//   pre-constructed balltree version
//   array of doubles version
//   leave-one-out cross-validation version
//
  void evaluate(const BallTree& atPoints, double* values, double maxErr=0) const;
//  void evaluate(index Npts, const double* atPoints, double* values, double maxErr=0) const;
  void evaluate(double* p, double maxErr) const {  evaluate(*this,p,maxErr); }

  void llGrad(const BallTree& locations, double* gradDens, double* gradAt, double tolEval, double tolGrad, Gradient) const;
//  void llGrad(index Npts, const double* atPoints, double* gradDens, double* gradAt, double tolEval, double tolGrad) const;
 
  bool updateBW(const double*, index);




  /////////////////////////////
  // Private object functions
  /////////////////////////////
 protected:
#ifdef MEX
  static mxArray* matlabMakeStruct(const mxArray* pts, const mxArray* wts, const mxArray* bw, BallTreeDensity::KernelType type);
#endif
  virtual void swap(BallTree::index, BallTree::index);// leaf-swapping function
  virtual void   calcStats(BallTree::index root);     // recursion for computing BW ranges

  KernelType type;
  unsigned int multibandwidth;    // flag: is bandwidth uniform?

  double *means;                  // Weighted mean of points from this level down     
  double *bandwidth;              // Variance or other multiscale bandwidth
  double *bandwidthMax,*bandwidthMin; // Bounds on BW in non-uniform case
  
  // Internal evaluate functions:
  //   Recursive tree evaluation
  const static index DirectSize = 100;        // if N*M is less than this, just compute.

  void evaluate(BallTree::index myRoot, const BallTree& atTree, BallTree::index aRoot, double maxErr) const;
  void evalDirect(BallTree::index myRoot, const BallTree& atTree, BallTree::index aRoot) const;

  void llGradDirect(BallTree::index dRoot, const BallTree& atTree, BallTree::index aRoot, Gradient) const;
  void llGradRecurse(BallTree::index dRoot,const BallTree& atTree, BallTree::index aRoot, double tolGrad, Gradient) const;
  void llGradWDirect(index dRoot, const BallTree& atTree, index aRoot) const;
  void llGradWRecurse(index dRoot,const BallTree& atTree, index aRoot, double tolGrad) const;


  //   Bounds on kernel values between points in this subtree & another
  double maxDistKer(BallTree::index dRoot, const BallTree& atTree, BallTree::index aRoot) const { 
    switch(getType()) 
      { case Gaussian:       return maxDistGauss(dRoot,atTree,aRoot);
        case Laplacian:      return maxDistLaplace(dRoot,atTree,aRoot); 
        case Epanetchnikov:  return maxDistEpanetch(dRoot,atTree,aRoot); 
      }
    };
  double minDistKer(BallTree::index dRoot, const BallTree& atTree, BallTree::index aRoot) const {
    switch(getType()) 
      { case Gaussian:       return minDistGauss(dRoot,atTree,aRoot);
        case Laplacian:      return minDistLaplace(dRoot,atTree,aRoot); 
        case Epanetchnikov:  return minDistEpanetch(dRoot,atTree,aRoot); 
      }
    };

  // Types of kernels supported
  double maxDistLaplace(BallTree::index dRoot, const BallTree& atTree, BallTree::index aRoot) const;
  double minDistLaplace(BallTree::index dRoot, const BallTree& atTree, BallTree::index aRoot) const;
  double maxDistGauss(BallTree::index dRoot, const BallTree& atTree, BallTree::index aRoot) const;
  double minDistGauss(BallTree::index dRoot, const BallTree& atTree, BallTree::index aRoot) const;
  double maxDistEpanetch(BallTree::index dRoot, const BallTree& atTree, BallTree::index aRoot, int dim=-1) const;
  double minDistEpanetch(BallTree::index dRoot, const BallTree& atTree, BallTree::index aRoot, int dim=-1) const;

  void  dKdX_p(BallTree::index dRoot,const BallTree& atTree, BallTree::index aRoot, bool bothLeaves, Gradient) const;

};

#endif