/*
 * Copyright (C) 2004 Evan Thomas
 * 
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or (at
 * your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 *
 */


/*******************************************************************
   A subclass of Dynamics that can be, in turn, subclassed to make
   a Dynamics implemented in Python.
********************************************************************/

#define P3_MODULE
#include "ndl.h"

static double route_C2PyCurrent(PyObject *self, double t) {
    PyObject *f;
    double rc;

    if( !PyObject_HasAttrString(self, "current") ) return 0;

    f = PyObject_CallMethod(self, "current", "d", t);
    
	if( !f ) {
		message(fatal, "Exception in route_PyCurrent()\n");
		longjmp(excpt_exit, 1);
	}
    if( !PyFloat_Check(f) && !PyInt_Check(f) ) {
        PyErr_SetString(PyExc_TypeError,
            "current must return a float");
        longjmp(excpt_exit, 1);
    }

    rc = PyFloat_AsDouble(f);
    Py_DECREF(f);

    return rc;
}

static double route_C2PyVoltage(PyObject *self, double t) {
    PyObject *f;
    double rc;

    if( !PyObject_HasAttrString(self, "voltage") ) return 0;

    f = PyObject_CallMethod(self, "voltage", "d", t);
    
	if( !f ) {
		message(fatal, "Exception in route_PyVoltage()\n");
		longjmp(excpt_exit, 1);
	}
    if( !PyFloat_Check(f) && !PyInt_Check(f) ) {
        PyErr_SetString(PyExc_TypeError,
            "current must return a float");
        longjmp(excpt_exit, 1);
    }

    rc = PyFloat_AsDouble(f);
    Py_DECREF(f);

    return rc;
}

static void route_C2PyAccepter(PyObject *self, Synapse *s, double t, int i) {
    PyObject *f;

    if( !PyObject_HasAttrString(self, "accepter") ) return;

    f = PyObject_CallMethod(self, "accepter", "Odi", s, t, i);
	if( !f ) {
		message(fatal, "Exception in route_PyAccepter()\n");
		longjmp(excpt_exit, 1);
	}

    Py_DECREF(f);

    return;
}

static void route_C2PyEnq(PyObject *self, double t, double s) {
    PyObject *f;

    if( !PyObject_HasAttrString(self, "enq") ) return;

    f = PyObject_CallMethod(self, "enq", "dd", t, s);
	if( !f ) {
		message(fatal, "Exception in route_PyEnq()\n");
		longjmp(excpt_exit, 1);
	}

    Py_DECREF(f);

    return;
}

static void route_C2PyDerivs(PyObject *self, double t) {
    PyObject *f;

    if( !PyObject_HasAttrString(self, "derivs") ) return;

    f = PyObject_CallMethod(self, "derivs", "d", t);
	if( !f ) {
		message(fatal, "Exception in route_PyDerivs()\n");
		longjmp(excpt_exit, 1);
	}

    Py_DECREF(f);

    return;
}

static void route_C2PyCleanup(PyObject *self, GD *gd) {
    PyObject *f;

    if( !PyObject_HasAttrString(self, "cleanup") ) return;

    f = PyObject_CallMethod(self, "cleanup", "O", gd);
	if( !f ) {
		message(fatal, "Exception in route_PyCleanup()\n");
		longjmp(excpt_exit, 1);
	}

    Py_DECREF(f);

    return;
}
    
static void route_C2PyDeleter(PyObject *self, Synapse *s, int i1, int i2) {
    PyObject *f;

    if( !PyObject_HasAttrString(self, "deleter") ) return;

    f = PyObject_CallMethod(self, "deleter", "Oii", s, i1, i2);
	if( !f ) {
		message(fatal, "Exception in route_PyDeleter()\n");
		longjmp(excpt_exit, 1);
	}


    Py_DECREF(f);

    return;
}

PyObject *p3_generic_get_statevariable(Dynamics *d, void *y) {
    int i = (int)y;
    return PyFloat_FromDouble(GETSTATE_DYN(d,i));
}

