mulfit_optimizers_append("Simulated Annealing", "MulfitSimAnnealSeq", 0)

begintemplate MulfitSimAnnealSeq

public prun, sa_efun, minerr, nefun, time
public saveflag, showopt, verbose
public alpha, beta, gamma, lambda, ftol, start_temptr
public restart_factor, cool_rate, iter_per_step
public savepath_name

objref r, simplex, old_simplex
objref simplex_errors, new_point, best_point, mean_point
objref simperrors_nonoise, simplex_noise
objref savepath, tl, pf
strdef savepath_name

proc init() {
  pf = $o1
  r = new Random()
  savepath = new File()
  new_point = new Vector()
  nefun = 0
  st = 0
  lambda = 2
  ftol = 1e-5
  alpha = 1.0
  beta = 2.0
  gamma = 0.5
  iter_per_step = 100
  start_temptr = 0.1
  cool_rate = 0.9
  restart_factor = -1
  eb = 1e9
  saveflag = 0
  verbose = 0
}


func sa_efun() { local e
  nefun += 1
  e =  pf.efun($1, &$&2)
  if (!stoprun) {
    if (minerr == -1 || e < minerr) {
      minerr = e
    }
  }
  doNotify()
  if (verbose) {
    printf("%d %14.12g ",nefun, e)
    for i = 0,$1-1 {
      printf("%14.12g ",$&2[i])
    }
    printf("\n")
  }
  return e
}


func prun() { local lines
  minerr = -1
  restart_temptr = restart_factor * start_temptr
  if (saveflag) {
    savepath.wopen("savepath.tmp")
    lines = 0
  }
  pf.doarg_get(new_point)
  nefun = 0
  eb = 1e9
  st = startsw()
  time = st
  if (new_point.size == 0) {
    minerr = pf.efun(0, &time) // time is dummy here
  } else {
    if (stoprun) {return minerr}
    minerr = fit_simanneal_seq()
    new_point = best_point.c
    minerr = sa_efun(new_point.size(),&new_point.x[0])
  }
  time = startsw() - st
  if (savepath.isopen) {
    savepath.close
    savepath.aopen("savepath.fit")
    savepath.printf("start\n")
    savepath.printf("%d %d\n", lines, pf.parmlist.count + 2)
    savepath.close
    system("cat savepath.tmp >> savepath.fit")
  }
  printf("Final temperature: %9.5g\tnefun: %d\tFinal error: %9.5g\tat point ",temptr,nefun,eb)
  best_point.printf("%9.4g")
  return minerr
}

proc showopt() {
 variable_domain(&start_temptr,0,1e12)
 variable_domain(&cool_rate,0,1)
 variable_domain(&iter_per_step,1,1e12)
 xpanel("")
 xpvalue("alpha",&alpha,1)
 xpvalue("beta",&beta,1)
 xpvalue("gamma",&gamma,1)
 xpvalue("lambda",&lambda,1)
 xpvalue("ftol",&ftol,1)
 xpvalue("Start temperature",&start_temptr,1)
 xpvalue("Rate of cooling",&cool_rate,1)
 xpvalue("Iterations per cooling step",&iter_per_step,1)
 xpvalue("Restart factor",&restart_factor,1)
 xpanel()
 xpanel("")
 xpvalue("Current temperature",&temptr)
 xpvalue("Current lowest error",&elo)
 xpvalue("Current tolerance",&rtol)
 xpanel()
}

