# Dimension-by-dimension local optimization algorithm
# Tuomo Maki-Marttunen, 2013-2017 (CC-BY 4.0)

from pylab import *
import time

def minimizedimbydim(fcn=lambda x: x[0]**2+x[1]**2,thrs=array([[-1.,-1.],[1.,1.]]),initx=[],powinterval=1./3,Ninterval=40,Niter=10,powintervalstd=[],nhoursmax=inf):

  if len(initx) == 0:
    initx = thrs[0] + (thrs[1]-thrs[0])*rand(1)

  if len(powintervalstd) == 0:
    powintervalstd = powinterval/5

  if len(initx.shape) == 1:
    initx = [initx]
  #if type(initx[0]) is not list:
  #  initx = [initx]
  myx = initx[:]
    
  Npart = len(myx)
  errs = []
  for j in range(0, Npart):
    errs.append(fcn(myx[j]))

  errsall = zeros([Niter,Npart])

  timenow = time.time()
  dimord = range(0,len(initx[0]))
  shuffle(dimord)
  #print str(powintervalstd)

  for iter in range(0,Niter):
    if time.time()-timenow > nhoursmax*3600:
      break
    dimord2 = range(0,len(initx[0])-1)
    shuffle(dimord2)
    dimord = [dimord[0]] + [dimord[1+i] for i in dimord2]

    for iidim in range(0,len(dimord)):
      idim = dimord[iidim]
      thispowinterval = powinterval + powintervalstd*randn(1)
      k=0
      while thispowinterval <= 0 or thispowinterval >= 1:
        thispowinterval = powinterval + powintervalstd*randn(1)
        k = k+1
        if k > 100000:
          print('try smaller powintervalstd!')
          return [nan,nan]
        
      parvals = []
      for j in range(0,Npart):
        parvals.append((myx[j][idim] - (myx[j][idim]-thrs[0][idim])*thispowinterval**array(range(1,Ninterval+1))).tolist() +
                       (myx[j][idim] + (thrs[1][idim]-myx[j][idim])*thispowinterval**array(range(1,Ninterval+1))).tolist())

      #print "        parvals[0]="+str(parvals[j][0])+", [end-1]="+str(parvals[j][Ninterval-1])+", val="+str(myx[j][idim])+", [2*end-1]="+str(parvals[j][2*Ninterval-1])+", [end]="+str(parvals[j][Ninterval])
      for j in range(0,Npart):
        these_errs = zeros([2*Ninterval,1])
        for k in range(0,len(parvals[0])):
          if parvals[j][k] == myx[j][idim]:
            these_errs[k] = errs[j]
          else:
            thisx = myx[j]
            thisx[idim] = parvals[j][k]
            these_errs[k] = fcn(thisx)

        if any(these_errs < errs[j]):
          ind = argmin(these_errs)
          myx[j][idim] = parvals[j][ind]
          errs[j] = min(these_errs)
    errsall[iter,:] = errs;
    print("minimizedimbydim: iter = "+str(iter)+", err = "+str(min(errs))) #+", x = "+str(myx))
  return [myx[:],errsall[:]]