int p3_generic_set_statevariable(Dynamics *d, PyObject *data, void *y) {
    int i = (int)y;

    if( PyFloat_Check(data) ) {
        SETSTATE_DYN(d, i, PyFloat_AsDouble(data));
        return 0;
    }

    if( PyInt_Check(data) ) {
        SETSTATE_DYN(d, i, PyInt_AsLong(data));
        return 0;
    }

    PyErr_SetString(PyExc_TypeError, 
        "State variables must be numbers");
    return -1;
}

PyObject *p3_generic_get_derivvariable(Dynamics *d, void *y) {
    int i = (int)y;
    return PyFloat_FromDouble(GETDERIV_DYN(d,i));
}

int p3_generic_set_derivvariable(Dynamics *self, PyObject *data, void *y) {
    int i = (int)y;

    if( PyFloat_Check(data) ) {
        SETDERIV_DYN(self, i, PyFloat_AsDouble(data));
        return 0;
    }

    if( PyInt_Check(data) ) {
        SETDERIV_DYN(self, i, PyInt_AsLong(data));
        return 0;
    }

    PyErr_SetString(PyExc_TypeError, 
        "State variables must be integers or floats");
    return -1;
}

PyObject *p3_generic_get_tracevariable(Dynamics *d, void *y) {
    int i = (int)y;
    PyObject *t;

    if( !d->trace ) {
        PyErr_SetString(PyExc_AttributeError,
            "Trace variable requested for a Dynamics that is not being trace.");
        return NULL;
    }

    t = (PyObject*)d->traceData[i];
    Py_INCREF(t);

    return t;
}

