/*
 * Copyright (C) 2004 Evan Thomas
 * 
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or (at
 * your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 *
 */

/***************************************************************
Implementation of the famous Hines method for compartmental
models consisting only of Hodgkin-Huxely like channels.  It also
includes Karoly Antal's step size adaptation method and Karoly's
adaptation to handle internal Ca and Ca dependent currents
****************************************************************/

#define P3_MODULE
#include "ndl.h"

#include <stdio.h>

typedef struct {
	int     n;
	int    *EliminateOrder;
	int    *nodes;
	int     nodeCnt;
	double *upper;
	double *lower;
	double *diag;
	double *offdiagval;
	double *offdiagvalindx;
	int    *offdiagindx;
	int    *offdiagcol;
	int    *offdiagrow;
	int     offdiagcnt;
} solverData;

static double *RHS, *newupper, *newoffdiagval, *newdiag, *Ysave, *Ysave2;
static int bufsize = 0;
static int statebufsize = 0;

static void pdv(char *s, double *x, int n) {
	int i;
	printf("%s", s);
	for(i=0; i<n; i++) printf(" %g", x[i]);
	printf("\n");
}
static void piv(char *s, int *x, int n) {
	int i;
	printf("%s", s);
	for(i=0; i<n; i++) printf(" %d", x[i]);
	printf("\n");
}

static void makeEliminateOrder(Compartment *thisCmpt, 
							   Compartment *lastCmpt,
							   int *EliminateOrder,
							   int *nodes,
							   int *nodeCnt,
							   int orderLength,
							   int *indx) {
	int i, n = il_length(thisCmpt->compartments);

	if( *indx>=orderLength ) {
		ABEND("Cell compartments seem to be cyclic.\n", thisCmpt->owner);
	}

	if( n>=3 ) 
		nodes[(*nodeCnt)++] = *indx+1;


	EliminateOrder[(*indx)++] = thisCmpt->ownerindx;

	thisCmpt->treeOrderCol = getmain(n*sizeof(int));
	thisCmpt->treeOrderRow = *indx-1;

	for(i=0; i<n; i++) {
		Compartment *dep = il_get(thisCmpt->compartments, i);
		if( dep != lastCmpt )
			makeEliminateOrder(dep,
				thisCmpt,
				EliminateOrder,
				nodes,
				nodeCnt,
				orderLength,
				indx);
		thisCmpt->treeOrderCol[i] = dep->treeOrderRow;
	}
}

							   

bool hines_interp(int n, double xinterp, double *yinterp, double *fsal) { 
    ABEND("Interpolation is not supported by the Hines method (yet).\n", (Cell*)Py_None);
	return true;
}


static double getA(solverData *sd, int row, int col) {

	if( row==col )   return sd->diag[col];
	if( row==col-1 ) return sd->upper[row];
	if( row-1==col ) return sd->lower[col];

	if( sd->offdiagindx[row]==col ) return sd->offdiagval[row];
	if( sd->offdiagindx[col]==row ) return sd->offdiagval[col];

	return 0;
}

void prtsdAF(solverData *sd) {
	int row, col;
	double x;
	FILE *f = fopen("hines.dat", "w");

	for(row=0; row<sd->n; row++) {
		for(col=0; col<sd->n; col++) {
			x = getA(sd, row, col);
			if( x )
				fprintf(f, "%3g ", x);
			else
				fprintf(f, "  0 ");
		}
		fprintf(f, "\n");
	}
	fprintf(f, "\n");

	fclose(f);
}

void prtrhsF(double *rhs, int n, char *name) {
	int i;
	FILE *f = fopen(name, "w");

	for(i=0; i<n; i++)
		fprintf(f, "%3g\n", rhs[i]);

	fclose(f);
}
void prtsdA(solverData *sd) {
	int row, col;
	double x;

	for(row=0; row<sd->n; row++) {
		for(col=0; col<sd->n; col++) {
			x = getA(sd, row, col);
			if( x )
				printf("%3g ", x);
			else
				printf("  . ");
		}
		printf("  %3d\n", sd->EliminateOrder[row]);
	}
	printf("\n");

	for(col=0; col<sd->n; col++)
		printf("%3d ", sd->EliminateOrder[col]);
	printf("\n");

}


