//
// cellsxfun.cpp
// Copyright Stuart Yarrow 2010/03/24 (s.yarrow@ed.ac.uk)
// All rights reserved.
//
// Based on mAryCellFcn.ccp by Michael Brost (michaelbrost@yahoo.com).
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
//
#include "mex.h"
// STD includes
#include <string>
#include <sstream>
using namespace std;
void usage(void)
{
mexPrintf("========================================================================================================================\n");
mexPrintf("| This function is a cell array-based generalization of bsxfun to arbitrary dimensions using cell arrays. |\n");
mexPrintf("| |\n");
mexPrintf("| usage: cell_array = cellsxfun(function_handle, cell_array_1, cell_array_2, ..., cell_array_N) |\n");
mexPrintf("| where function_handle corresponds to a function which accepts N simultaneous cell array contents as inputs and |\n");
mexPrintf("| which returns at most one matlab object. The dimensions of each array must either match or be singleton e.g: |\n");
mexPrintf("| |\n");
mexPrintf("| x = {rand(20), rand(20), rand(20)}; 1 x 3 cell array |\n");
mexPrintf("| y = {rand(20)}; 1 x 1 cell array |\n");
mexPrintf("| z = {rand(20) rand(20)}'; 2 x 1 cell array |\n");
mexPrintf("| |\n");
mexPrintf("| out = cellsxfun(@plus, x, y, z); 2 x 3 cell array |\n");
mexPrintf("| |\n");
mexPrintf("| Each argument is treated as if singleton dimensions are replicated to the size of the output. |\n");
mexPrintf("| |\n");
mexPrintf("| NOTE: there were limited opportunities to trap errors. I suggest that you defensively program your functions. |\n");
mexPrintf("| Anonymous function handles are supported. |\n");
mexPrintf("| |\n");
mexPrintf("| Copyright Stuart Yarrow 2010/03/24 (s.yarrow@ed.ac.uk) |\n");
mexPrintf("| Based on mAryCellFcn.ccp by Michael Brost (michaelbrost@yahoo.com) |\n");
mexPrintf("| |\n");
mexPrintf("| This program is free software: you can redistribute it and/or modify |\n");
mexPrintf("| it under the terms of the GNU General Public License as published by |\n");
mexPrintf("| the Free Software Foundation, either version 3 of the License, or |\n");
mexPrintf("| (at your option) any later version. |\n");
mexPrintf("| |\n");
mexPrintf("| This program is distributed in the hope that it will be useful, |\n");
mexPrintf("| but WITHOUT ANY WARRANTY; without even the implied warranty of |\n");
mexPrintf("| MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |\n");
mexPrintf("| GNU General Public License for more details. |\n");
mexPrintf("| |\n");
mexPrintf("| You should have received a copy of the GNU General Public License |\n");
mexPrintf("| along with this program. If not, see <http://www.gnu.org/licenses/>. |\n");
mexPrintf("========================================================================================================================\n");
}
// gateway function
void mexFunction( int nOutArgs, mxArray *outPtr[], int nInArgs, const mxArray *inPtr[] )
{
// string for error messages
string errMsg;
// usage statement
if(nInArgs == 0)
{
usage();
return;
}
// check input argument counts
if(nInArgs < 2)
{
errMsg = string("function requires at least 2 input args.");
mexErrMsgTxt(errMsg.c_str());
}
// check input argument counts
if(nOutArgs > 1)
{
errMsg = string("function supports at most 1 output arg.");
mexErrMsgTxt(errMsg.c_str());
}
// check arg 1
if(mxGetClassID(inPtr[0]) != mxFUNCTION_CLASS)
{
errMsg = string("function argument #1 must be a function handle.");
mexErrMsgTxt(errMsg.c_str());
}
// check input args - all MUST be cell arrays
for(int aIndex=1; aIndex<nInArgs; aIndex++)
{
if(mxGetClassID(inPtr[aIndex]) != mxCELL_CLASS)
{
stringstream ss;
ss << (aIndex + 1);
errMsg = "function argument #" + ss.str() + " must be a cell array.";
mexErrMsgTxt(errMsg.c_str());
}
}
// number of cell array inputs
int nInputs = nInArgs-1;
// calculate output dimensionality
int nDim = 0;
for(mwSize argIndex=1; argIndex < nInArgs; argIndex++)
{
int nArgDims = mxGetNumberOfDimensions(inPtr[argIndex]);
if(nArgDims > nDim)
{
nDim = nArgDims;
}
}
// use matlab memory allocation for vector counter
mwSize *endVec = static_cast<mwSize*>(mxCalloc(nDim, sizeof(mwSize)));
mwSize *startVec = static_cast<mwSize*>(mxCalloc(nDim, sizeof(mwSize)));
mwSize *thisVec = static_cast<mwSize*>(mxCalloc(nDim, sizeof(mwSize)));
mwSize *workVec = static_cast<mwSize*>(mxCalloc(nDim, sizeof(mwSize)));
mwSize *argDims = static_cast<mwSize*>(mxCalloc(nInputs*nDim, sizeof(mwSize)));
// calculate output sizes
for(mwSize argIndex=0; argIndex < nInputs; argIndex++)
{
int nArgDims = mxGetNumberOfDimensions(inPtr[argIndex+1]);
const mwSize *argD = mxGetDimensions(inPtr[argIndex+1]);
for(mwSize dimIndex=0; dimIndex < nDim; dimIndex++)
{
if(dimIndex <= nArgDims)
{
argDims[argIndex*nDim + dimIndex] = argD[dimIndex];
if(argD[dimIndex] > endVec[dimIndex])
{
endVec[dimIndex] = argD[dimIndex];
}
}
else
{
argDims[argIndex*nDim + dimIndex] = 1;
}
}
}
// check input dims
for(mwSize argIndex=0; argIndex < nInputs; argIndex++)
{
for(mwSize dimIndex=0; dimIndex < nDim; dimIndex++)
{
// dims need to either match output dimensionality or be singleton
if(argDims[argIndex*nDim + dimIndex] != endVec[dimIndex] && argDims[argIndex*nDim + dimIndex] != 1)
{
errMsg = "cellsxfun: dimension mismatch\n";
mexErrMsgTxt(errMsg.c_str());
}
}
}
// create the output cell array
// normally we'd subtract the startVec but it is all 0s so we just use endVec
outPtr[0] = mxCreateCellArray(nDim, endVec);
// create an array of pointers to the data
mxArray **dataPtrArray = (mxArray **)mxCalloc(nInputs+1, sizeof(mxArray *));
// copy and set the function handle in the argument array
//dataPtrArray[0] = mxDuplicateArray(inPtr[0]);
dataPtrArray[0] = (mxArray *)(inPtr[0]);
// pointer to the results
mxArray *resultArray;
// using thisVec, count between startVec and endVec
for(;;)
{
// vector of pointers to cell's contents
for(mwSize vIndex=0; vIndex < nInputs; vIndex++)
{
for(mwSize dIndex=0; dIndex < nDim; dIndex++)
{
// calculate which cell we are using from input array
if(argDims[vIndex*nDim + dIndex] == 1)
workVec[dIndex] = 0;
else
workVec[dIndex] = thisVec[dIndex];
}
// get the contents of the input cell arrays as determined by
// the vector counter's contents
mwIndex inputIndex = mxCalcSingleSubscript(inPtr[vIndex+1], nDim, workVec);
dataPtrArray[vIndex+1] = mxGetCell(inPtr[vIndex+1], inputIndex);
// if the cell array was not initialized correctly, the return pointer
// will be null.
if(!(dataPtrArray[vIndex+1]))
{
errMsg = "cellsxfun: could not access data from input cell array\n";
mexErrMsgTxt(errMsg.c_str());
}
}
// calculate the offset to the output cell to update
mwIndex index = mxCalcSingleSubscript(outPtr[0], nDim, thisVec);
// invoke the user's function here - accept only one output and discrard the others
int status = false;
status = mexCallMATLAB(1, &resultArray, (int)nInputs+1, dataPtrArray, "feval");
// was there an error which did not abort directly from mexCallMATLAB first?
if(status)
{
errMsg = "cellsxfun: an error occured while invoking the user's function.\n";
mexErrMsgTxt(errMsg.c_str());
}
// copy the results from calcResult[0] into the cell contents located at offset index
mxSetCell(outPtr[0], index, mxDuplicateArray(resultArray));
// clear the temporarily allocated memory block
if(resultArray) mxDestroyArray(resultArray);
// update the counters - we will skip to here if the data was bad (uninitialized ?)
thisVec[0]++;
for(mwSize vIndex=0; vIndex < (nDim-1); vIndex++)
{
if(thisVec[vIndex] >= endVec[vIndex])
{
thisVec[vIndex] = startVec[vIndex];
thisVec[vIndex+1]++;
}
}
// terminal condition - here we break out of the loop
if(thisVec[nDim-1] >= endVec[nDim-1]) break;
}
// deallocation of allocated stuff
//if(dataPtrArray[0]) mxFree(dataPtrArray[0]);
if(dataPtrArray) mxFree(dataPtrArray);
if(endVec) mxFree(endVec);
if(startVec) mxFree(startVec);
if(thisVec) mxFree(thisVec);
if(workVec) mxFree(workVec);
if(argDims) mxFree(argDims);
}