static int addPyStateVars(PyObject *self) {
    /* The 'interesting' feature of this initialiser is that it looks
       for a class variable "stateVars" which is a list of the state 
       variables.  It then adds PyGetSetDefs for them and sets flag
       in the type dictionary so that it is only done once. */
    PyObject *state, *dxdt, *trace;
    int firstTimeFlag;
    Dynamics *d = (Dynamics*)self;
    int i, size;

    /* Have we already done this */
    firstTimeFlag = 
        PyDict_GetItemString(self->ob_type->tp_dict, "stateVarsFlag")==0;

    /* Set flag in the class dictionary so we don't do this again */
    if( firstTimeFlag ) {
        Py_INCREF(Py_None);
        if( PyDict_SetItemString(self->ob_type->tp_dict, "stateVarsFlag", Py_None)!=0 )
            return -1;
    }

    state = PyDict_GetItemString(self->ob_type->tp_dict, "stateVars");
    if( state ) size = PyList_Size(state);
    else size = 0;

    dxdt = PyDict_GetItemString(self->ob_type->tp_dict, "stateDerivs");
    if( state && (!dxdt || size!=PyList_Size(dxdt)) ) {
        PyErr_SetString(PyExc_AttributeError,
            "If stateVars are specified then a matching stateDerivs must be specified");
        return -1;
    }

    trace = PyDict_GetItemString(self->ob_type->tp_dict, "stateTrace");
    if( state &&  trace && size!=PyList_Size(trace) ) {
        PyErr_SetString(PyExc_AttributeError,
            "If stateVars are specified then stateTrace must be the same size");
        return -1;
    }

    /* Add the GetSetDefs for the state variables and their derivatives */
    if( state ) {        
        d->y = addState(d->owner->owner, size);
        d->n = size;
    }

    if( firstTimeFlag && state ) {
        /* Some type checking */
        if( !PyList_Check(state) || !PyList_Check(dxdt) ) {
            PyErr_SetString(PyExc_TypeError,
                "stateVars and derivVars must be lists of strings naming the variables");
            return -1;
        }
        for(i=0; i<size; i++) {
            if( !PyString_Check(PyList_GetItem(state, i)) || !PyString_Check(PyList_GetItem(dxdt, i)) ) {
                PyErr_SetString(PyExc_TypeError,
                    "stateVars and derivVars must be lists of strings naming the variables");
                return -1;
            }
        }

        for(i=0; i<size; i++) {
            PyGetSetDef *gs;
            PyObject *descr;
            char *name;

            /* ith state variable */
            name = PyString_AsString(PyList_GetItem(state, i));
            gs = PyMem_Malloc(sizeof(*gs));
            if( !gs ) {
                PyErr_NoMemory();
                return -1;
            }
            gs->name = name;
            gs->doc  = "State variable";
            gs->get  = (getter)p3_generic_get_statevariable;
            gs->set  = (setter)p3_generic_set_statevariable;
            gs->closure = (void*)i; /* Legal and portable methinks! */
            descr = PyDescr_NewGetSet(self->ob_type, gs);
            if( !descr ) return -1;
            if( PyDict_SetItemString(self->ob_type->tp_dict, name, descr)!=0 )
                return -1;

            /* ith derivative variable */
            name = PyString_AsString(PyList_GetItem(dxdt, i));
            gs = PyMem_Malloc(sizeof(*gs));
            if( !gs ) {
                PyErr_NoMemory();
                return -1;
            }
            gs->name = name;
            gs->doc  = "Derivative of a state variable";
            gs->get  = (getter)p3_generic_get_derivvariable;
            gs->set  = (setter)p3_generic_set_derivvariable;
            gs->closure = (void*)i; /* Legal and portable methinks! */
            descr = PyDescr_NewGetSet(self->ob_type, gs);
            if( !descr ) return -1;
            if( PyDict_SetItemString(self->ob_type->tp_dict, name, descr)!=0 )
                return -1;
        }
    }

    if( firstTimeFlag && state && trace) {
        /* Some type checking */
        if( !PyList_Check(trace) ) {
            PyErr_SetString(PyExc_TypeError,
                "stateTrace must be a list of strings naming the variables");
            return -1;
        }
        for(i=0; i<size; i++) {
            if( !PyString_Check(PyList_GetItem(trace, i)) ) {
                PyErr_SetString(PyExc_TypeError,
                    "stateTrace and derivVars must be a list of strings naming the variables");
                return -1;
            }
        }

        for(i=0; i<size; i++) {
            PyGetSetDef *gs;
            PyObject *descr;
            char *name;

            /* ith state variable */
            name = PyString_AsString(PyList_GetItem(trace, i));
            gs = PyMem_Malloc(sizeof(*gs));
            if( !gs ) {
                PyErr_NoMemory();
                return -1;
            }
            gs->name = name;
            gs->doc  = "State trace variable";
            gs->get  = (getter)p3_generic_get_tracevariable;
            gs->set  = 0;
            gs->closure = (void*)i; /* Legal and portable methinks! */
            descr = PyDescr_NewGetSet(self->ob_type, gs);
            if( !descr ) return -1;
            if( PyDict_SetItemString(self->ob_type->tp_dict, name, descr)!=0 )
                return -1;
        }
    }

    if( !state ) {
        d->y = -1;
        d->n = 0;
    }

    return 0;
 }

static int initPyDynamics(Dynamics *self, PyObject *args, PyObject *kw) {
    int rc = p3_DynamicsType.tp_init((PyObject*)self, args, kw);

    if( rc ) return rc;

    rc = addPyStateVars((PyObject*)self);
    if( rc ) return rc;

    /* Use routers defined in this file to do the C->Python call backs
       to the methods (hopefully) defined in the user's derived class */
    self->accepter = (accepterfcn*)route_C2PyAccepter;
    if( self->y>=0 ) self->derivs   = (derivsfcn*)route_C2PyDerivs;
    self->cleanup  = (cleanupfcn*)route_C2PyCleanup;
    self->current  = (currentfcn*)route_C2PyCurrent;
    self->voltage  = (currentfcn*)route_C2PyVoltage;
    self->enq      = (enqfcn*)route_C2PyEnq;

    return 0;
}

/* If these methods are called it means the user hasn't provided
   its own, therefore provide some mild default beahviour. We
   also set the C language method to zero so that we don't waste
   time looking for methods in the future. */
