/*
 * Copyright (C) 2004 Evan Thomas
 * 
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or (at
 * your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 *
 */

/****************************************************************
  The Rosenbrock/Wolfbrandt low order semi-implicit RK method
  (similar to that used in Matlab's ode23s function).

  The step size selection strategy is the same as used in RK32

  Ref: The MATLAB ODE Suite
       LF Shampine
       SIAM Journal of Scientific Computation 18(1) 1997 pp1-22
*****************************************************************/


#define P3_MODULE
#include "ndl.h"

#define PGROW -0.33
#define PSHRNK -0.33
#define SAFETY 0.8
#define FACTOR 5.0

/* 1/(2+sqrt(2)) (40 digits) */
#define D   0.2928932188134524755991556378951509607152 
/* 6+sqrt(2) (40 digits) */
#define E32 7.414213562373095048801688724209698078570
#define SQRT_DBL_EPSILON   1.5e-8

static double *k1, *k2, *lasty, lastt, laststep;
static double **W, *spW;
static double *F1, *F2, *dFdt, *k3, *ytemp;


typedef struct partialsDependencyTag {
    internal_list *dgjdVi;
    Compartment   *dVidVi;
    internal_list *dVjdVi;
    Dynamics      *dxjdgi;
} partialsDependency;

static int cmp(const void *i, const void *j) {
    int int1=*(int*)i;
    int int2=*(int*)j;
    if( int1<int2 ) return -1;
    if( int1>int2 ) return 1;
    return 0;
}

static void makedep(Cell *cell) {
    /* Make the dependency tree required for the sparse derivative
       calculation.
       */

    int n = da_length(cell->Y);
    partialsDependency *dep = getmain(n*sizeof(*dep));
    int i, j, k=n;
    Compartment *cmpt, *adjcmpt;
    Dynamics *d=0;
    bool isVi;
    int *holdj, holdjindx;
	int cmpt_seq, dyn_seq;

    cell->sparsePattern = getmain(3*n*sizeof(int));
    cell->sparsePattern[0] = n + 1;
    cell->spCnt = 3*n;

    holdj = getmain(n*sizeof(int));
    
    message(debug, "makedep() start.\n");

    cell->dependencies = dep;

    for(i=0; i<n; i++) {

        /* What type of state variable? */
        isVi = 0;
		cmpt_seq = 0;
        while( (cmpt=il_seq_next(cell->compartments, &cmpt_seq)) ) {
            if( cmpt->Emindx == i ) {
                isVi = 1;
                goto stateFound;
            }
            dyn_seq = 0;
            while( (d=il_seq_next(cmpt->dynamics, &dyn_seq)) ) {
                if( d->y<=i && i<(d->y+d->n) && d->derivs) {
                    goto stateFound;
                }
            }
        }
        ABEND("Logic error in makedep().\n", cell);

stateFound:
	/* Create the dependency information */
        if( isVi ) {
            /* diagonal element */
            dep[i].dVidVi = cmpt;
            /* find compartments to which this one connects */
            dep[i].dVjdVi = cmpt->compartments;
            /* Gate variables that are dependent on this Vi */
            dep[i].dgjdVi = cmpt->dynamics;
        } else {
            /* Empty lists */
            dep[i].dgjdVi   = il_new("Dynamics");
            dep[i].dVjdVi   = il_new("Compartment");
            /* Access to derivs and current */
            dep[i].dxjdgi   = d;
        }

	/* Create the sparse Jacobian pattern index */
    holdjindx = 0;
	if( isVi ) {
	  /* Assume each gate variable contributes */
	  dyn_seq = 0;
	  while( (d=il_seq_next(cmpt->dynamics, &dyn_seq)) ) {
	    for(j=d->y; j<(d->y+d->n); j++) {
            holdj[holdjindx++] = j;
	    }
	  }
	  /* Vj for each connected compartment */
	  cmpt_seq = 0;
	  while( (adjcmpt=il_seq_next(cmpt->compartments, &cmpt_seq)) ) {
          holdj[holdjindx++] = adjcmpt->Emindx;
	  }
	} else {
	    /* Assume each gate variable depends on Vi */
        holdj[holdjindx++] = d->owner->Emindx;
    }

    /* It is worth the effort of getting these in order
       (for example if it needs to be printed) */
    qsort(holdj, holdjindx, sizeof(int), cmp);

    for(j=0; j<holdjindx; j++) {
        if( i==holdj[j] ) continue;
        if( ++k >= cell->spCnt ) {
            cell->spCnt = (int)((float)cell->spCnt*1.1+1);
            cell->sparsePattern = PyMem_Realloc(cell->sparsePattern, 
                cell->spCnt*sizeof(int));
            if( cell->sparsePattern==0 ) 
                ABEND("No memory in makedep\n", cell);
        }
        cell->sparsePattern[k] = holdj[j];
    }
    cell->sparsePattern[i+1] = k + 1;
    }   
    PyMem_Free(holdj);

    message(debug, "makedep() end.\n");
}

