#include <math.h>
#include "mex.h"
#define MEX
#include "cpp/BallTreeDensity.h"
double erf(double c) { // disabled for lack of windows erf f'n
return 1.0;
}
const double pi = 3.141592653589;
const double s2pi = .398942280401432; // = 1/sqrt(2*pi);
const double s2 = 1.414213562373095; // = sqrt(2);
void KLGrad_Resub(const BallTreeDensity&, const BallTreeDensity&, double*, double*);
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
const mxArray *cell;
double *err1, *err2;
/*********************************************************************
** Verify arguments and initialize variables
*********************************************************************/
if (nrhs != 2)
mexErrMsgTxt("Takes 2 input arguments");
if (nlhs > 2)
mexErrMsgTxt("Outputs 2 results");
if (!mxIsClass(prhs[0],"kde")) mexErrMsgTxt("Takes two KDE class variables");
if (!mxIsClass(prhs[1],"kde")) mexErrMsgTxt("Takes two KDE class variables");
BallTreeDensity p1 = BallTreeDensity(prhs[0]);
BallTreeDensity p2 = BallTreeDensity(prhs[1]);
if (p1.getType() != BallTreeDensity::Gaussian)
mexErrMsgTxt("Sorry -- only Gaussian kernels supported");
if (p2.getType() != BallTreeDensity::Gaussian)
mexErrMsgTxt("Sorry -- only Gaussian kernels supported");
plhs[0] = mxCreateDoubleMatrix(p1.Ndim(),p1.Npts(),mxREAL);
plhs[1] = mxCreateDoubleMatrix(p2.Ndim(),p2.Npts(),mxREAL);
err1 = mxGetPr(plhs[0]);
err2 = mxGetPr(plhs[1]);
KLGrad_Resub(p1,p2,err1,err2);
}
//void entGrad_Resub(const BallTreeDensity& dens, const BallTreeDensity &loc, double* err) {
//// Law-of-Large-Numbers Estimate of Entropy:
////
//// for (j in #pts) { % Dij = delta Xi-Xj
//// for (k in #pts) { % Kij = kernel of j at i
//// Djk = (Xj-Xk)/(2*sig) %
//// Kjk = exp(- Djk^2 / (2*sig)) % K'ij = Kij*Dij
//// ERj = (Sum(K'jk,k)/Sum(Kjk,k)) % Error = Sum(K')/Sum(K)
////
//
// double *KpOverK_K;
// KpOverK_K = err;
//
// BallTree::index i,j;
// unsigned long jj;
// unsigned int Ndim = dens.Ndim();
// unsigned int k;
//
// for (j=loc.leafFirst(0);j<=loc.leafLast(0);j++) {
// jj = loc.getIndexOf(j);
// double K = 0; //K[jj] = 0;
// for (k=0;k<Ndim;k++) KpOverK_K[k+Ndim*jj]=0;
// for (i=dens.leafFirst(0);i<=dens.leafLast(0);i++) {
// double Ktmp = 0;
// for (k=0;k<Ndim;k++) {
// double mDiff = (loc.center(j)[k] - dens.center(i)[k]);
// Ktmp += -.5 * (mDiff*mDiff) / dens.bw(i)[k];
// }
// Ktmp = exp(Ktmp);
// for (k=0;k<Ndim;k++)
// KpOverK_K[k+Ndim*jj] += Ktmp * (loc.center(j)[k] - dens.center(i)[k]) / dens.bw(i)[k];// * dens.bw(j)[k]);
// K += Ktmp; //K[jj] += Ktmp;
// }
// for (k=0;k<Ndim;k++) KpOverK_K[k+Ndim*jj] /= K; //K[jj];
// for (k=0;k<Ndim;k++) KpOverK_K[k+Ndim*jj] *= .05; // epsilon change
// }
//}
// KLGrad1 -- calculate gradient of - E_p1[ log p2 ] WRT points of p1
//
// in notes, xi -> p1 and yj -> p2
//
void KLGrad_Resub(const BallTreeDensity& p1, const BallTreeDensity &p2, double* err1, double* err2)
{
BallTree::index i,j;
unsigned int k, Ndim = p1.Ndim();
double *err, *Kprime = new double[Ndim];
//////////////////////////////////////////////////////////////////////////////////
// dE_p1[ log p1 ] / dp1
//////////////////////////////////////////////////////////////////////////////////
for (i=p1.leafFirst(0);i<=p1.leafLast(0);i++) { // err[i] = Sum_yj wi wj K'(xi-yj)/p2(xi)
double p = 0;
err = err1 + Ndim*p1.getIndexOf(i);
for (k=0;k<Ndim;k++) Kprime[k] = 0;
for (j=p1.leafFirst(0);j<=p1.leafLast(0);j++) {
double K = 0; // compute K(xi-yj)
for (k=0;k<Ndim;k++) { // and K'(xi-yj)
double mDiff = (p1.center(j)[k] - p1.center(i)[k]);
K -= .5* ((mDiff*mDiff) / p1.bw(i)[k] + log(p1.bw(i)[k]));
}
K = p1.weight(j) * exp(K); // yj^th kernel at xi
for (k=0;k<Ndim;k++) { // and K'(xi-yj)
double mDiff = (p1.center(j)[k] - p1.center(i)[k]);
Kprime[k] += K * mDiff / p1.bw(i)[k];
}
p += K;
}
for (k=0;k<Ndim;k++)
err[k] = p1.weight(i) * Kprime[k] / p; //
}
////////////////////////////////////////////////////////////////////////////////////
//// dE_p1[ log p2 ] / dp1
////////////////////////////////////////////////////////////////////////////////////
for (i=p1.leafFirst(0);i<=p1.leafLast(0);i++) { // err[i] = Sum_yj wi wj K'(xi-yj)/p2(xi)
double p = 0;
err = err1 + Ndim*p1.getIndexOf(i);
for (k=0;k<Ndim;k++) Kprime[k] = 0;
for (j=p2.leafFirst(0);j<=p2.leafLast(0);j++) {
double K = 0; // compute K(xi-yj)
for (k=0;k<Ndim;k++) { // and K'(xi-yj)
double mDiff = (p1.center(i)[k] - p2.center(j)[k]);
K -= .5* ((mDiff*mDiff) / p2.bw(i)[k] + log(p2.bw(i)[k]));
}
K = p2.weight(j) * exp(K); // yj^th kernel at xi
for (k=0;k<Ndim;k++) { // and K'(xi-yj)
double mDiff = (p1.center(i)[k] - p2.center(j)[k]);
Kprime[k] += - K * mDiff / p2.bw(i)[k];
}
p += K;
}
for (k=0;k<Ndim;k++)
err[k] -= p1.weight(i) * Kprime[k] / p; //
}
////////////////////////////////////////////////////////////////////////////////////
//// dE_p1[ log p2 ] / dp2
////////////////////////////////////////////////////////////////////////////////////
for (j=p2.leafFirst(0);j<=p2.leafLast(0);j++) { // err[i] = Sum_yj wi wj K'(xi-yj)/p2(xi)
double p = 0;
err = err2 + Ndim*p2.getIndexOf(j);
for (k=0;k<Ndim;k++) Kprime[k] = 0;
for (i=p1.leafFirst(0);i<=p1.leafLast(0);i++) {
double K = 0; // compute K(xi-yj)
for (k=0;k<Ndim;k++) { // and K'(xi-yj)
double mDiff = (p2.center(j)[k] - p1.center(i)[k]);
K -= .5* ((mDiff*mDiff) / p2.bw(i)[k] + log(p2.bw(i)[k]));
}
K = p1.weight(i) * exp(K); // yj^th kernel at xi
for (k=0;k<Ndim;k++) { // and K'(xi-yj)
double mDiff = (p2.center(j)[k] - p1.center(i)[k]);
Kprime[k] += - K * mDiff / p2.bw(j)[k];
}
p += K;
}
for (k=0;k<Ndim;k++)
err[k] = - p2.weight(j) * Kprime[k] / p; //
}
//////////////////////////////////////////////////////////////////////////////////
// dE_p2[ log p2 ] / dp2
//////////////////////////////////////////////////////////////////////////////////
// for (i=p2.leafFirst(0);i<=p2.leafLast(0);i++) { // err[i] = Sum_yj wi wj K'(xi-yj)/p2(xi)
// double p = 0;
// err = err2 + Ndim*p2.getIndexOf(i);
// for (k=0;k<Ndim;k++) Kprime[k] = 0;
// for (j=p2.leafFirst(0);j<=p2.leafLast(0);j++) {
// double K = 0; // compute K(xi-yj)
// for (k=0;k<Ndim;k++) { // and K'(xi-yj)
// double mDiff = (p2.center(j)[k] - p2.center(i)[k]);
// K -= .5* ((mDiff*mDiff) / p2.bw(i)[k]);// -log(p1.bw(i)[k]));
// }
// K = p2.weight(j) * exp(K); // yj^th kernel at xi
// for (k=0;k<Ndim;k++) { // and K'(xi-yj)
// double mDiff = (p2.center(j)[k] - p2.center(i)[k]);
// Kprime[k] += K * mDiff / p2.bw(i)[k];
// }
// p += K;
// }
// for (k=0;k<Ndim;k++)
// err[k] += p2.weight(i) * Kprime[k] / p; //
// }
////////////////////////////////////////////////////////////////////////////////////
delete[] Kprime;
}