/*
* Copyright (C) 2004 Evan Thomas
* 
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or (at
* your option) any later version.
* 
* This program is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
* General Public License for more details.
* 
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*
*/

/* 
Fifth order Radau IIA, solved iteratively using Newton's method.
See Solving Ordinary Differential Equations II, by Hairer & Wanner
*/

#define P3_MODULE
#include "ndl.h"

#define PGROW  0.25
#define PSHRNK 0.25
#define SAFETY 0.9
#define FACTOR 4.0
#define ERRCON 0.003    /* (FACTOR/SAFETY)**(-1/PGROW)  */
#define KAPPA  1e-1
#define NEWTON_MAX 10


/* See Solving Ordinary Equations II by Hairer & Wanner */
static const double c[] = {0, 0.1550510257216822, 0.6449489742783178,
1.0000000000000000};
static const double b[] = {0, 0.3764030627004673, 0.5124858261884216,
0.1111111111111111};
static const double a[4][4] = {{0, 0, 0, 0},
{0, 0.1968154772236604, -0.0655354258501984,
0.0237709743482202},
{0, 0.3944243147390872, 0.2920734116652284,
-0.0415487521259979},
{0, 0.3764030627004673, 0.5124858261884216,
0.1111111111111111}};
static const double eig = 2.748888295956773e-01;  /* real eigenvalue of A */
static const double e[] = {0, -2.762305454748599, 3.799355982527286e-01,
-9.162960986522578e-02};

/* (4+sqrt(6))/10-1 */
#define C2M1 -0.3550510257216821901802715925294108608034
/* (4-sqrt(6))/10-1 */
#define C1M1 -0.8449489742783178098197284074705891391966
/* (4-sqrt(6))/10 */
#define C1 0.1550510257216821901802715925294108608034
/* (4-sqrt(6))/10 - (4+sqrt(6))/10 */
#define C1MC2 -0.4898979485566356196394568149411782783932
/* (4+sqrt(6))/10 */
#define C2 0.6449489742783178098197284074705891391966


static double *Z, *lasty, laststep, lastt;
static double **J, **M, *spJ, *spM;
static double *dy1, *dy2, *dy3, *temp, *err;
static int *spMI, spMCnt;

void rd5_init(GD *gd, Cell *cell) {
	static int nMax, spCntMax;
	int n = da_length(cell->Y);

	if( n>nMax ) {
		if( temp  ) freevec(temp );
		if( Z     ) freevec(Z    );
		if( dy1   ) freevec(dy1  );
		if( dy2   ) freevec(dy2  );
		if( dy3   ) freevec(dy3  );
		if( err   ) freevec(err  );
		if( lasty ) freevec(lasty);

		temp     = vector(3*n);
		Z        = vector(3*n);
		dy1      = vector(n);
		dy2      = vector(n);
		dy3      = vector(n);
		err      = vector(n);
		lasty    = vector(n);

		/* Both the sparse and full matrices need to allocated
		   since in general, both might in use in a single run */
		if( J ) freemat(J, nMax);
		if( M ) freemat(M, 3*nMax);
		J  = matrix(n);
		M  = matrix(3*n);
		if( spM ) freevec(spM);
		spM   = vector(3*6*n);
		if( spMI ) PyMem_Free(spM);
		spMI  = getmain(3*6*n*sizeof(int));
		spMCnt = 3*6*n;
	
		nMax = n;
	}

	if( cell->spCnt > spCntMax ) {
		if( spJ ) freevec(spJ);
		spJ   = vector(cell->spCnt);
		spCntMax = cell->spCnt;
	}
}

bool rd5_interp(int n, double tinterp, double *yinterp, double *fsal) { 
  int i;
  double s = (tinterp - lastt)/laststep;
  double cont1, cont2, cont3, ak, acont3;

  for(i=0; i<n; i++) {
      cont1 = (Z[i+n] - Z[i+2*n])/C2M1;
      ak = (Z[i]-Z[i+n])/C1MC2;
      acont3 = Z[i]/C1;
      acont3 = (ak - acont3)/C2;
      cont2 = (ak - cont1)/C1M1;
      cont3 = cont2 - acont3;
      yinterp[i] = lasty[i] +
          s*(cont1 + (s-(C2M1))*(cont2 + (s-(C1M1))*cont3));
  }
  
  return true;
}

