/*
 * Copyright (C) 2004 Evan Thomas
 * 
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or (at
 * your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 *
 */

/*************************
  Controls the WR
  Interface to the solver
**************************/

#define P3_MODULE
#include "ndl.h"

static void copy_state(Cell*);
static void solver_init(GD*);
static void startWindow(GD*);
static void integrateUntil(GD*);
static void endWindow(GD*);
static void tracecapture(Cell*);
static void dyntracecapture(Cell*);
static void initCells(GD*);
static void initNet(PyObject*);
static void freeNet(PyObject*);

static int IterRound;

/* Extreme care required when updating this structure from call back
   routines (it may be accessed again at an earlier time) */
static Cell *currentCell;

static void interp(Cell *cell) {
  
  cell->lasttime += cell->timeSample;

  cell->interpolator(da_length(cell->Y), 
      cell->lasttime,
      cell->Yinterp->data, 
      cell->DYDT->data);

}

static void solve(Cell *cell) {
  /* Take the cell one step forward in time */
  double stepnext;
  double tolerance = gd->tolerance;
  double minStep = gd->minStep;
  Compartment *cmpt;
  int seq;

  currentCell = cell; /* This how call back routines know the current cell */

  seq = 0;
  while( (cmpt=il_seq_next(cell->compartments, &seq)) )
      cmpt->lastEm = GETEM_CMPT(cmpt, cell->time); /* required for AP detection */

  cell->laststepaccept = cell->solver(cell, tolerance, minStep, &stepnext);

  stepNext(cell, stepnext);
 
  cell->stepAccepts += cell->laststepaccept;
  cell->stepTotal++;

  if( cell->stepTrace ) do_cell_handler(cell, gd->stepTrace_handler);
}

jmp_buf excpt_exit;

void startGD(GD *gd) {
	gd->windowID = 0;
	gd->reDoCnt = 0;
	gd->roundReDoCnt = 0;
	gd->sendMsgCnt = 0;
	gd->recvMsgCnt = 0;
	gd->roundCnt = 0;
}

static double window;
static double windowStart;
static double windowEnd;


PyObject* parplex(void) {
  bool globalConvergence;

  window = min(gd->window, gd->duration);
  windowStart = 0;
  windowEnd = window;

  currentCell = (Cell*)Py_None;

  /* Propagate any exceptions raised by callbacks. */
  if( setjmp(excpt_exit) ) {
      freeNet(gd->network);
      return NULL;
  }
  
  startGD(gd);
  initMPI(gd);
  initNet(gd->network);
  initCells(gd);
  solver_init(gd);

  /* Main waveform relaxation loop. */
  while(windowStart < gd->duration) {
    message(debug, "New window at %g.\n", windowStart);

    gd->windowStart = windowStart;
    gd->firstRound = true;
    startWindow(gd);

    /* First and major integration of window. */
    gd->cellsToGo = pq_size();

    integrateUntil(gd);

    globalConvergence = exchangeMessages(gd);
    
    gd->firstRound = false;

    IterRound = 2;
    while(!globalConvergence) {
      message(debug, "Starting round %d.\n", IterRound);

	  /* Roll forward any cells that required redoing */
      integrateUntil(gd);
      
      /* Exchange events with other cells */
      globalConvergence = exchangeMessages(gd);

	  IterRound++;
	  gd->roundCnt++;
    }

    endWindow(gd);
    
    windowStart = windowEnd;
    windowEnd += window;
	windowEnd = min(windowEnd, gd->duration);
    gd->windowID++;
  }
  
  freeNet(gd->network);
  Py_INCREF(Py_None);
  return Py_None;
}

static void endWindow(GD *gd) {
    Cell *cell;
    Compartment *cmpt;
    Dynamics *d;
    int i, seq;

    pq_seq_start();

    while( (cell = pq_seq_next()) ) {
        apout(cell);
        if( cell->hasEmTrace||cell->hasDynTrace ) traceout(cell);
        seq = 0;
        while( (cmpt=il_seq_next(cell->compartments, &seq)) ) {
            for(i=0; i<il_length(cmpt->dynamics); i++) {
                d = (Dynamics*)il_get(cmpt->dynamics, i);
                if( d->cleanup ) 
                    d->cleanup(d, gd);
            }
        }
		stepCommit(cell);
    }

    if( gd->endWindow_handler )
        do_void_handler(gd->endWindow_handler, gd);
}

