/*
 * Copyright (C) 2004 Evan Thomas
 * 
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or (at
 * your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 *
 */

/******************************************************************
  This file contains the logic for sending and receiving synaptic
  messages, both within this node and to remote nodes.  It also
  determines global and local convergence for the current window.
*******************************************************************/

#define P3_MODULE
#include "ndl.h"

#ifdef MPI
#include <mpi.h>
#endif

typedef struct qmsgtag {
  struct qmsgtag *next;
  double msgtime;
  int    synapse_id;
  int    tgtrank;
  int    cell_id;
  double strength;
  bool  retract;
  Synapse *s;
} Qmsg;

/*!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
  Warning: If the structure changes don't
  forget to change the MPI definition.
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/
typedef struct {
  double msgtime;
  int    synapse_id;
  int    cell_id;
  bool  retract;
  double strength;
} sendMsg;

/* MPI send buffers */
#define BUF_INCR  1000
static sendMsg **msgBuffers;
static int  *BufSize;
static int  *BufEnd;

#ifdef MPI
/* MPI receive buffers */
static int recvBufSize;
static sendMsg *recvBuf;

static void  dispatchMsgs(GD*, int);
#endif


static Qmsg *newQmsg(void);
static void  sendLocalMsg(Synapse*, int, double, double);
static void  QMsg(Synapse*, Compartment*, GD*);
static void  gatherCellMsgs(GD*, Cell*);
static void  gatherMsg(GD*, Qmsg*);
static void  exchangeRemoteMessages(GD*);
static void  sendRemoteMessages(GD*, int);
static void  receiveRemoteMessages(GD*, int);
static void  freeCellMsgs(Cell *c);
static bool  hasConverged(GD*, Cell*);
static void  revertCell(GD*, Synapse*);
static void  recordAP(GD*, Compartment*);
static bool  globalConvergence(GD*, bool);
static void  stripOldMessages(Synapse*, int);
static void  retractLocalMsg(GD*, Synapse*, int);


/* Tags used in this module */
#define SYNAPTIC_MSG     314
#define CONVERGENCE_MSG  315
#define SYNAPTIC_MSG_CNT 316

static Qmsg * newQmsg(void) {
  return getmain(sizeof(Qmsg));
}

#ifdef MPI
/* Undocumented feature: these need to be available anytime the data
   type it used. */
static int          blocklen[5] = {1, 1, 1, 1, 1};
static MPI_Aint     disp[5];
static MPI_Datatype type[5]     = {MPI_DOUBLE, MPI_INT, MPI_INT, MPI_INT,
				   MPI_DOUBLE};
static MPI_Datatype sendMsg_MPI;
#endif

extern Synapse **synapseIndex;
void dumpbuffer(int num, sendMsg *buf, char *s) {
  int i;
  sendMsg *msg;
  for(i=0; i<num; i++) {
    msg = &buf[i];
    debug_msg("%s i=%d msg->msgtime=%g, msg->synapse_id=%d, "
	    "msg->cell_id=%d, msg->retract=%d tgt_cell_id=%d\n",
	    s, i, msg->msgtime, msg->synapse_id, 
		msg->cell_id, msg->retract, synapseIndex[msg->synapse_id]->target->id);
  }
}

void initMPI(GD *gd) {
#ifdef MPI
  /*
    1) Create the MPI_TYPE_STRUCT for synaptic events
    2) Init message buffers
   */
  int i;
  sendMsg foo;
  if( mpi_size == 1 ) return;  /* Not running in parallel. */

  /* Set up the MPI send message structure. */
  i = 0;
  MPI_Address(&foo.msgtime,    &disp[i++]);
  MPI_Address(&foo.synapse_id, &disp[i++]);
  MPI_Address(&foo.cell_id,    &disp[i++]);
  MPI_Address(&foo.retract,    &disp[i++]);
  MPI_Address(&foo.strength,   &disp[i++]);
  for(i=sizeof(disp)/sizeof(disp[0])-1; i>=0; i--) disp[i] -= disp[0];
  MPI_Type_struct(sizeof(disp)/sizeof(disp[0]),
		  blocklen, disp, type, &sendMsg_MPI);
  MPI_Type_commit(&sendMsg_MPI);

    
  /* Allocate send buffers */
  BufSize    = getmain(sizeof(*BufSize)*mpi_size);
  BufEnd     = getmain(sizeof(*BufEnd )*mpi_size);
  msgBuffers = getmain(sizeof(sendMsg*)*mpi_size);
  for(i=0; i<mpi_size; i++) {
    BufSize[i] = BUF_INCR;
    BufEnd[i]  = 0;
    msgBuffers[i] = getmain(sizeof(sendMsg)*BUF_INCR);
  }
#endif
}