func noisy_simplex() { local i, e, iter, esave
  iter = $1
  r.uniform(1e-20,1)

  while (1) {
    if (stoprun) break
    // find the highest, next highest and lowest points
    //
    // at this point in the loop, simplex_errors includes the noise from the
    // last iteration. however, we also want to keep track of the simplex
    // errors WITHOUT the noise.
    simperrors_nonoise = simplex_errors.c.sub(simplex_noise)
    simplex_errors = simperrors_nonoise.c
    for i = 0, npoints-1 {

      // By storing the noise applied at each step, in addition to the 
      // noise-amplified error, we are able to retain the actual function
      // values observed, which is crucial to the Press algorithm.  Otherwise
      // we keep adding noise on top of itself, which is a real problem when the
      // noise is large!!
      //
      simplex_noise.x[i] = -1*temptr*log(r.repick())
      simplex_errors.x[i] += simplex_noise.x[i]
    }
    ilo = simplex_errors.min_ind()
    ihi = simplex_errors.max_ind()
    ehi = simplex_errors.x[ihi]
    simplex_errors.x[ihi] = -1
    inhi = simplex_errors.max_ind()
    simplex_errors.x[ihi] = ehi
    elo = simplex_errors.x[ilo]
    enhi = simplex_errors.x[inhi]

    // test for completion
    rtol = 2 * abs(ehi - elo)/( abs(ehi) + abs(elo) )
    //printf("%14.12g %14.12g %14.12g\n",rtol, ehi, elo)
    //printf("nefun: %d Temp: %g iter: %d ehi: %g elo: %g rtol: %g ftol:%g ",nefun,temptr,iter,ehi,elo,rtol,ftol)
    //simplex.getrow(ilo).printf()
    if (rtol < ftol || iter < 0) {
      break
    }

    // Compute vector average of all points except highest
    mean_point.fill(0)
    for i = 0, npoints-1 {
      if (i != ihi) mean_point.add(simplex.getrow(i))
    }
    mean_point.div(npoints-1)
    iter -= 2
    old_simplex = simplex.c()
    if (stoprun) break
    esave = ehi
    e = new_step(alpha) 	// Reflect simplex from high point
    if (e < elo) {		// if lower than the lowest... N.B Press et al use <= but this seems to lead to loops
      if (stoprun) break
      if (verbose) printf("New best point. Try a further extrapolation.\n")
      e = new_step(beta)	// ...Extrapolate in the same direction
    } else if (e >= enhi) {	// else if higher than the second-highest...
      if (stoprun) break
      if (ehi < esave) {
      if (verbose) printf("Do a one dimensional contraction (+)\n")
        esave = ehi
        e = new_step(gamma)       // ...Do a 1D positive contraction\
      } else {
      if (verbose) printf("Do a one dimensional contraction (-)\n")
        e = new_step(-1*gamma)    // ...Do a 1D negative contraction\
      }
      if (e >= esave) {		// if the highest point is still there...
        if (verbose) printf("Contract around the best point\n")
        for i = 0, npoints-1 {  // ...contract about the lowest point
          if (stoprun) break
          if (i != ilo) {
            new_point = simplex.getrow(i).add(simplex.getrow(ilo)).mul(gamma)
	    simplex.setrow(i,new_point)
            simplex_errors.x[i] = sa_efun(new_point.size(),&new_point.x[0])
            simplex_noise.x[i]  = 0
          }
        }
        iter -= npoints-1
      }
    } else {
      iter += 1
    }
  }
  return iter
}

func new_step() { local e, fluc
  new_point = mean_point.c.mul(1+$1).sub(old_simplex.getrow(ihi).mul($1))
  e = sa_efun(new_point.size(),&new_point.x[0])
  if (e <= eb) {
    best_point = new_point
    eb = e
  }
  fluc = -1*temptr*log(r.repick())
  e -= fluc

  if (e < ehi) {
    if (verbose) printf("Replaced highest point\n")
    ehi = e
    simplex.setrow(ihi,new_point)
    simplex_errors.x[ihi] = e
    simplex_noise.x[ihi]  = -1*fluc
  }
  return e
}

proc fit_simanneal_seq() { local i, id
  // initialise the simplex
  nparam = new_point.size()
  npoints = nparam + 1
  simplex = new Matrix(npoints,nparam)
  simplex.setrow(0,new_point)
  for i = 1,npoints-1 {
    simplex.setrow(i,new_point)
    simplex.x[i][i-1] *= lambda
  }
  best_point = new Vector(nparam)
  mean_point = new Vector(nparam)
  // initialise the error vector
  simplex_errors = new Vector(npoints)
  for i = 0,npoints-1 {
    if (stoprun) return
    new_point = simplex.getrow(i)
    simplex_errors.x[i] = sa_efun(new_point.size(),&new_point.x[0])
  }

  //initialize the vector which stores the applied noise.
  simplex_noise = new Vector(npoints,0)

  //Start the optimisation
  iter = -1
  temptr = start_temptr
  while (iter < 0 && !stoprun) {
    iter = iter_per_step
    iter = noisy_simplex(iter)
    if (stoprun) return
    anneal_step(0)
  }
}

proc anneal_step() { local newTemptr, i, in_simplex
  time = startsw() - st

  if (savepath.isopen) {
    lines += 1
    savepath.printf("%g %d %d %-12.8g ", temptr, nefun, time, eb)
    for i = 0,nparam-1 {
      savepath.printf("%g ",best_point.x[i])
    }
    savepath.printf("\n")
  }
  printf("temptr=%g nefun=%d time=%d eb=%-12.8g\n", temptr, nefun, time, eb)

  if ($1 == 0) {		// simplest annealing schedule
    temptr *= cool_rate
  } else { 			// more complex schedule
    newTemptr = 1.0*(elo-eb)
    if (newTemptr < temptr) {
      if (newTemptr > 0.5*temptr) {
        temptr = newTemptr
      } else {
        temptr = 0.5*temptr
      }
    }
  }

  if (temptr < restart_temptr) {
    restart_temptr *= restart_factor
    in_simplex = 0
    for i = 0, npoints-1 {
      in_simplex += simplex.getrow(i).eq(best_point)
    }
    if (in_simplex > 0) {
      print "Best point is in simplex. No restart."
    } else {
      simplex.setrow(ihi,best_point)
      printf("Restarting at temperature %g\n",temptr)
    }
  }
}

endtemplate MulfitSimAnnealSeq