static void resetTraces(Dynamics *d) {
    int i;
    if( !d->trace ) return;
    da_clear(d->traceTimes);
    for(i=0; i<d->n; i++) da_clear(d->traceData[i]);
}

static void startWindow(GD *gd) {
  /* Take a snaphot of the state variables */
  Cell *cell;
  Compartment *cmpt;
  int i, seq;

  if( gd->startWindow_handler )
    do_void_handler(gd->startWindow_handler, gd);

  pq_seq_start();

  while( (cell = pq_seq_next()) ) {
    /* Cause the cell to step on window edge */
    stepOn(cell, windowEnd);
    
    /* Snap shot the cell's state at the start of the window */
    if( mpi_size>1 ) copy_state(cell);
      
    seq = 0;
    while( (cmpt=il_seq_next(cell->compartments, &seq)) ) {
        /* Reset AP counters. */
        da_clear(cmpt->previousAP); 
        da_clear(cmpt->currentAP);

        /* Reset the Em trace */
        da_clear(cmpt->traceTimes);
        da_clear(cmpt->traceData);

        /* reset the other state variable traces */
        for(i=0; i<il_length(cmpt->dynamics); i++)
            resetTraces(il_get(cmpt->dynamics, i));
    }
  }
}

static void integrateUntil(GD *gd) {
    double x;

  gd->roundReDoCnt = 0;

  while( gd->cellsToGo ) {
    /* Get the next cell */
    currentCell = pq_lazy_dequeue();

    /* Integrate to next time step */
    solve(currentCell);

    if( currentCell->laststepaccept ) {
      /* Action potential detection, and action */
      detectAP(gd, currentCell);

      /* If tracing is on ... */
      if( currentCell->hasEmTrace || currentCell->hasDynTrace ) {
          /* If tracing isn't sampled, just output. */
          if( currentCell->timeSample==0 ) {
              if( currentCell->hasEmTrace ) tracecapture(currentCell);
              if( currentCell->hasDynTrace ) dyntracecapture(currentCell);
          } else {
              /* Interpolate at the sample times and output */
              while ( currentCell->lasttime <= currentCell->time-currentCell->timeSample ) {
                  interp(currentCell);
                  if( currentCell->hasEmTrace ) tracecapture(currentCell);
                  if( currentCell->hasDynTrace ) dyntracecapture(currentCell);
              }
          }
      }
	}

    /* Sanity check... */
    if( !finite(currentCell->time) || isnan(currentCell->time) )
      ABEND("Integration error detected at sanity check time", currentCell);
    da_seq_start(currentCell->Y);
    while( da_seq_next(currentCell->Y, &x) )
        if( !finite(x) || isnan(x) )
            ABEND("Membrane insane - This is usually an error in the equations or parameters", currentCell);

    if( currentCell->time >= windowEnd ) {
      gd->cellsToGo--;
    }

    /* Re-insert for next time */
    pq_lazy_insert(currentCell);
  }
}

void setCurrentCell(Cell *c) {currentCell = c;}

void solver_init(GD* gd) {
  /* Iterate over the cells to find the largest state vector, or guess
     */
  Cell *cell;

  pq_seq_start();
  while( (cell = pq_seq_next()) )
	  cell->solverInit(gd, cell);

}

static void copy_state(Cell *cell) {
  /* Make a snapshot of the dynamic variables and integration
     variables at the start of the window. */
    Compartment *cmpt;
	int seq;

  cell->step_copy   = cell->step;
  cell->last_copy   = cell->lasttime;
  da_copy(cell->Y_copy,        cell->Y);
  da_copy(cell->DYDT_copy,     cell->DYDT);

  seq = 0;
  while( (cmpt=il_seq_next(cell->compartments, &seq)) )
        cmpt->lastEm_copy = cmpt->lastEm;
}