void detectAP(GD *gd, Cell *cell) {
    Synapse *s;
    Compartment *cmpt;
    int i, seq = 0;

    while( (cmpt=il_seq_next(cell->compartments, &seq)) ) {
        double threshold = cmpt->APthreshold;
		double currentEm = GETEM_CMPT(cmpt, cell->time);

        /* Has AP fired? */
        if( cmpt->lastEm>=threshold || currentEm<=threshold )
            continue;

        /* Yes! cell has fired */

        /* record AP time */
        recordAP(gd, cmpt);

        /* Scan synapses, notify targets, adjust efficacy. */
         for(i=0; i<il_length(cmpt->synapse); i++) {
            s = (Synapse*)il_get(cmpt->synapse, i);
            if( s->tgtrank==mpi_rank && gd->firstRound ) {
                /* Local message and first round */
                sendLocalMsg(s, gd->windowID, 0, 0);
            } else if( s->target->id == s->owningCell->id ) {
                /* Send 'auto' messages synchronously */
                sendLocalMsg(s, gd->windowID, 0, 0);
            } else {
                QMsg(s, cmpt, gd);
            }
        }
    }
}

static void sendLocalMsg(Synapse *s, int window_id, double msgtime, double msgstrength) {
  Dynamics *d = s->target_dynamics;
  double strength;
  double arrvtime;
  Cell *cell = s->owningCell;
	
  if( !d->accepter )
	  ABEND("Attempt to send message to a Dynamics without an acceptor method.\n",
	  s->owningCell);

  if( cell->isLocal ) {
	  arrvtime = cell->time + s->trans_time;
	  strength = s->nominal_strength;
  } else {
	  arrvtime = msgtime + s->trans_time;
	  strength = msgstrength;
  }
  stepOnInt(s->target, arrvtime, d, s, strength, window_id, acceptermsg);
}

void sendLocalMsgNow(Dynamics *d, Synapse *s, double strength, int window_id, MsgType msgtype) {

	switch( msgtype) {
	  case enqmsg:
		  d->enq(d, d->owningCell->time, strength);
		  break;
	  case acceptermsg:
		  d->accepter(d, s, strength, window_id);
		  break;
	  case nomsg:
		  ABEND("Logic error in sendLocalMsgNow.\n", d->owningCell);
		  break;
  }
}

static void QMsg(Synapse *s, Compartment *c, GD *gd) {
  /* Add an outgoing message to the cell's message queue */
  Qmsg *newmsg = newQmsg();

  newmsg->next       = c->msgQ;
  c->msgQ            = newmsg;
  newmsg->msgtime    = c->owner->time;
  newmsg->synapse_id = s->synapse_id;
  newmsg->tgtrank    = s->tgtrank;
  newmsg->strength   = s->nominal_strength;
  newmsg->s          = s;
  newmsg->cell_id    = s->owner->owner->id;
  newmsg->retract    = false;
}