static void loadA(solverData *sd, int row, int col, double x, Cell *thisCell) {
	int indx = sd->offdiagcnt;

	if( row==col ) {
		sd->diag[row] += x;
		return;
	}
	if( row==col-1) {
		sd->upper[row] = x;
		return;
	}
	if( row-1==col ) {
		sd->lower[col] = x;
		return;
	}

	if( row>col ) {
		/* A bit of sanity checking won't go astray. Note that 
		   this does not collect all logic errors */
		if( sd->offdiagrow[col] == -1 ) return;

		if( sd->offdiagindx[col] != row || sd->offdiagvalindx[col] != x )
			ABEND("Far off diagonal elements are not symmetric.\n", thisCell);

		/* Not filled in yet, or missing or OK */
		return;
	}

	if( sd->offdiagindx[row] != -1 ) {
		ABEND("An element is already loaded in the far off diagonal location.\n",
			thisCell);
	}

	sd->offdiagindx[row] = col;
	sd->offdiagvalindx[row]  = x;


	sd->offdiagcol[indx] = col;
	sd->offdiagrow[indx] = row;
	sd->offdiagval[indx] = x;

	sd->offdiagcnt++;
}		

void hines_init(GD *gd, Cell *cell) {
	int m, n = il_length(cell->compartments);
	int statelen = da_length(cell->Y);
	int i, j, k, indx = 0, *temp, *temprow, *tempcol, lastmin, lastj=0;
	double *tempval;
	Compartment *cmpt = 0;
	solverData *sd = getmain(sizeof(*sd));
	
	/* Check that the cell is eligible for this method */
	for(i=0; i<n; i++) {
		cmpt = il_get(cell->compartments, i);
		if( cmpt->ClampMode==VoltageClamp )
			ABEND("Hines integration is not available in VoltageClamp.\n", cell);
		m = il_length(cmpt->dynamics);
		for(j=0; j<m; j++) {
			Dynamics *d = il_get(cmpt->dynamics, j);
			if( !PyObject_TypeCheck(d, &p3_HinesIntegrableDynamicsType) &&
				!d->HinesIntegrable) {
				message(fatal, "Dynamics type \"%s\" is not eligible for the hines integration.\n",
					d->ob_type->tp_name);
				ABEND("", cell);
			}
		}
	}
	
	/* Allocate work areas */
	if( n>bufsize ) {
		RHS           = PyMem_Realloc(RHS,      n*sizeof(double));
		newupper      = PyMem_Realloc(newupper, n*sizeof(double));
		newdiag       = PyMem_Realloc(newdiag,  n*sizeof(double));
		newoffdiagval = PyMem_Realloc(newoffdiagval, n*sizeof(double));
		bufsize = n;
	}
	if( statelen>statebufsize) {
		Ysave         = PyMem_Realloc(Ysave,  statelen*sizeof(double));
		Ysave2        = PyMem_Realloc(Ysave2, statelen*sizeof(double));
		if( Ysave==0 || Ysave2==0 ) 
			ABEND("No Memory in hines:hines_solve()\n", cell);
		statebufsize = statelen;
	}

	sd->n              = n;
	sd->EliminateOrder = getmain(n*sizeof(int));
	sd->nodes          = getmain((n+1)*sizeof(int));
	sd->nodes[0]       = 0;
	sd->nodeCnt        = 1;
	sd->diag           = getmain(n*sizeof(double));
	sd->upper          = getmain((n-1)*sizeof(double));
	sd->lower          = getmain((n-1)*sizeof(double));
	sd->offdiagvalindx = getmain(n*sizeof(double));
	sd->offdiagval     = getmain(n*sizeof(double));
	sd->offdiagindx    = getmain((n+1)*sizeof(int));
	sd->offdiagrow     = getmain(n*sizeof(int));
	sd->offdiagcol     = getmain(n*sizeof(int));
	sd->offdiagcnt     = 0;
	for(i=0; i<n; i++) sd->offdiagindx[i] = sd->offdiagrow[i] = sd->offdiagcol[i] = -1;
	sd->offdiagindx[n] = -2;

	/* Make the order in which the Hines eliminate will occur 
	First find a compartment (segment in Hines' terminology) that
	has no other compartment connected to it. The call mEO to recursively
	build the order in which elimination will occur */
	for(i=0; i<il_length(cell->compartments); i++) {
		cmpt = il_get(cell->compartments, i);
		if( il_length(cmpt->compartments) == 1 ) break;
	}
	makeEliminateOrder(cmpt, 
		0,
		sd->EliminateOrder,
		sd->nodes,
		&sd->nodeCnt,
		n,
		&indx);
	sd->nodes[sd->nodeCnt++] = n;

	/* Reverse the order. */
	temp  = getmain(n*sizeof(int));
	for(j=0; j<n; j++) {
			temp[n-1-j]  = sd->EliminateOrder[j];
	}
	PyMem_Free(sd->EliminateOrder);
	sd->EliminateOrder = temp;

	if( indx<n-1 ) ABEND("Cell is not simply connected.\n", cell);

	/* Load A */
	for(i=0; i<n; i++) {
		int cmptIndx = sd->EliminateOrder[i];
		Compartment *cmpti = il_get(cell->compartments, cmptIndx);
		for(k=0; k<il_length(cmpti->compartments); k++) {
			int    j = n - 1 - cmpti->treeOrderCol[k];
			double a = da_get(cmpti->axial_conductance, k);
			loadA(sd, i, j, a, cell);
			loadA(sd, i, i, -a, cell);
		}
	}

	/* These arrays need to sorted in column order */
	temprow = getmain(n*sizeof(int));
	tempcol = getmain(n*sizeof(int));
	tempval = getmain(n*sizeof(double));
	for(i=0; i<sd->offdiagcnt; i++) {
		lastmin = n;
		for(j=0; j<sd->offdiagcnt; j++) {
			if( sd->offdiagcol[j]<lastmin ) {
				lastmin = sd->offdiagcol[j];
				lastj = j;
			}
		}
		temprow[i] = sd->offdiagrow[lastj];
		tempcol[i] = sd->offdiagcol[lastj];
		tempval[i] = sd->offdiagval[lastj];
		sd->offdiagcol[lastj] = n+1;
	}
	PyMem_Free(sd->offdiagrow);
	PyMem_Free(sd->offdiagcol);
	PyMem_Free(sd->offdiagval);
	sd->offdiagrow = temprow;
	sd->offdiagcol = tempcol;
	sd->offdiagval = tempval;

	cell->solverData = sd;

 	cell->step = min(max(1e-3, gd->minStep), cell->mustStep[0].time/2);

	return;
}

