/***********************************************************************
** ISE evaluation MEX code (taken from multi-tree epsilon product)
**
**
***********************************************************************/
//
// Written by Alex Ihler and Mike Mandel
// Copyright (C) 2003 Alexander Ihler; distributable under GPL -- see README.txt
//


#define MEX
#include <math.h>
#include "mex.h"
#include "cpp/BallTreeDensity.h"

void multiEval(void);
void computeSigVals(void);
double normConstant(void);

// a little addressing formula: 
//   to access a^th dimension of density pair (b,c)'s constant
#define SIGVALSMAX(a,b,c) (SigValsMax + a+Ndim*b+Ndim*Ndens*c)
#define SIGVALSMIN(a,b,c) (SigValsMin + a+Ndim*b+Ndim*Ndens*c)
double *SigValsMax, *SigValsMin;

BallTreeDensity *trees;    // structure of all trees
BallTree::index *ind;      // indices of this level of the trees

double *C,*sC,*M;

double *randunif1, *randunif2, *randnorm;  // required random numbers
double *samples;
BallTree::index *indices;    // return data

double maxErr;                 // epsilon tolerance (%) of algorithm
double total, soFar, soFarMin; // partition f'n and accumulation

unsigned int Ndim,Ndens;   // useful constants
unsigned long Nsamp;
bool bwUniform;

#ifdef MEX
//////////////////////////////////////////////////////////////////////
// MEX WRAPPER
//////////////////////////////////////////////////////////////////////
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
  mxArray *rNorm, *rUnif1, *rUnif2, *rsize;
  double *ISE;
  double rUnif = 1;
  BallTreeDensity tempTree;
  unsigned int i,j;
  
  /*********************************************************************
  ** Verify arguments and initialize variables
  *********************************************************************/

  if (nrhs != 3)
    mexErrMsgTxt("Takes 3 input arguments");
  if (nlhs >  1)
    mexErrMsgTxt("Outputs 1 results");

  Ndens = 2;
  trees = (BallTreeDensity*) mxMalloc(Ndens*sizeof(BallTreeDensity));
  bwUniform = true;
  bool allGaussians = true;
  for (i=0;i<Ndens;i++) {                               // load densities
    trees[i] = BallTreeDensity( prhs[i] );  
    if (trees[i].getType() != BallTreeDensity::Gaussian) allGaussians = false;
    bwUniform = bwUniform && trees[i].bwUniform();
  }
  if (!allGaussians)
    mexErrMsgTxt("Sorry -- only Gaussian kernels supported");

  Ndim  = trees[0].Ndim();                      // more accessible dimension variable
  maxErr= 2*mxGetScalar(prhs[2]);               // epsilon (we always use 2*epsilon)

  plhs[0] = mxCreateDoubleMatrix(1,1,mxREAL);
  ISE = (double*) mxGetData(plhs[0]);

  Nsamp = 0; randunif1 =&rUnif; randunif2 =&rUnif; randnorm =&rUnif; // something positive
  SigValsMax = (double*) mxMalloc(Ndim*Ndens*Ndens*sizeof(double));  // precalc'd constants
  SigValsMin = (double*) mxMalloc(Ndim*Ndens*Ndens*sizeof(double));  // precalc'd constants
  C       = (double*) mxMalloc(Ndim*sizeof(double));
  sC      = (double*) mxMalloc(Ndim*sizeof(double));
  M       = (double*) mxMalloc(Ndim*sizeof(double));

  total =    -1; soFar = soFarMin = 0;   multiEval();  // compute cross-density terms
  *ISE = -2*soFar*normConstant(); tempTree = trees[1]; trees[1] = trees[0]; 
  total =    -1; soFar = soFarMin = 0;   multiEval();  // add square of 1st density
  *ISE += soFar*normConstant();   trees[1] = tempTree; tempTree = trees[0]; trees[0] = trees[1];
  total =    -1; soFar = soFarMin = 0;   multiEval();  // and square of 2nd density
  *ISE += soFar*normConstant();   trees[0] = tempTree;

  mxFree(trees);
  mxFree(C); mxFree(sC); mxFree(M); mxFree(SigValsMin); mxFree(SigValsMax);
}
#endif


double normConstant(void) {
  unsigned int i,j;
  double tmp, normConst;
  const double pi = 3.141592653589;
  
  normConst = 1;                               // precalculate influence of normalization
  tmp = pow(2*pi,((double)Ndim)/2);
  for (i=0;i<Ndens;i++) {                      // divide by norm fact of each indiv. gauss.
    normConst /= tmp;
    if (bwUniform) for (j=0;j<Ndim;j++) {
      normConst /= sqrt(trees[i].bwMin(0)[j]);
    }
  }
  normConst *= tmp;                            // times norm factor of resulting gaussian
  for (j=0;j<Ndim;j++) {
    tmp = 0;
    if (bwUniform) {
      for (i=0;i<Ndens;i++) tmp += 1/trees[i].bwMin(0)[j];  // compute result bandwidth
      normConst /= sqrt(tmp);                               // and its norm factor
    }
  }
  return normConst;
}