bool exchangeMessages(GD *gd) {
  /*
     1) Send out going synaptic messages to appropiate targets.
        - Send all messages on the first round.
	- Send only messages for cells that have changed on subsequent
          rounds.
	- Send an empty buffer to targets with no actual messages.
     3) Listen for incoming messages.
     4) Requeue cells receiving messages.
     5) Count incoming messages until all have arrived.
     6) Exchange local convergence status and determine global
        convergence. 
  */
  
  Cell *cell;
  Compartment *cmpt;
  bool cellConvergence;
  bool localConvergence = true;
  int seq;

  /* Gather up all outgoing messages and determine local convergence. */
  pq_seq_start();
  while( (cell = pq_seq_next()) ) {
      if( gd->firstRound ) {
          cellConvergence = true;
          seq = 0;
          while( (cmpt=il_seq_next(cell->compartments, &seq)) ) {
              cellConvergence &= cmpt->msgQ==0;
          }
      }
      else if( cell->reDo )
          cellConvergence = hasConverged(gd, cell);
      else
          cellConvergence = true;

    localConvergence = localConvergence && cellConvergence;

	if( !cellConvergence ) {
      gatherCellMsgs(gd, cell);
	}

    freeCellMsgs(cell);
    cell->reDo = false;
  }

  exchangeRemoteMessages(gd);

  if( gd->roundReDoCnt )
	  message(info, "Redoing %d cells.\n", gd->roundReDoCnt);

  return globalConvergence(gd, localConvergence);
}

static void gatherCmptMsgs(GD *gd, Compartment *cmpt) {
    /* 1) Gather up messages for remote destinations into buffers ready
    to be sent efficiently.
    2) Send local messages immediately and mark targets for redo
    */
    Qmsg *m, rm;
    int wid = gd->windowID;

    m = cmpt->msgQ;

    if( !m ) { /* Cell fired previous round, but not this round. */
        Synapse *s;
        int i;
        message(debug, "Retracting messages sent from cell id=%d.\n", cmpt->owner->id);
        for(i=0; i<il_length(cmpt->synapse); i++) {
            s = (Synapse*)il_get(cmpt->synapse, i);
            if( s->tgtrank == mpi_rank ) {
                revertCell(gd, s);
                retractLocalMsg(gd, s, wid);
            } else {
                rm.retract    = true;
                rm.synapse_id = s->synapse_id;
                rm.tgtrank    = s->tgtrank;
                rm.s          = s;
                rm.cell_id    = s->owner->owner->id;
                gatherMsg(gd, &rm);
            }
        }
    }

    while(m) {
        if( m->tgtrank==mpi_rank ) {
            revertCell(gd, m->s);
            sendLocalMsg(m->s, wid, 0, 0);
        } else {
            gatherMsg(gd, m);
            message(debug, "Sent message from cell->id=%d to cell->id=%d "
                "in node %d.\n", cmpt->owner->id, m->s->target->id, m->tgtrank);
        }
        m = m->next;
    }
}

static void gatherCellMsgs(GD *gd, Cell *cell) {
    Compartment *cmpt;
	int seq = 0;
    while( (cmpt=il_seq_next(cell->compartments, &seq)) )
        gatherCmptMsgs(gd, cmpt);
}

static void retractLocalMsg(GD *gd, Synapse *s, int window_id) {
  /* Messages need to be retracted from this target */
  Dynamics *d = s->target_dynamics;

  removeStep(s, window_id);
  return;
}

static void gatherMsg(GD *gd, Qmsg *m) {
  int tgtrank = m->tgtrank;
  sendMsg *newmsg;

  /* Increment the size of the buffer, if necessary. */
  if( BufEnd[tgtrank] >= BufSize[tgtrank] ) {
    BufSize[tgtrank]   += BUF_INCR;
    msgBuffers[tgtrank] = PyMem_Realloc(msgBuffers[tgtrank],
				  BufSize[tgtrank]*sizeof(sendMsg));
    if( !msgBuffers[tgtrank] ) {
      message(fatal, "Could not allocate %d bytes.\n",
	      BufSize[tgtrank]);
      abort();
    }
  }
  
  newmsg = &msgBuffers[tgtrank][BufEnd[tgtrank]++];
  newmsg->msgtime    = m->msgtime;
  newmsg->synapse_id = m->synapse_id;
  newmsg->cell_id    = m->cell_id;
  newmsg->strength   = m->strength;
  newmsg->retract    = m->retract;

}

static void exchangeRemoteMessages(GD *gd) {
  /*
    Send the gather messages to their targets.
  */
  int i;

  for(i=0; i<mpi_size; i++) {
    if( i==mpi_rank ) continue;
    if( i<mpi_rank ) {  /* Avoids deadlock */
      sendRemoteMessages(gd, i);
      receiveRemoteMessages(gd, i);
    } else {
      receiveRemoteMessages(gd, i);
      sendRemoteMessages(gd, i);
    }
  }
}