void rw23_init(GD *gd, Cell *cell) {
	static int nMax, spCntMax;
	int n = da_length(cell->Y);

	if( n>nMax ) {
		if( F1 )    freevec(F1);
		if( F2 )    freevec(F2);
		if( dFdt )  freevec(dFdt);
		if( k1 )    freevec(k1);
		if( k2 )    freevec(k2);
		if( k3 )    freevec(k3);
		if( lasty ) freevec(lasty);
		if( ytemp ) freevec(ytemp);

	    F1    = vector(n);
		F2    = vector(n);
		dFdt  = vector(n);
		k1    = vector(n);
		k2    = vector(n);
		k3    = vector(n);
		lasty = vector(n);
		ytemp = vector(n);

		/* Both the sparse and full matrices need to allocated
		   since in general, both might in use in a single run */
		if( W ) freemat(W, nMax);
		W     = matrix(n);

		matrix_init(gd, n);

		nMax = n;
	}

	if( cell->sparseLinearAlgebra ) makedep(cell);

	if( cell->spCnt > spCntMax ) {
		if( spW ) freevec(spW);
		spW   = vector(cell->spCnt+100);
		spCntMax = cell->spCnt;
	}

	if( cell->step == 0 ) cell->step = 0.1;
}

bool rw23_interp(int n, double tinterp, double *yinterp, double *fsal) { 
  int i;
  double h = laststep;
  double s = (tinterp - lastt)/h;
  double c1 = s*(1-s)/(1-2.0*D);
  double c2 = s*(s-2.0*D)/(1-2.0*D);
  
  for(i=0; i<n; i++) {
    yinterp[i] = lasty[i] + h*(c1*k1[i]+c2*k2[i]);
  }
  
  return true;
}

/* I think I deserve a prize for this function declaration! */
bool rw23_solve(Cell *cell, double eps, double minStep, double *hnext)
{
    int i, j;
    double errmax, temp;
    static double *lastyin = 0;
    int **spWI = &cell->sparsePattern;
	int *spCnt = &cell->spCnt;
    
    double *y = cell->Y->data;
    double *F0 = cell->DYDT->data;
    double *t = &cell->time;
    double h = cell->step;
    int n = da_length(cell->Y);
    double *fac = cell->fac->data;
    int *functionCnt = &cell->functionCnt;
    int *newtonCnt = &cell->newtonCnt;
    int *jacobianCnt = &cell->jacobianCnt;
    bool UseSparse = cell->sparseLinearAlgebra;

    *newtonCnt += 3;

    (*jacobianCnt)++;
    if( UseSparse )
      sparse_jacobian(*t, y, F0, fac, n, &spW, spWI, 
	                   spCnt, functionCnt, cell);
    else
      jacobian(*t, y, F0, fac, n, W, functionCnt, cell);

    if( UseSparse ) {
        for(i=0; i<n; i++) {
            spW[i] = 1 - h*D*spW[i];
            for (j=(*spWI)[i];j<(*spWI)[i+1];j++)
                spW[j] *= - h*D;
        }
    } else {
        for(i=0; i<n; i++) {
            for(j=0; j<n; j++)
                W[i][j] = (double)(i==j) - h*D*W[i][j];
        }
    }

    /* Time derivative */
    (*functionCnt)++;
    temp = sqrt(eps)*min(h, 0.1);
    derivs(*t+temp, y, dFdt);
    for(i=0; i<n; i++) dFdt[i] = (dFdt[i] - F0[i]) / temp;

    /* These data form the starting point for interpolation */
    lastt = *t;
    for(i=0; i<n; i++) lasty[i] = y[i];
    laststep = h;
    lastyin = y;

    /* k1 */
    for(i=0; i<n; i++)
      k1[i] = F0[i] + h*D*dFdt[i];
    linsolve(W, spW, *spWI, k1, n, eps, cell, 0, UseSparse);
    
    /* F1 */
    for(i=0; i<n; i++) ytemp[i] = y[i] + h*k1[i]/2.0;
    (*functionCnt)++;
    derivs(*t+h/2.0, ytemp, F1);
    
    /* k2 */
    for(i=0; i<n; i++)
      k2[i] = F1[i]-k1[i];
    linsolve(W, spW, *spWI, k2, n, eps, cell, 1, UseSparse);
    
    for(i=0; i<n; i++)
      k2[i] += k1[i];
    
    
    /* y(t+h) */
    for(i=0; i<n; i++) ytemp[i] = y[i] + h*k2[i];

    /* F2 */
    (*functionCnt)++;
    derivs(*t+h, ytemp, F2);

    /* k3 */
    for(i=0; i<n; i++)
        k3[i] = F2[i] - E32*(k2[i]-F1[i]) - 2.0*(k1[i]-F0[i]) + h*D*dFdt[i];
    linsolve(W, spW, *spWI, k3, n, eps, cell, 1, UseSparse);

    /* error */
    errmax = 0;

    for(i=0; i<n; i++) {
        double scale = max(max(fabs(y[i]), fabs(ytemp[i])), eps);
        temp = (k1[i] - 2.0*k2[i] + k3[i])/6.0;
        temp = fabs(temp/scale);
        if( errmax<temp ) errmax = temp;
    }

    errmax = errmax*h/eps;
    if( errmax <= 1.0 ) {
        /********
          Accept
        *********/
      if( errmax == 0 ) {
	*hnext = h * FACTOR;
      } else {
        temp = SAFETY * pow(errmax, PGROW);
        temp = min(temp, FACTOR);   /* Don't grow too much!   */
        temp = max(1/FACTOR, temp); /* Don't shrink too much! */
        *hnext = temp * h;
      }
        *t += h;
        for(i=0; i<n; i++) {
            y[i] = ytemp[i];
            F0[i] = F2[i];
        }
        return true;
    } else {
        /*******************************************************
          Reject - truncation error too large, reduce stepsize
        ********************************************************/
        temp = SAFETY * pow(errmax, PSHRNK);
        temp = max(0.1, temp); /* Don't shrink too much! */
        *hnext = temp * h;
        return false;
    }
}
