//
// Matlab MEX interface for "reduce" function QP solvers
//
// Written by Alex Ihler
// Copyright (C) 2003 Alexander Ihler; distributable under GPL -- see README.txt
//
#define MEX
#include "mex.h"
#include "math.h"
double fabs(double);
void SMO(double* Q, double* D, unsigned int N, double* weights);
bool searchPoint(unsigned int& I1, unsigned int I2,double* weights,double* Q,double* D,unsigned int N, double&);
bool updateWeight(unsigned int I1,unsigned int I2,double* weights,double dW, double* Q, unsigned int N);
void multUpdate(double* Q, double* D, unsigned int N, double* weights);
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
// check for the right number of arguments
if((nrhs < 2)||(nrhs > 3))
mexErrMsgTxt("Takes 2-3 input arguments");
if(nlhs != 1)
mexErrMsgTxt("Outputs one result");
unsigned int N = mxGetN(prhs[0]);
unsigned int type;
if (nrhs == 3) type = (unsigned int) mxGetScalar(prhs[2]); else type = 2; // default SMO
double* Q = mxGetPr(prhs[0]);
double* D = mxGetPr(prhs[1]);
plhs[0] = mxCreateDoubleMatrix(1, N, mxREAL);
double* weights = mxGetPr(plhs[0]);
switch (type) {
case 1: mexErrMsgTxt("Sorry -- standard QP not implemented in MEX yet."); break;
case 2: SMO(Q,D,N,weights); break;
case 3: multUpdate(Q,D,N,weights); break;
}
}
#define weightTolerance 1e-6
#define errorTolerance 1e-5
//
// Sequential Minimal Optimisation (SMO) algorithm for Reduced Set Density Estimation (RSDE).
// Finds weights to minimize : 0.5*wts*Q*wts'- wts*D'
//
void SMO(double* Q, double* D, unsigned int N, double* weights)
{
unsigned int i,j,numChanged=0,I1,I2;
double wtMax, sD, error1, error2=1e10;
double *weightsBACKUP;
bool* examine;
bool done = false, loop=false;
examine = (bool*) mxMalloc(N*sizeof(bool));
bool firstTime=true;
weightsBACKUP = (double*) mxMalloc(N*sizeof(double));
for (i=0,sD=0;i<N;i++) sD += D[i];
for (i=0;i<N;i++) weights[i] = D[i]/sD;
for (i=0;i<N;i++) examine[i] = true;
while (!done) {
wtMax = -1;
for (i=0;i<N;i++) {
if (firstTime) if (weights[i] < weightTolerance) examine[i] = false;
if (examine[i] && (wtMax < weights[i])) {
wtMax = weights[i]; I2 = i;
}
}
double wI1_old;
for (i=0;i<N;i++) weightsBACKUP[i] = weights[i];
if (searchPoint(I1,I2,weights,Q,D,N,wI1_old)) numChanged++;
loop = true;
examine[I2] = false; // don't care about matching I1 & I2 now
if (weights[I1] == wI1_old) examine[I1] = 0;
for (i=0;i<N;i++) { // check if we're done:
if (weights[i] == wtMax) examine[i] = false; // don't care about the maximal weight
if (weights[i] == weights[I1] && i!=I1) examine[i] = false;
if (examine[i]) loop = false; // if still some to look at, not yet done with this set!
}
firstTime = false;
if (loop) {
error1=0;
for (i=0;i<N;i++) {
double tmp = 0;
for (j=0;j<N;j++)
tmp += Q[i+N*j]*weights[j];
error1 += .5*weights[i]*tmp - weights[i]*D[i];
}
if (error1 > error2) {
for (i=0;i<N;i++) weights[i] = weightsBACKUP[i];
// printf("Error got worse!\n"); //should do: alpha=alpha_bk;
done=true;
} else if (fabs(error1-error2) < errorTolerance) done = true;
if (numChanged==0) done = true; // if nothing changed, we can quit
if (~done) {
loop = false; numChanged = 0; // back to pairwise optimization steps
for (i=0;i<N;i++) examine[i] = true; // consider everything again
firstTime = true;
error2 = error1; // save this as the new error
}
}
// printf(" -- %f\n",error2);
// mexCallMATLAB(0, NULL, 0, NULL, "pause");
}
mxFree(examine);
}
//%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
// Find the second point to be updated
bool searchPoint(unsigned int& I1, unsigned int I2,double* weights,double* Q,double* D,unsigned int N, double& wI1_old)
{
double dW=0, dWabs=0, W1, W2;
unsigned int i,j;
I1 = I2; // default is, do nothing
W2 = 0;
for (j=0;j<N;j++) W2 += weights[j] * Q[N*j + I2];
W2 -= D[I2];
for (i=0; i<N; i++) {
if (weights[i] <= weightTolerance) continue;
W1 = 0;
for (j=0;j<N;j++) W1 += weights[j] * Q[i + N*j];
W1 -= D[i];
if (fabs(W1-W2) > dWabs) {
dWabs = fabs(W1-W2); dW = W1-W2; I1 = i;
}
}
wI1_old = weights[I1];
return updateWeight(I1,I2,weights,dW,Q,N);
}
//%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
// adjust the weights of I1,I2
bool updateWeight(unsigned int I1,unsigned int I2,double* weights,
double dW, double* Q, unsigned int N)
{
double alph1, alph2;
if (I1==I2) return false;
if (dW==0) return false;
if (weights[I1] < weightTolerance) return false;
alph2=weights[I2] + dW / (Q[I1+N*I1]-2*Q[I1+N*I2]+Q[I2+N*I2]);
if (alph2 < 0) alph2 = 0;
alph1 = weights[I1]+weights[I2]-alph2;
if (alph1 < 0) {
alph1=0;
alph2=weights[I1]+weights[I2];
}
weights[I1]=alph1;
weights[I2]=alph2;
return true;
}
//
// Multiplicative update optimisation algorithm for Reduced Set Density Estimation (RSDE).
// Minimising 0.5*alpha'*Q*alpha-alpha'*D
// Updating rule: alpha=(alpha.*D')./(Q*alpha);
#define alpha_tolerance 1e-6
#define error_tolerance 1e-5
void multUpdate(double* Q, double* D, unsigned int N, double* weights)
{
unsigned int i,j;
double sumA,sumAD, err, errNew, tmp;
double *a = (double*) mxMalloc(N*sizeof(double));
// printf("Performing multUpdate\n",N);
// mexCallMATLAB(0, NULL, 0, NULL, "pause");
for (i=0,sumA=0;i<N;i++) sumA += D[i];
for (i=0;i<N;i++) weights[i] = D[i]/sumA;
for (i=0,err=0;i<N;i++) { // compute ISE error value
for (j=0,tmp=0;j<N;j++) tmp += Q[i*N+j]*weights[j];
err += .5*weights[i]*tmp - weights[i]*D[i];
}
double dE=1; // improvement in ISE each iteration
while (fabs(dE)>error_tolerance) {
for (i=0,sumA=0,sumAD=0;i<N;i++) {
for (j=0, tmp=0;j<N;j++) tmp += Q[i+N*j]*weights[j];
a[i] = weights[i]/tmp;
sumA += a[i]; sumAD += a[i]*D[i];
}
for (i=0;i<N;i++) weights[i] = a[i]*(D[i] + (1-sumAD)/sumA);
for (i=0,sumA=0;i<N;i++) {
if (weights[i] <= alpha_tolerance) weights[i] = 0;
sumA += weights[i];
}
for (i=0;i<N;i++) weights[i] /= sumA;
for (i=0,errNew=0;i<N;i++) { // compute new error
for (j=0, tmp=0;j<N;j++) tmp += Q[i+N*j]*weights[j];
errNew += .5*weights[i]*tmp - weights[i]*D[i];
}
dE = errNew - err; err = errNew; // and check for convergence
}
mxFree(a);
}