static void sendRemoteMessages(GD *gd, int rank) {
#ifdef MPI
  int num = BufEnd[rank];
  
  if( num )
     message(info, "Sending %d messages to %d windowID=%d.\n",
	 num, rank, gd->windowID);
  gd->sendMsgCnt += num;

  /* Send count */
  MPI_Send(&num, 1, MPI_INT, rank, SYNAPTIC_MSG_CNT, MPI_COMM_WORLD);

  /* Send data */
  if( num ) {
    MPI_Send(msgBuffers[rank],
	     num,
	     sendMsg_MPI,
	     rank,
	     SYNAPTIC_MSG,
	     MPI_COMM_WORLD);
  }

  /* Reset send buffer */
  BufEnd[rank] = 0;
#endif
}


#ifdef MPI
static void dispatchMsgs(GD *gd, int num) {
  /* Send messages that have come in remotely to their appropriate
     targets.
  */
  int i;
  Synapse *s;
  sendMsg *msg;
  extern Synapse **synapseIndex;
  double arrvtime;

  for(i=0; i<num; i++) {
    msg = &recvBuf[i];
    s = synapseIndex[msg->synapse_id];
    message(debug, "Received message from cell->id=%d to cell->id=%d.\n",
		msg->cell_id, s->target->id);

    /* If these new messages are not for this window, and it is not
       the first round, then they don't need to be stripped. */
	arrvtime = s->trans_time+s->owningCell->time;
	if( !gd->firstRound || arrvtime<=(gd->window+gd->windowStart) )
		revertCell(gd, s);

	if( msg->retract )
      retractLocalMsg(gd, s, gd->windowID);
    else
		sendLocalMsg(s, gd->windowID, msg->msgtime, msg->strength);
  }
}
#endif

static void receiveRemoteMessages(GD *gd, int rank) {
#ifdef MPI
  MPI_Status rc;
  int num;

  /* Receive message count */
  MPI_Recv(&num, 1, MPI_INT, rank, SYNAPTIC_MSG_CNT, MPI_COMM_WORLD, &rc);

  /* Increase buffer size, if necassary. */
  if( num>recvBufSize ) {
    if( recvBuf ) PyMem_Free(recvBuf);
    recvBufSize = num;
    recvBuf = getmain(recvBufSize*sizeof(*recvBuf));
  }

  /* Receive messages */
  if( num ) {
    MPI_Recv(recvBuf,
	     num,
	     sendMsg_MPI,
	     rank,
	     SYNAPTIC_MSG,
	     MPI_COMM_WORLD,
	     &rc);
  }
  
  if( num )
	  message(info, "Received %d messages from %d windowID=%d.\n",
	 num, rank, gd->windowID);
  gd->recvMsgCnt += num;

  /* Forward on to local targets */
  dispatchMsgs(gd, num);
#endif
}

static void freeCellMsgs(Cell *c) {
    Compartment *cmpt;
  Qmsg *m;
  Qmsg *tmp;
  int seq = 0;

  while( (cmpt=il_seq_next(c->compartments, &seq)) ) {
      m = cmpt->msgQ;
      while(m) {
          tmp = m->next;
          PyMem_Free(m);
          m = tmp;
      }
      cmpt->msgQ = 0;
  }
}

static bool hasConverged(GD *gd, Cell *cell) {
  /* Determine if cell has significantly different outputs from last
     time? Cell action potentials will have occurred at almost the
     same times. */
  int i;
  Compartment *cmpt;
  int seq = 0;
  double err;

  while( (cmpt=il_seq_next(cell->compartments, &seq)) ) {  
	  if( da_length(cmpt->previousAP) != da_length(cmpt->currentAP) ) {
          return false;
	  }
  
      for(i=0; i<da_length(cmpt->currentAP); i++) {
		  err = fabs(da_get(cmpt->currentAP, i) - da_get(cmpt->previousAP, i));
          if( err/100 >  gd->tolerance ) {
          return false;
		}
      } 
  }

  return true;
}