//////////////////////////////////////////////////////////////////////////
// calculate bounds on the min/max distance possible between two ball-trees
//   return un-exponentiated values
//
double minDistProd(const BallTreeDensity& bt1, BallTree::index i,
                   const BallTreeDensity& bt2, BallTree::index j,
                   const double* SigValIJ,const double* SigNIJ)             //  precomp'd weighting factors
{
  double result=0;
  const double *center1, *center2;

  center1 = bt1.center(i); center2 = bt2.center(j);
  for (unsigned int k=0;k<Ndim;k++) {
    double tmp = fabs( center1[k] - center2[k] );
    tmp-= bt1.range(i)[k] + bt2.range(j)[k];
    if (tmp < 0) tmp = 0;
    result -= (tmp*tmp) * SigValIJ[k];
    if (!bwUniform) result += log(SigNIJ[k]); // !!! no should be min Sig val not max
  }
  result /= 2;
  return result;
}

double maxDistProd(const BallTreeDensity& bt1, BallTree::index i,
                   const BallTreeDensity& bt2, BallTree::index j,
                   const double* SigValIJ,const double* SigNIJ)             //  precomp'd weighting factors
{
  double result=0;
  const double *center1, *center2;

  center1 = bt1.center(i); center2 = bt2.center(j);
  for (unsigned int k=0;k<Ndim;k++) {
    double tmp = fabs( center1[k] - center2[k] );
    tmp+= bt1.range(i)[k] + bt2.range(j)[k];
    result -= (tmp*tmp) * SigValIJ[k];
    if (!bwUniform) result += log(SigNIJ[k]); // !!! no should be max Sig val not min
  }
  result /= 2;
  return result;
}

// Compute (1 over) the \Lambda_(i,j) values needed for distance-weight computations
// 
void computeSigVals(void) {
  unsigned int i,j,k;
  double *SigNormMin = (double*) mxMalloc(Ndim*sizeof(double));
  double *SigNormMax = (double*) mxMalloc(Ndim*sizeof(double));
  for (i=0;i<Ndim;i++) {
    SigNormMin[i] = SigNormMax[i] = 0;
    for (j=0;j<Ndens;j++) SigNormMin[i]+=1/trees[j].bwMin(ind[j])[i]; // compute \Lambda_L 
    for (j=0;j<Ndens;j++) SigNormMax[i]+=1/trees[j].bwMax(ind[j])[i]; //
    SigNormMax[i] = 1/SigNormMax[i]; SigNormMin[i] = 1/SigNormMin[i];
  }
  for (i=0;i<Ndim;i++) {
    for (j=0;j<Ndens;j++)                                    //  then compute pairwise leave-
      for (k=j;k<Ndens;k++) {                                //  two-out normalized values
        *SIGVALSMIN(i,k,j) = SigNormMax[i] / (trees[j].bwMin(ind[j])[i]*trees[k].bwMin(ind[k])[i]);
        *SIGVALSMAX(i,k,j) = SigNormMin[i] / (trees[j].bwMax(ind[j])[i]*trees[k].bwMax(ind[k])[i]);
        *SIGVALSMIN(i,j,k) = *SIGVALSMIN(i,k,j);             //  make symmetric
        *SIGVALSMAX(i,j,k) = *SIGVALSMAX(i,k,j);
      }
  }
//  delete[] SigNorm;  //(don't need this anymore)
  mxFree(SigNormMin);
  mxFree(SigNormMax);

}