static bool goosestep(Cell*, double, double);

bool hines_fixed_step_solve(Cell *cell, double eps, double minStep, double *hnext) {
	goosestep(cell, cell->time, cell->step);
	*hnext = eps;
	cell->time += cell->step;
	return true;
}

bool hines_solve(Cell *cell, double eps, double minStep, double *hnext) {
    double t = cell->time;
	double h = cell->step;
	double h2 = h/2;
	double temp=0, errmax=0;
	int i, n = cell->Y->len;

	cell->functionCnt += 3;
    cell->newtonCnt   += 3;

	memcpy(Ysave, cell->Y->data, n*sizeof(double));

	/* One whole step */
	goosestep(cell, t, h);

	/* Two half steps */
	memcpy(Ysave2, cell->Y->data, n*sizeof(double));
	memcpy(cell->Y->data, Ysave, n*sizeof(double));
	goosestep(cell, t, h2);

	goosestep(cell, t+h2, h2);

#define PGROW -0.19
#define SHRINK 0.8
#define PSHRNK -0.30
#define FACTOR 2
#define SAFETY 0.8

	/* Error calculation */
	errmax=0.0;
    for (i=0; i<n; i++) {
		double scale = max(max(fabs(cell->Y->data[i]), fabs(Ysave2[i])), eps);
        temp = cell->Y->data[i] - Ysave2[i];
        temp = fabs(temp/scale);
        if( errmax<temp ) errmax = temp;
	}
    
	errmax = errmax*h/eps;
    
	if( errmax <= 1.0 ) {
        /********
          Accept
        *********/
      if( errmax == 0 ) {
		*hnext = h * FACTOR;
      } else {
        temp = SAFETY * pow(errmax, PGROW);
        temp = min(temp, FACTOR);   /* Don't grow too much!   */
        temp = max(1/FACTOR, temp); /* Don't shrink too much! */
        *hnext = temp * h;
      }
	  cell->time += h;
      return true;
    } else {
		if( h <= minStep ) {
			cell->time += h;
			*hnext = minStep;
			return true;
		}
        /*******************************************************
          Reject - truncation error too large, reduce stepsize
        ********************************************************/
        temp = SAFETY * pow(errmax, PSHRNK);
        temp = max(0.1, temp); /* Don't shrink too much! */
        *hnext = max(temp * h, minStep);
		memcpy(cell->Y->data, Ysave, n*sizeof(double));
        return false;
    }
}