void derivs(double t, double *y, double *dydt) {
    Dynamics *d;
    Compartment *cmpt, *adjcmpt;
    double *Yhold    = currentCell->Y->data;
    double *DYDThold = currentCell->DYDT->data;
    double g;
    int i, seq1, seq2;

    currentCell->Y->data = y;
    currentCell->DYDT->data = dydt;

    seq1 = 0;
    while( (cmpt=il_seq_next(currentCell->compartments, &seq1)) ) {
        double dEmdt = 0;

        for(i=0; i<il_length(cmpt->dynamics); i++) {
            d = (Dynamics*)il_get(cmpt->dynamics, i);
            /* Call the derivative */
            if( d->derivs )
                d->derivs(d, t);

            /* Grab the this current contribution */
            if( d->current )
                dEmdt += d->current(d, t);
        }

		if( cmpt->ClampMode==CurrentClamp ) {
	        /* Contibution from adjoining compartments */
		    seq2 = 0;
			da_seq_start(cmpt->axial_conductance);
	        while( da_seq_next(cmpt->axial_conductance, &g), 
		           adjcmpt = il_seq_next(cmpt->compartments, &seq2) ) {
			    dEmdt += (GETEM_CMPT(adjcmpt, t) - GETEM_CMPT(cmpt, t)) * g;
			}

	        /* Fix up dEmdt */
		    dEmdt /= cmpt->capacitance;

	        SETDEMDT_CMPT(cmpt, dEmdt);
		}
    }

    currentCell->Y->data = Yhold;
    currentCell->DYDT->data = DYDThold;
}

static double getCurrent(Compartment *cmpt, double t) {
	double I = 0;
    Dynamics *d;
    Compartment *adjcmpt;
    int i, seq;
    double g;

    for(i=0; i<il_length(cmpt->dynamics); i++) {
	    d = (Dynamics*)il_get(cmpt->dynamics, i);
        /* Grab the this current contribution */
        if( d->current )
            I -= d->current(d, t);
    }

    /* Contibution from adjoining compartments */
    seq = 0;
	da_seq_start(cmpt->axial_conductance);
    while( da_seq_next(cmpt->axial_conductance, &g), 
           adjcmpt = il_seq_next(cmpt->compartments, &seq) ) {
	    I -= (GETEM_CMPT(adjcmpt, t) - GETEM_CMPT(cmpt, t)) * g;
	}

	return I;
}

static double getInterpCurrent(Compartment *cmpt, double t) {
	double I;
	Cell *cell = cmpt->owner;
    double *Yhold    = cell->Y->data;

	cell->Y->data = cell->Yinterp->data;
	I = getCurrent(cmpt, t);
	cell->Y->data = Yhold;

	return I;
}

static void tracecapture(Cell *cell) {
    Compartment *cmpt;
	int seq = 0;
	double x;

    while( (cmpt=il_seq_next(cell->compartments, &seq)) ) {
        if( !cmpt->emtrace ) continue;
        if( cell->timeSample > 0 ) {
            da_append(cmpt->traceTimes, cell->lasttime);
			if( cmpt->ClampMode==CurrentClamp )
	            x = GETEMINTERP_CMPT(cmpt);
			else
				x = getInterpCurrent(cmpt, cell->lasttime);
			da_append(cmpt->traceData,  x);
        } else {
            da_append(cmpt->traceTimes, cell->time);
			if( cmpt->ClampMode==CurrentClamp )
	            x = GETEM_CMPT(cmpt, cell->time);
			else
				x = getCurrent(cmpt, cell->time);
			da_append(cmpt->traceData,  x);
        }
    }

  return; 
}

static void dyntracecapture(Cell *cell) {
    Compartment *cmpt;
	int seq = 0;

	while( (cmpt=il_seq_next(cell->compartments, &seq)) ) {
        int i, l=il_length(cmpt->dynamics);
        for(i=0; i<l; i++)
            addTraceEntry((Dynamics*)il_get(cmpt->dynamics, i));
    }
}