void multiEvalRecursive(void) {
  unsigned int i,j;
  double minVal=0, maxVal=0;                    // for computing bounds and 
  unsigned int maxInd0, maxInd1;  //  determining which tree to split

  //
  // find min/max values of product
  //
  if (!bwUniform) computeSigVals();

  double maxDiscrep = -1;
  bool allLeaves = true;
  for (i=0; i<Ndens; i++) {                       // For each pair of densities, bound
    for (j=i+1;j<Ndens;j++) {                     //   the total weight of their product:
      double maxValT = minDistProd(trees[i],ind[i],trees[j],ind[j],SIGVALSMAX(0,i,j),SIGVALSMIN(0,i,j));  // compute min & max
      double minValT = maxDistProd(trees[i],ind[i],trees[j],ind[j],SIGVALSMIN(0,i,j),SIGVALSMAX(0,i,j));  // dist = max/min values
      maxVal += maxValT; minVal += minValT;

      if ((maxValT - minValT) > maxDiscrep) {           // also find which pair
        maxDiscrep = maxValT - minValT;                 //   has the largest
        maxInd0=i; maxInd1=j;                           //   discrepancy (A/B)
      }
    }
    allLeaves = allLeaves && trees[i].isLeaf(ind[i]);
  }
  maxVal = exp(maxVal); minVal = exp(minVal);

  // If the approximation is good enough,
  if (allLeaves || fabs(maxVal - minVal) <= maxErr * (soFarMin+minVal) ) {  // APPROXIMATE
    double add = (maxVal + minVal)/2;                   // compute contribution
    for (i=0;i<Ndens;i++) add *= trees[i].weight(ind[i]);
    soFar += add;
    add = minVal; for (i=0;i<Ndens;i++) add *= trees[i].weight(ind[i]);
    soFarMin += add;

    while (*randunif1 <= soFar/total) {                 // for all the samples coming from this block
      randunif1++;
      for (j=0;j<Ndim;j++) M[j] = 0;                    // clear out M
      if (!bwUniform) for (j=0;j<Ndim;j++) C[j] = 0;    // clear out C if necc.

      for (i=0;i<Ndens;i++) {                           // find an index within this block
        double SumTmp = 0;
        BallTree::index index = trees[i].leafFirst(ind[i]);  // start with 1st leaf and
        for (;index <= trees[i].leafLast(ind[i]);index++) {
          SumTmp += trees[i].weight(index) / trees[i].weight(ind[i]);
          if (SumTmp > *randunif2) break;
        }
        randunif2++;
        for (j=0;j<Ndim;j++)                                 // compute product mean:
          M[j] += trees[i].center(index)[j] / trees[i].bw(index)[j];
        *(indices++) = trees[i].getIndexOf(index);           // and save selected indices
        if (!bwUniform) for (j=0;j<Ndim;j++)                 // compute covariance
            C[j] += 1/trees[i].bw(index)[j];                 //  contribution of each dens.
      }
      if (!bwUniform) for (j=0;j<Ndim;j++) {                 // finish computing covar and
          C[j] = 1/C[j];                                     //  std dev. of product kernel
          sC[j] = sqrt(C[j]);
      }

      for (j=0;j<Ndim;j++) M[j] *= C[j];
      for (j=0;j<Ndim;j++)                              // sample from the product dist.
        *(samples++) = M[j] + sC[j] * (*(randnorm++));
    }

  // Otherwise, we need to subdivide at least one tree:
  } else {                                              // RECURSION  
    unsigned int split;
    double size0 = trees[maxInd0].range(ind[maxInd0])[0];  // from the pair with the largest
    double size1 = trees[maxInd1].range(ind[maxInd1])[0];  // pairwise max-min discrepancy term,

    for(BallTree::index k=0; k<trees[maxInd0].Ndim(); k++)
      if(trees[maxInd0].range(ind[maxInd0])[k] > size0)
	size0 = trees[maxInd0].range(ind[maxInd0])[k];
    for(BallTree::index k=0; k<trees[maxInd1].Ndim(); k++)
      if(trees[maxInd1].range(ind[maxInd1])[k] > size1)
	size1 = trees[maxInd1].range(ind[maxInd1])[k];    

    split = (size0 > size1) ? maxInd0 : maxInd1;        // take the largest.
    
    BallTree::index current = ind[split];
    if (!trees[split].isLeaf(current)) {
      ind[split] = trees[split].left(current);  
      multiEvalRecursive();                             // recurse left 
      ind[split] = trees[split].right(current);         //   and right tree
      multiEvalRecursive();                             // restore indices 
      ind[split] = current;                             //   for calling f'n
    }                                                   
  }
}


void multiEval(void) {
  unsigned int i,j,k;
//  ind = new BallTree::index[Ndens];               // construct index array  
  ind = (BallTree::index*) mxMalloc(Ndens*sizeof(BallTree::index));    // construct index array  
  for (i=0;i<Ndens;i++) ind[i] = trees[i].root(); //  & init to root node

  if (bwUniform) {                                     // if all one kernel size, do this in
    computeSigVals();                                  //   one operation.
    for (i=0;i<Ndim;i++) {                             // compute covariance and
      double tmp = 0;                                  //   std. deviation of a
      for (j=0;j<Ndens;j++)                            // resulting product kernel 
        tmp += 1/trees[j].bw(trees[j].leafFirst(trees[j].root()))[i]; 
      C[i] = 1/tmp;
      sC[i] = sqrt(C[i]);
    }
  }

  multiEvalRecursive();

//  delete[] ind;
  mxFree(ind);
}