/* I think I deserve a prize for being such a swell fellow! */
static bool goosestep(Cell *cell, double t, double h) {
	double newerdiag;
	int *eliminateorder = ((solverData*)cell->solverData)->EliminateOrder;
	int i, j, k, n = il_length(cell->compartments);
	solverData *sd = (solverData*)cell->solverData;
	double *olddiag        = sd->diag;
	double *oldupper       = sd->upper;
	double *oldlower       = sd->lower;
	int    *offdiagindx    = sd->offdiagindx;
	int    *offdiagrow     = sd->offdiagrow;
	int    *offdiagcol     = sd->offdiagcol;
	double *offdiagval     = sd->offdiagval;
	double *offdiagvalindx = sd->offdiagvalindx;
	double I, V, Er;
	Compartment *cmpt;
	Dynamics *d;

	/* Calculate the rhs */
	for(i=0; i<n; i++) {
		cmpt = il_get(cell->compartments, eliminateorder[i]);
		V = GETEM_CMPT(cmpt, t);
		/* capacitive current */
		RHS[i] = - V * 2 * cmpt->capacitance / h;
		/* Diagonal element is dependent on the step size */
		newdiag[i] = olddiag[i] - 2 * cmpt->capacitance / h;

		/* ionic, synaptic injected current */
        for(j=0; j<il_length(cmpt->dynamics); j++) {
            d = il_get(cmpt->dynamics, j);
			/* grab the this current contribution */
			if( d->HIcurrent ) {
				d->HIcurrent(d, t, &I, &Er);
				RHS[i] -= I * Er;
				newdiag[i] -= I;
			} else {
				if( d->current )
					RHS[i] -= d->current((Dynamics*)d, t+h/2);
			}
		}
	}

	/* Forward elimination on the Jacobian */
	newerdiag = newdiag[0];
	if( offdiagindx[0] != -1 ) 
		newoffdiagval[0] = offdiagvalindx[0] / newerdiag;
	else
		newupper[0] = oldupper[0] / newerdiag;
	RHS[0] /= newerdiag;
	k = 0;
	j = offdiagcol[k];

	for(i=1; i<n; i++) {
		newerdiag = newdiag[i];

		/* Here we eliminate lower far-off diagonal elements.
		This occurs when there are items in column below the    
		current diagonal element. */
		while( j==i ) {
			int col = offdiagrow[k];
			newerdiag = newerdiag - offdiagval[k]*newoffdiagval[col];
			RHS[i] = (RHS[i] - offdiagval[k]*RHS[col]);
			j = offdiagcol[++k];
		}

		/* Eliminate immediately lower diagonal elements */
		newerdiag = newerdiag - oldlower[i-1]*newupper[i-1];
		
		/* Divide the RHS through by the diagonal */
		RHS[i] = (RHS[i] - oldlower[i-1]*RHS[i-1]) / newerdiag;

		/* Divide the upper off diag element by the diagonal */
		if( i<n-1 ) {
			if( offdiagindx[i] != -1 ) 
				newoffdiagval[i] = offdiagvalindx[i] / newerdiag;
			else
				newupper[i] = oldupper[i] / newerdiag;
		}
	}

	/* Back substitution */
	for(i=n-2; i>=0; i--) {
		k = offdiagindx[i];
		if( k != -1 )
			RHS[i] -= newoffdiagval[i]*RHS[k];
		else
			RHS[i] -= newupper[i]*RHS[i+1];
	}

	for(i=0; i<n; i++) {
		cmpt = il_get(cell->compartments, eliminateorder[i]);
		V = 2*RHS[i]-GETEM_CMPT(cmpt, t);
		SETEM_CMPT(cmpt, V);

		/* Update the gating variables */
		for(j=0; j<il_length(cmpt->dynamics); j++) {
			d = il_get(cmpt->dynamics, j);
			if( d->HIupdate )
				d->HIupdate(d, t, h);
		}
	}

	return 1;
}

/* Changes

   EAT 01Feb08
   - fixed failure to eliminate lower elements in the 0th column during
     forward elimination.

   EAT 06Feb08
   - made the step size initialisation deterministic.

   EAT 03Apr08
   - added minStep support

   EAT 01May08
   - changed initial step size to minStep rather than 1

   EAT 16Jul08
   - hines_fixed_step was ignoring step size changes
     introduced by stepper (ie at window end).

*/