/*
* pynestpycsa.cpp
*
* This file is part of NEST.
*
* Copyright (C) 2004 The NEST Initiative
*
* NEST is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 2 of the License, or
* (at your option) any later version.
*
* NEST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with NEST. If not, see <http://www.gnu.org/licenses/>.
*
*/
#include "pynestpycsa.h"
#include <string>
#include <iostream>
static PyObject *NESTError = NULL;
static PyObject* pMask = 0;
static PyObject* pConnectionSet = 0;
static PyObject* pCSAClasses = 0;
static PyObject* pArity = 0;
static PyObject* pCross = 0;
static PyObject* pPartition = 0;
static void
error (std::string errstring)
{
PYGILSTATE_ENSURE (gstate);
PyErr_SetString (NESTError, errstring.c_str ());
PYGILSTATE_RELEASE (gstate);
}
static bool
CSAimported ()
{
return PyMapping_HasKeyString (PyImport_GetModuleDict (), (char*)"csa");
}
static bool
loadCSA ()
{
PYGILSTATE_ENSURE (gstate);
PyObject* pModule = PyMapping_GetItemString (PyImport_GetModuleDict (), (char*)"csa");
pMask = PyObject_GetAttrString (pModule, "Mask");
if (pMask == NULL)
{
Py_DECREF (pModule);
PYGILSTATE_RELEASE (gstate);
error ("Couldn't find the Mask class in the CSA library");
return false;
}
pConnectionSet = PyObject_GetAttrString (pModule, "ConnectionSet");
if (pConnectionSet == NULL)
{
Py_DECREF (pModule);
PYGILSTATE_RELEASE (gstate);
error ("Couldn't find the ConnectionSet class in the CSA library");
return false;
}
pArity = PyObject_GetAttrString (pModule, "arity");
pCross = PyObject_GetAttrString (pModule, "cross");
pPartition = PyObject_GetAttrString (pModule, "partition");
Py_DECREF (pModule);
if (pArity == NULL)
{
PYGILSTATE_RELEASE (gstate);
error ("Couldn't find the arity function in the CSA library");
return false;
}
pCSAClasses = PyTuple_Pack (2, pMask, pConnectionSet);
PYGILSTATE_RELEASE (gstate);
return true;
}
bool PyPyCSA_Check (PyObject* obj)
{
if (pCSAClasses == 0)
{
if (!CSAimported ())
return false;
// load CSA library
bool status = loadCSA ();
if (!status)
return false;
}
return PyObject_IsInstance (obj, pCSAClasses);
}
PyCSAGenerator::PyCSAGenerator (PyObject* obj)
: pCSAObject (obj), pPartitionedCSAObject (NULL), pIterator (NULL)
{
PYGILSTATE_ENSURE (gstate);
Py_INCREF (pCSAObject);
PyObject* a = PyObject_CallFunctionObjArgs (pArity, pCSAObject, NULL);
arity_ = PyInt_AsLong (a);
Py_DECREF (a);
PYGILSTATE_RELEASE (gstate);
}
PyCSAGenerator::~PyCSAGenerator ()
{
PYGILSTATE_ENSURE (gstate);
Py_XDECREF (pIterator);
Py_XDECREF (pPartitionedCSAObject);
Py_DECREF (pCSAObject);
PYGILSTATE_RELEASE (gstate);
}
int
PyCSAGenerator::arity ()
{
return arity_;
}
PyObject*
PyCSAGenerator::makeIntervals (IntervalSet& iset)
{
PyObject* ivals = PyList_New (0);
if (iset.skip () == 1)
{
for (IntervalSet::iterator i = iset.begin (); i != iset.end (); ++i)
PyList_Append (ivals,
PyTuple_Pack (2,
PyInt_FromLong (i->first),
PyInt_FromLong (i->last)));
}
else
{
for (IntervalSet::iterator i = iset.begin (); i != iset.end (); ++i)
{
int last = i->last;
for (int j = i->first; j < last; j += iset.skip ())
PyList_Append (ivals,
PyTuple_Pack (2,
PyInt_FromLong (j),
PyInt_FromLong (j)));
}
}
return ivals;
}
void
PyCSAGenerator::setMask (std::vector<Mask>& masks, int local)
{
PYGILSTATE_ENSURE (gstate);
PyObject* pMasks = PyList_New (masks.size ());
for (size_t i = 0; i < masks.size (); ++i)
{
PyObject* pMask
= PyObject_CallFunctionObjArgs (pCross,
makeIntervals (masks[i].sources),
makeIntervals (masks[i].targets),
NULL);
PyList_SetItem (pMasks, i, pMask);
}
Py_XDECREF (pPartitionedCSAObject);
pPartitionedCSAObject = PyObject_CallFunctionObjArgs (pPartition,
pCSAObject,
pMasks,
PyInt_FromLong (local),
NULL);
if (pPartitionedCSAObject == NULL)
{
PYGILSTATE_RELEASE (gstate);
std::cerr << "Failed to create masked CSA object" << std::endl;
return;
}
Py_INCREF (pPartitionedCSAObject); //*fixme* check if necessary!
PYGILSTATE_RELEASE (gstate);
}
int
PyCSAGenerator::size ()
{
PYGILSTATE_ENSURE (gstate);
int size = PySequence_Size (pCSAObject);
PYGILSTATE_RELEASE (gstate);
return size;
}
void
PyCSAGenerator::start ()
{
if (pPartitionedCSAObject == NULL)
{
error ("CSA connection generator not properly initialized");
return;
}
PYGILSTATE_ENSURE (gstate);
Py_XDECREF (pIterator);
pIterator = PyObject_GetIter (pPartitionedCSAObject);
PYGILSTATE_RELEASE (gstate);
}
bool
PyCSAGenerator::next (int& source, int& target, double* value)
{
if (pIterator == NULL)
{
error ("Must call start() before next()");
return false;
}
PYGILSTATE_ENSURE (gstate);
PyObject* tuple = PyIter_Next (pIterator);
PyObject* err = PyErr_Occurred ();
if (err)
{
PYGILSTATE_RELEASE (gstate);
return false;
}
if (tuple == NULL)
{
Py_DECREF (pIterator);
pIterator = NULL;
PYGILSTATE_RELEASE (gstate);
return false;
}
source = PyInt_AsLong (PyTuple_GET_ITEM (tuple, 0));
target = PyInt_AsLong (PyTuple_GET_ITEM (tuple, 1));
for (int i = 0; i < arity_; ++i)
{
PyObject* v = PyTuple_GET_ITEM (tuple, i + 2);
if (!PyFloat_Check (v))
{
Py_DECREF (tuple);
PYGILSTATE_RELEASE (gstate);
error ("NEST cannot handle non-float CSA value sets");
return false;
}
value[i] = PyFloat_AsDouble (v);
}
Py_DECREF (tuple);
PYGILSTATE_RELEASE (gstate);
return true;
}
void
PyCSA_init(void)
{
NESTError = Py_BuildValue("s", "NESTError");
}