/***********************************************************************
** exact sampling MEX file
**
** Calculate the partition function and sample Nsamp times
** in O(N^d) time, O(d) storage
**
***********************************************************************/
//
// Written by Alex Ihler and Mike Mandel
// Copyright (C) 2003 Alexander Ihler; distributable under GPL -- see README.txt
//
#define MEX
#include <assert.h>
#include "cpp/BallTreeDensity.h"
#include "mex.h"
void exactEval(void);
double normConstant(void);
//void exactInit(void);
//
// Make recursion-independent variables global for simplicity
//
BallTreeDensity *trees; // structure of all trees
double *samples;
BallTree::index *indices; // return data -- samples & indices
double *randnorm, *randunif; // random numbers; gaussian & sorted uniform
double total, soFar; // partition function value & counter
unsigned int Ndim,Ndens; // common size variables
unsigned long Nsamp;
bool bwUniform;
#ifdef MEX
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
mxArray *rsize, *rUnif, *rNorm; double *rsizeP; // random # call vars
unsigned int i,j;
/*********************************************************************
** Verify arguments and initialize variables
*********************************************************************/
if (nrhs != 2)
mexErrMsgTxt("Takes 2 input arguments");
if (nlhs > 2)
mexErrMsgTxt("Outputs 2 results");
Ndens = (unsigned int) mxGetN(prhs[0]); // get # of densities
trees = new BallTreeDensity[Ndens];
bwUniform = true;
bool allGaussians = true;
for (i=0;i<Ndens;i++) { // load densities
trees[i] = BallTreeDensity( mxGetCell(prhs[0],i) );
if (trees[i].getType() != BallTreeDensity::Gaussian) allGaussians = false;
bwUniform = bwUniform && trees[i].bwUniform();
}
if (!allGaussians)
mexErrMsgTxt("Sorry -- only Gaussian kernels supported");
Ndim = trees[0].Ndim(); // globally accessible dimension variable
Nsamp = (unsigned long) mxGetScalar(prhs[1]); // get requested # of samples
// Create enough gaussian and (sorted) uniform samples to get us through
// the rest of the code:
rsize = mxCreateDoubleMatrix(1,2,mxREAL);
rsizeP= mxGetPr(rsize); rsizeP[0] = 1; rsizeP[1] = Nsamp+1;
rUnif = mxCreateDoubleMatrix(1,Nsamp+1,mxREAL);
mexCallMATLAB(1, &rNorm, 1, &rsize, "rand"); randunif = mxGetPr(rNorm);
randunif[Nsamp] = 100;
mexCallMATLAB(1, &rUnif, 1, &rNorm, "sort"); randunif = mxGetPr(rUnif);
mxDestroyArray(rNorm);
rsizeP[0] = Ndim; rsizeP[1] = Nsamp;
mexCallMATLAB(1, &rNorm, 1, &rsize, "randn"); randnorm = mxGetPr(rNorm);
// Make a return location for the samples
plhs[0] = mxCreateDoubleMatrix(Ndim,Nsamp,mxREAL);
samples = (double*) mxGetData(plhs[0]);
plhs[1] = mxCreateNumericMatrix(Ndens,Nsamp,mxUINT32_CLASS,mxREAL);
indices = (BallTree::index*) mxGetData(plhs[1]);
total = -1; soFar = 0; exactEval(); // recurse to get partition value
total = soFar; soFar = 0; exactEval(); // recurse on trees to sample
mxDestroyArray(rUnif); mxDestroyArray(rNorm); mxDestroyArray(rsize);
delete[] trees;
}
#endif
double normConstant(void) {
unsigned int i,j;
double tmp,normConst;
const double pi=3.141592653589;
normConst = 1; // precalculate influence of normalization
tmp = pow(2*pi,((double)Ndim)/2);
for (i=0;i<Ndens;i++) { // divide by norm fact of each indiv. gauss.
normConst /= tmp;
if (bwUniform) for (j=0;j<Ndim;j++) {
normConst /= sqrt(trees[i].bwMin(0)[j]);
}
}
normConst *= tmp; // times norm factor of resulting gaussian
for (j=0;j<Ndim;j++) {
tmp = 0;
if (bwUniform) {
for (i=0;i<Ndens;i++) tmp += 1/trees[i].bwMin(0)[j]; // compute result bandwidth
normConst /= sqrt(tmp); // and its norm factor
}
}
return normConst;
}
void exactEval(void) {
unsigned int i,j;
BallTree::index *ind = new BallTree::index[Ndens]; // current data indices
double *M = new double[Ndim];
double *C = new double[Ndim];
double *sC= new double[Ndim];
if (bwUniform) { // IF THIS IS THE SAME FOR ALL INDICES
for (j=0;j<Ndim;j++) { // Find variance of each product kernel
double tmp = 0; //
for (i=0;i<Ndens;i++)
tmp += 1/trees[i].bw(trees[i].leafFirst(trees[i].root()))[j];
C[j] = 1/tmp;
sC[j] = sqrt(C[j]); // also find std. deviation value
}
}
for (i=0;i<Ndens;i++)
ind[i] = trees[i].leafFirst( trees[i].root() ); // initialize indices
do { // for all combos of input indices
if (!bwUniform) { // IF THIS IS NOT THE SAME FOR ALL INDICES
for (j=0;j<Ndim;j++) { // Find variance of each product kernel
double tmp = 0; //
for (i=0;i<Ndens;i++)
tmp += 1/trees[i].bw(ind[i])[j];
C[j] = 1/tmp;
sC[j] = sqrt(C[j]); // also find std. deviation value
}
}
for (j=0;j<Ndim;j++) {
M[j] = 0; // Find mean of the product kernel
for (i=0;i<Ndens;i++)
M[j] += trees[i].mean(ind[i])[j] / trees[i].bw(ind[i])[j];
}
for (j=0;j<Ndim;j++) M[j] *= C[j];
double p = 1;
for (i=0;i<Ndens;i++) {
p *= trees[i].weight(ind[i]); // calculate contribution of
double sum = 0; // each component to weight
for (j=0;j<Ndim;j++) { // of this product element
double tmp = trees[i].center(ind[i])[j] - M[j];
sum -= tmp*tmp / trees[i].bw(ind[i])[j];
if (!bwUniform) sum -= log(trees[i].bw(ind[i])[j]);
}
p *= (double) exp(sum/2); // p is prop. to weight of this gaussian
}
if (!bwUniform) for (j=0;j<Ndim;j++) p*=sC[j];
soFar += p; // keep a running tab on the total weight
while (*randunif <= soFar/total) { // if this index is a sample
randunif++; // (or more than one)
for (i=0;i<Ndim;i++) // draw from its gaussian
*(samples++) = M[i] + sC[i] * (*(randnorm++));
for (i=0;i<Ndens;i++) // save its indices
*(indices++) = trees[i].getIndexOf(ind[i])+1;
}
ind[0]++; // increment indices
for (i=0;i<Ndens-1;i++) { // checking for wrap-around
if (!trees[i].validIndex(ind[i])) {
ind[i] = trees[i].leafFirst( trees[i].root() );
ind[i+1]++;
}
}
} while (trees[Ndens-1].validIndex(ind[Ndens-1])); // test for end-of-loop
delete[] ind; // free allocated memory
delete[] M; delete[] C; delete[] sC;
}