bool rd5_solve(Cell *cell, double eps, double minStep, double *hnext) {
    int i, j;
    double errNewtmax, errNewtlast, errIntmax;

    double h, Zmax;
    int newton_cnt;

    double *y = cell->Y->data;
    double *x = &cell->time;
    double htry = cell->step;
    int n = da_length(cell->Y);
    double *fac = cell->fac->data;
    int *functionCnt = &cell->functionCnt;
    int *newtonCnt = &cell->newtonCnt;
    int *jacobianCnt = &cell->jacobianCnt;
    bool UseSparse = cell->sparseLinearAlgebra;

    if( UseSparse ) ABEND("Sparse is currently broken in rdIIA.\n", cell);

    h = htry;

    /* These data form the starting point for interpolation */
    lastt = *x;
    for(i=0; i<n; i++) lasty[i] = y[i];
    laststep = h;

    /******************
    Newton Iteration
    ******************/

    /* Initialise */
    for(i=0; i<3*n; i++) Z[i] = 0;
    errNewtlast = 0.5;
    newton_cnt = 0;

    /* Evaluate Jacobian */
    (*functionCnt)++;
    derivs(*x, y, temp);
    (*jacobianCnt)++;   
    if( UseSparse )
      sparse_jacobian(*x, y, temp, fac, n, &spJ, &cell->sparsePattern,
		              &cell->spCnt, functionCnt, cell);
    else
      jacobian(*x, y, temp, fac, n, J, functionCnt, cell);

    while(1) { /* Newton iteration ... */
        (*newtonCnt)++;

        for(i=0; i<n; i++) {
            temp[i]     = y[i] + Z[i];
            temp[i+n]   = y[i] + Z[i+n];
            temp[i+2*n] = y[i] + Z[i+2*n];
        }

        *functionCnt += 3;
        derivs(*x + c[1] * h,  temp,      dy1);
        derivs(*x + c[2] * h, &temp[n],   dy2);
        derivs(*x + c[3] * h, &temp[2*n], dy3);

        for(i=0; i<n; i++) {
            temp[i] = - Z[i] + h *
                (a[1][1] * dy1[i] + a[1][2] * dy2[i] + a[1][3] * dy3[i]);

            temp[i+n] = - Z[i+n] + h *
                (a[2][1] * dy1[i] + a[2][2] * dy2[i] + a[2][3] * dy3[i]);

            temp[i+2*n] = - Z[i+2*n] + h *
                (a[3][1] * dy1[i] + a[3][2] * dy2[i] + a[3][3] * dy3[i]);
        }

        for(i=0; i<n; i++) {
            for(j=0; j<n; j++) {
                M[i][j]           = (i==j) - h * a[1][1] * J[i][j];
                M[i][j+n]         =        - h * a[1][2] * J[i][j];
                M[i][j+2*n]       =        - h * a[1][3] * J[i][j];

                M[i+n][j]         =        - h * a[2][1] * J[i][j];
                M[i+n][j+n]       = (i==j) - h * a[2][2] * J[i][j];
                M[i+n][j+2*n]     =        - h * a[2][3] * J[i][j];

                M[i+2*n][j]       =        - h * a[3][1] * J[i][j];
                M[i+2*n][j+n]     =        - h * a[3][2] * J[i][j];
                M[i+2*n][j+2*n]   = (i==j) - h * a[3][3] * J[i][j];	  
            }
        }

        linsolve(M, spM, spMI, temp, 3*n, eps, cell, 0, UseSparse);

        errNewtmax = Zmax = 0.0;
        for(i=0; i<3*n; i++) {
            double t = fabs(temp[i]/max(y[i%n],eps));
            Zmax = Zmax > fabs(Z[i]) ? Zmax : fabs(Z[i]);
            Z[i] += temp[i];
            errNewtmax = errNewtmax > t ? errNewtmax : t;
        }

        if(errNewtmax < eps * KAPPA * (1-errNewtlast) / errNewtlast) break;

        /* Should we restart the Newton iteration with a smaller step? */
        if( ++newton_cnt > NEWTON_MAX ||
            !finite(errNewtmax) ||
            isnan(errNewtmax)) {
                h = h/2;
                if( *x + h == *x ) {
                    PyErr_SetString(PyExc_RuntimeError,
                        "Step size too small in rd5qc().\n");
                    longjmp(excpt_exit, 1);
                }
                /* Re-initialise */
                for(i=0; i<3*n; i++) Z[i] = 0;
                errNewtlast = 0.5;
                newton_cnt = 0;
            } else
                errNewtlast = errNewtmax;
    }

    (*functionCnt)++;
    derivs(*x, y, err);
    for(i=0; i<n; i++) {
        err[i] = eig*h*err[i] + e[1]*Z[i] + e[2]*Z[i+n] + e[3]*Z[i+2*n];
    }

    for(i=0; i<n; i++)
        for(j=0; j<n; j++)
            M[i][j] = (i==j) - h * eig * J[i][j];
    linsolve(M, spM, spMI, err, n, eps, cell, 0, UseSparse);

    errIntmax = 0;
    for(i=0; i<n; i++) {
        double yt = fabs(y[i] + Z[i+2*n]);
        double scale = max(max(fabs(y[i]), yt), eps);
        double temp = fabs(err[i]/scale);
        errIntmax = errIntmax > temp ? errIntmax : temp;
    }

    errIntmax /= eps;
    if( errIntmax <= 1.0 ) {
      if( errIntmax == 0 ) *hnext = h * FACTOR;
      else
        *hnext = errIntmax > ERRCON ?
            SAFETY * h * pow(errIntmax, -PGROW) *
            (2*NEWTON_MAX + 1) / (2*NEWTON_MAX + newton_cnt)
            : FACTOR * h;
        if( *x == *x+h ) {
            PyErr_SetString(PyExc_RuntimeError,
                "Step size too small in rd5qc().\n");
            longjmp(excpt_exit, 1);
        }
        *x += h;
        for(i=0; i<n; i++) y[i] += Z[i+2*n];
        return true; /* Step accepted */
    }

    /* Step rejected */
    *hnext = SAFETY * h * pow(errIntmax, -PSHRNK);

    return false;
}