static void initCells(GD *gd) {
    /* Any last minute initialisation of cells */
    Cell *cell;
    Compartment *cmpt;
    int i, seq;

    pq_seq_start();
    while( (cell = pq_seq_next()) ) {

        cell->hasDynTrace = false;
        cell->hasEmTrace = false;
	cell->gd = gd;

	cell->time = 0;
	cell->stepAccepts = 0;
	cell->stepTotal = 0;
	cell->functionCnt = 0;
	cell->newtonCnt = 0;
	cell->jacobianCnt = 0;
	cell->maxStepCnt = 0;
	cell->badStepCnt = 0;

        seq = 0;
        while( (cmpt=il_seq_next(cell->compartments, &seq)) ) {
            cell->hasEmTrace |= cmpt->emtrace;

            for(i=0; i<il_length(cmpt->dynamics); i++) {
                Dynamics *d = (Dynamics*)il_get(cmpt->dynamics, i);
                if( d->trace ) {
                    cell->hasDynTrace = true;
                    break;
                }
				if( d->voltage && cmpt->ClampMode==VoltageClamp )
					cmpt->Vclamper = d;
            }
        }

    }
}

static int cntCell(PyObject *network, Cell *cell) {
    int i, cnt=0;
    for(i=0; i<PyList_Size(network); i++) {
        cnt += (Cell*)PyList_GetItem(network, i)==cell;
    }
    return cnt;
}

static void initNet(PyObject *network) {
    /* The list of cells in the network is in a Python list
       gd.network.  This is loaded into the priority queue.
       The reference count of cells in incremented and some
       basic type and sanity checking is performed. */
    int i, size;
    int cellid = 0;
    Cell *cell;
    Compartment *cmpt;
	int cmpt_seq, syn_seq;

    if( !network ) {
        /* No network, no cells! */
        message(info, "No network.\n");
        return;
    }

    size = PyList_Size(network);
    if( size <= 0 ) {
        message(info, "Network with no cells.\n");
        return;
    }

    /* Increment everything blindly so that decrementing
       can be done blindly if we bail out at any point */
    for(i=0; i<size; i++)
        Py_INCREF(PyList_GetItem(network, i));

    /* Initialise the Q */
    pq_init(size);

    /* Load the Q, type check and sanity check */
    for(i=0; i<size; i++) {
        cell = (Cell*)PyList_GetItem(network, i);
        if( !PyObject_TypeCheck((PyObject*)cell, &p3_CellType) ) {
            PyErr_SetString(PyExc_TypeError,
                "An item in the network list is not a cell");
            longjmp(excpt_exit, 1);
        }
        /* Check that for numerical method */
        if( !cell->solver ) {
            PyErr_SetString(PyExc_AttributeError,
                "Cell has no numerical method assigned.");
            longjmp(excpt_exit, 1);
        }
        /* Check that the cell is only in the network once */
        if( cntCell(network, cell) != 1 ) {
            PyErr_SetString(PyExc_AttributeError,
                "A cell appears in the network more than once");
            longjmp(excpt_exit, 1);
        }
        /* Check that each cell that this cell is connected to
           is also in the network */
        cmpt_seq = 0;
        while( (cmpt=il_seq_next(cell->compartments, &cmpt_seq)) ) {
            Synapse *s;
            /* Check for zero capacitance */
            if( cmpt->capacitance==0 ) {
                PyErr_SetString(PyExc_AttributeError,
                    "The compartment has zero capacitance.");
                longjmp(excpt_exit, 1);
            }
            syn_seq = 0;
            while( (s=il_seq_next(cmpt->synapse, &syn_seq)) ) {
                Cell *tgt = s->target;
                if( cntCell(network, tgt) == 0 ) {
                    PyErr_SetString(PyExc_AttributeError,
                        "A cell has a connection to another cell that is not in the network");
                    longjmp(excpt_exit, 1);
                }
            }
        }

        /* We made it!  Add the cell to the Q */
        cell->id = cellid++;
        if( cell->isLocal ) pq_insert(cell);
    }

    return;
}

static void freeNet(PyObject *network) {
    /* Decrement every object in the network prior to returning
       to Python main */
    int i;
    if(network) {
        for(i=0; i<PyList_Size(network); i++)
            Py_DECREF(PyList_GetItem(network, i));
    }
}

/* Changes

  EAT 02Oct07
  - added gd->roundCnt++;

  EAT 06Feb08
  - first steps to making parplex reusable, added startGD() and
    reset some cell variables in initCells()

  EAT 28Apr08
  - finally, parplex will stop after gd->duration even if it is
    not a mutliple of the window size!

  EAT 01May08
  - added stepCommit call to endWindow
*/
