#include <math.h>
#include "mex.h"
#define MEX
#include "cpp/BallTreeDensity.h"
double erf(double c) { // disabled for lack of windows erf f'n
return 1.0;
}
const double pi = 3.141592653589;
const double s2pi = .398942280401432; // = 1/sqrt(2*pi);
const double s2 = 1.414213562373095; // = sqrt(2);
void entGrad_Resub(const BallTreeDensity& dens, double* err);
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
const mxArray *cell;
double *err;
/*********************************************************************
** Verify arguments and initialize variables
*********************************************************************/
if (nrhs != 1)
mexErrMsgTxt("Takes 1 input arguments");
if (nlhs > 1)
mexErrMsgTxt("Outputs 1 results");
if (!mxIsClass(prhs[0],"kde")) mexErrMsgTxt("Takes one KDE class variable");
BallTreeDensity dens = BallTreeDensity(prhs[0]);
if (dens.getType() != BallTreeDensity::Gaussian)
mexErrMsgTxt("Sorry -- only Gaussian kernels supported");
plhs[0] = mxCreateDoubleMatrix(dens.Ndim(),dens.Npts(),mxREAL);
err = mxGetPr(plhs[0]);
entGrad_Resub(dens,err);
}
void entGrad_Resub(const BallTreeDensity& dens, double* err1) {
// Law-of-Large-Numbers Estimate of Entropy:
//
// for (j in #pts) { % Dij = delta Xi-Xj
// for (k in #pts) { % Kij = kernel of j at i
// Djk = (Xj-Xk)/(2*sig) %
// Kjk = exp(- Djk^2 / (2*sig)) % K'ij = Kij*Dij
// ERj = (Sum(K'jk,k)/Sum(Kjk,k)) % Error = Sum(K')/Sum(K)
//
BallTree::index i,j;
unsigned long jj;
unsigned int Ndim = dens.Ndim();
unsigned int k;
double *Kprime = new double[Ndim];
for (j=dens.leafFirst(dens.root());j<=dens.leafLast(dens.root());j++) {
double p = 0;
double* err = err1 + Ndim*dens.getIndexOf(j);
for (k=0;k<Ndim;k++) Kprime[k] = 0;
for (i=dens.leafFirst(dens.root());i<=dens.leafLast(dens.root());i++) {
double K = 0; // compute K(xi-yj)
for (k=0;k<Ndim;k++) { // and K'(xi-yj)
double mDiff = (dens.center(j)[k] - dens.center(i)[k]);
K -= .5* ((mDiff*mDiff) / dens.bw(j)[k] + log(dens.bw(j)[k]));
}
K = dens.weight(i) * exp(K); // yj^th kernel at xi
for (k=0;k<Ndim;k++) { // and K'(xi-yj)
double mDiff = (dens.center(j)[k] - dens.center(i)[k]);
Kprime[k] += K * mDiff / dens.bw(j)[k];
}
p += K;
}
for (k=0;k<Ndim;k++)
err[k] = dens.weight(j) * Kprime[k] / p; //
}
delete[] Kprime;
}