static PyObject *route_Py2Ccurrent(Dynamics *self, PyObject *args) {
    PyObject *i = PyFloat_FromDouble(0);
	self->current = 0;
	Py_INCREF(i);
	return i;
}
static PyObject *route_Py2CVoltage(Dynamics *self, PyObject *args) {
    PyObject *i = PyFloat_FromDouble(0);
	self->voltage = 0;
	Py_INCREF(i);
	return i;
}
static PyObject *route_Py2Cenq(Dynamics *self, PyObject *args) {
    double strength=1, time;
	if( !PyArg_ParseTuple(args, "d|d", &time, &strength) )
		return 0;	  
	stepOnInt(self->owningCell, time, self, 0, strength, -1, enqmsg);
	Py_INCREF(Py_None);
	return Py_None;
}
static PyObject *route_Py2Cderivs(Dynamics *self, PyObject *args) {
	self->enq = 0;
    Py_INCREF(Py_None);
    return Py_None;
}
static PyObject *route_Py2Caccepter(Dynamics *self, PyObject *args) {
	self->accepter = 0;
    Py_INCREF(Py_None);
    return Py_None;
}
static PyObject *route_Py2Ccleanup(Dynamics *self, PyObject *args) {
	self->cleanup = 0;
    Py_INCREF(Py_None);
    return Py_None;
}

/* Override the default methods, provided in Dynamics */
static PyMethodDef methods[] = {
    {"enq", (PyCFunction)route_Py2Cenq, METH_VARARGS, "enq(time, strength): add an exogenous event"},
    {"current", (PyCFunction)route_Py2Ccurrent, METH_VARARGS, "current(time): current passing through this Dynamics"},
    {"derivs", (PyCFunction)route_Py2Cderivs, METH_VARARGS, "derivs(time): derivative function"},
    {"accepter", (PyCFunction)route_Py2Caccepter, METH_VARARGS, "accepter(fromSynapse, strength, windowid): accept a synaptic event"},
    {"cleanup", (PyCFunction)route_Py2Ccleanup, METH_VARARGS, "cleanup(gd): perform end of window resource management"},
    {"voltage", (PyCFunction)route_Py2CVoltage, METH_VARARGS, "voltage(time): voltage clamp waveform"},
    {NULL, NULL, 0, NULL}  /* sentinel */
};


PyTypeObject p3_PyDynamicsType = {
    PyObject_HEAD_INIT(NULL)
    0,                         /*ob_size*/
    "_p3.PyDynamics",             /*tp_name*/
    sizeof(Dynamics),      /*tp_basicsize*/
    0,                         /*tp_itemsize*/
    0,                         /*tp_dealloc*/
    0,                         /*tp_print*/
    0,                         /*tp_getattr*/
    0,                         /*tp_setattr*/
    0,                         /*tp_compare*/
    0,                         /*tp_repr*/
    0,                         /*tp_as_number*/
    0,                         /*tp_as_sequence*/
    0,                         /*tp_as_mapping*/
    0,                         /*tp_hash */
    0,                         /*tp_call*/
    0,                         /*tp_str*/
    0,                         /*tp_getattro*/
    0,                         /*tp_setattro*/
    0,                         /*tp_as_buffer*/
    Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,        /*tp_flags*/
    "Parplex PyDynamics object",   /* tp_doc */
    0,		               /* tp_traverse */
    0,		               /* tp_clear */
    0,		               /* tp_richcompare */
    0,		               /* tp_weaklistoffset */
    0,		               /* tp_iter */
    0,		               /* tp_iternext */
    methods,             /* tp_methods */
    0,             /* tp_members */
    0,          /* tp_getset */
    &p3_DynamicsType,                         /* tp_base */
    0,                         /* tp_dict */
    0,                         /* tp_descr_get */
    0,                         /* tp_descr_set */
    0,                         /* tp_dictoffset */
    (initproc)initPyDynamics,      /* tp_init */
    0,                         /* tp_alloc */
    PyType_GenericNew                 /* tp_new */
};