void static revertCell(GD *gd, Synapse *s) {
  /* Revert the cell to the state it was in at the beginning of the
     window.
  */
  double_array *p;
  Cell *cell = s->target;
  Compartment *cmpt;
  int seq;

  /* strip off old messages from this synapse */
  stripOldMessages(s, gd->windowID);

  if( cell->time != gd->windowStart)
	  pq_requeue(cell, gd->windowStart, cell->step_copy);

  if( cell->reDo ) return;
  
  gd->cellsToGo++;
  gd->reDoCnt++;
  gd->roundReDoCnt++;

  message(debug, "Redoing cell: cell->id=%d.\n", cell->id);

  /* Revert state variables */
  da_copy(cell->Y, cell->Y_copy);
  da_copy(cell->DYDT, cell->DYDT_copy);
  cell->lasttime = cell->last_copy;
  cell->reDo     = true;

  seq = 0;
  while( (cmpt=il_seq_next(cell->compartments, &seq)) ) {
      cmpt->lastEm   = cmpt->lastEm_copy;

      /* Swap AP time stamp buffers. */
      p = cmpt->currentAP;
      cmpt->currentAP = cmpt->previousAP;
      cmpt->previousAP = p;
  
      /* Reset AP counters. */
      da_clear(cmpt->currentAP);

      /* Reset the trace */
      da_clear(cmpt->traceTimes);
      da_clear(cmpt->traceData);
  }

  stepRewind(cell);
}

static void stripOldMessages(Synapse *s, int window_id) {
    /* The target cell is having new messages added to it from a source
    cell that is being recalculated.  Stip messages issued in this
    window, but from a previous round. */
    Dynamics *d = s->target_dynamics;
    Synapse *t;
    int i;

    /* Strip messages from the target cell */
	removeStep(s, window_id);

    /* Messages from the cell to itself need to be found and deleted */
    for(i=0; i<il_length(s->owner->synapse); i++) {
        t = (Synapse*)il_get(s->owner->synapse, i);
        if( t->target == t->owningCell ) {
			removeStep(t, window_id);
        }
    }
}


static void recordAP(GD *gd, Compartment *cmpt) {
  da_append(cmpt->currentAP, cmpt->owner->time);
}

static bool globalConvergence(GD *gd, bool localConvergence) {
#ifdef MPI
  /* Do an MPI global thingy to detemine whether everyone else has
     converged.
  */
  bool gc;

  if( mpi_size > 1 ) {
    /* This is nice and easy. */
    MPI_Allreduce(&localConvergence, &gc, 1, MPI_INT, 
		  MPI_LAND, MPI_COMM_WORLD);
    return gc;
  } else
    return localConvergence;
#else
  return true;
#endif
}

/* CHANGELOG
   EAT 27/11/02
   changed the revertCell requeue action to pend the requeue so that
   the queue maintains integrity during the convergence check loop.

   EAT 10Dec02
   Nonlocal messages were using synaptic strengths taken from the
   synapse structure in the target node, however the activity
   dependent modification takes place in the source node.  The code
   was modified to send the synaptic strength at the time of the
   action potential.

   EAT 23Jan03
   Allowed conditional compilation of MPI.

   EAT 29APR04
   V0.2

   EAT 31AUG07
   - Reinstate MPI compilation
   - Change compile define from double negative (#ifndef NOMPI -> #ifdef MPI)
   - Move MPI_Init call from pymain to here

   EAT 10Sep07
   - memcpy restore of mustStep was not using the full buffer length.

   EAT 01Oct07
   - Dropped use of the old 'pend' queue methods since it didn't
     provide any benefits. This allowed updating of cell->time in
	 revertCell, which in turn fixed a couple of bugs with the
	 synchronous messaging.

   EAT 01Oct07
   - added partition performance counters.

   EAT 08Oct07
   - sendLocalMsg modified to accept the cell->time from remote cells
     (previously it stupidly picked the time from the dummy copy
	 of the local structure.)
   - send/recv message only printed when msgcount>0

   EAT 01May08
   - added stepRewind call to the resetCell thingy

   EAT 05May08
   - got rid of asynchronous message support
*/
