
######## GAR - FULL BAYES #################
GAR.FB = function(X, W1, W2=FALSE, K=5, B=2000, Burn=400, r1=1, s1=1, r2=0.01, s2=0.00001, monot.incr=FALSE, mon.tol=0, Plot=TRUE, marg.map=FALSE, zgrid=20, ori3d=200){

 library(MASS)
 library(mgcv)
 library(QUIC)
 library(statmod)
 library(truncdist)


  POS=function(MAT, d=ncol(MAT)){
   D=diag(sqrt(diag(MAT)))
   MAT=cov2cor(MAT)
   b=0
    while(min(eigen(MAT)$val)<(10^(-10))){
      b=b+0.00001
      MAT=MAT+diag(d)*b
    }
    return(D%*%cov2cor(MAT)%*%D)
  }

  listmean = function(LIST, set){
   L = lapply(set, function(i) LIST[[i]])
  return(Reduce("+", L)/length(set))
  }

  IWfun = function(W1, W2=FALSE, K=10){
   Q=upper.tri(diag(ncol(W1)))
   w1 = W1[Q]
   Q1=c(min(w1)-1,quantile(w1,(1:K)/K))
   Q1[K+1]=Inf
   Q2=NA
   if(is.numeric(W2)==FALSE){
    IW=lapply(1:K, function(k) round(W1>Q1[k] & W1<=Q1[k+1]))
   }
   if(is.numeric(W2)==TRUE){
    w2 = W2[Q]
    Q2 = lapply(1:K, function(k){
     sel=(w1>Q1[k] & w1<=Q1[k+1])
     Q2=c(min(w2[sel])-1,quantile(w2[sel],(1:K)/K))
     Q2[K+1]=Inf
    return(Q2)
    })
    SEL=cbind(rep(1:K,each=K),rep(1:K,K))
    IW=lapply(1:nrow(SEL), function(j) {
      k1=SEL[j,1]
      k2=SEL[j,2]
      round(W1>Q1[k1] & W1<=Q1[k1+1] & W2>Q2[[k1]][[k2]] & W2<=Q2[[k1]][[k2+1]])
    })
   }
  return(list(IW=IW, Q1=Q1, Q2=Q2, K=K, KK=length(IW)))
  }

  stepfun = function(w1, w2, Beta, Q1, Q2, K){
   SEL=cbind(rep(1:K,each=K),rep(1:K,K))
   inds=sapply(1:nrow(SEL), function(j) {
    k1=SEL[j,1]
    k2=SEL[j,2]
    round(w1>Q1[k1] & w1<=Q1[k1+1] & w2>Q2[[k1]][[k2]] & w2<=Q2[[k1]][[k2+1]])
   })
   sum(inds*Beta)
  }

  up.Alpha = function(d, r, s, Omega, Tau, Alpha, Beta, IW, KK){
    G = 0
    for(k in 1:KK){
     G=G+Beta[k]*IW[[k]]
    }
    for(i in 1:d){
     Ci = Omega[i,i]+sum(((Alpha[-i]*Omega[i,-i]*G[i,-i])^2)/Tau[i,-i])
     Alpha[i] = sqrt(rgamma(1, shape=(r+d+1)/2, rate=s+Ci))
    }
  return(Alpha)
  }

  up.Beta = function(r, s, Ds, Omega, Tau, Alpha, IW, KK, Q){
   Beta=numeric(KK)
   AOA2 = (diag(Alpha)%*%Omega%*%diag(Alpha))^2
   for(k in 1:KK){
    Ek = sum(IW[[k]][Q]*AOA2[Q]/Tau[Q])
    Beta[k] = sqrt(rgamma(1, shape=(r+Ds[k])/2, rate=s+Ek))
   }
  return(Beta)
  }

  up.Beta.monot = function(r, s, Ds, Omega, Tau, Alpha, Beta, IW, KK, Q, mon.tol=.9){
   Beta=c(0,Beta, Inf)
   AOA2 = (diag(Alpha)%*%Omega%*%diag(Alpha))^2
   for(k in 1:KK){
    Ek = sum(IW[[k]][Q]*AOA2[Q]/Tau[Q])
    Beta[k+1] = sqrt(rtrunc(1, 'gamma', shape=(r+Ds[k])/2, rate=s+Ek, a=max(0,Beta[k]*(1-mon.tol)), b=Beta[k+2]*(1+mon.tol)))
   }
  return(Beta[2:(KK+1)])
  }

  Lambda.build = function(Alpha, Beta, IW, KK){
   L = 0
   for(k in 1:KK){
    L = L + Beta[k]*IW[[k]]
   }
   diag(L)=1
  return(diag(Alpha)%*%L%*%diag(Alpha))
  }

  up.Tau = function(d, Omega, Lambda, Q){
   MEAN = 1/(Lambda[Q]*abs(Omega[Q]))
   Tau = Omega*0
   Tau[Q] = 1/rinvgauss(d*(d-1)/2, mean=MEAN, shape=2)
   Tau = Tau+t(Tau)
   diag(Tau)=1
  return(Tau)
  }

  up.Omega = function(d, n, S, Omega, Lambda, Tau){
   for(i in 1:d){
    iO = solve(Omega[-i,-i])
    D = 2*diag(Lambda[-i,i]^2/Tau[-i,i])
    Ai = solve((S[i,i]+2*Lambda[i,i])*iO + D)
    Eta = mvrnorm(1, mu = -Ai%*%S[-i,i], Sigma = Ai)
    Xi = rgamma(1, shape=n/2+1, rate=S[i,i]/2+Lambda[i,i])
    Omega[-i,i] = Eta
    Omega[i,-i] = Eta
    Omega[i,i] = Xi + t(Eta)%*%iO%*%Eta
   }
  return(Omega)
  }

  modelog = function(LIST, set){
    MAT=log(t(sapply(set,function(i) c(LIST[[i]]))))
    M = colMeans(MAT)
    COV = cov(MAT)
    K=NCOL(MAT)
    MODE = c(exp(M-COV%*%rep(1,K)))
   return(list(MODE=MODE,COV=COV,M=M))
  }

 d=ncol(X)
 n=nrow(X)
 S=cov(X)*(n-1)
 Q=upper.tri(diag(d))
 SEG=IWfun(W1=W1,W2=W2,K=K)
 IW=SEG$IW
 KK=length(IW)
 Ds=sapply(1:KK, function(k) sum(IW[[k]][Q]))
 Alpha = rep(1,d)
 Beta=1:KK
 Tau = diag(d)*0+1
 Omega = solve(POS(cov(X)))
 Lambda = Lambda.build(Alpha=Alpha, Beta=Beta, IW=IW, KK=KK)
 Omegas = list()
 Alphas = list()
 Betas = list()
 Lambdas = list()
 z=matrix(NA,nc=zgrid,nr=zgrid)
  if(Plot==T){
   dev.new()
  }
 for(b in 1:B){
  Tau = up.Tau(d=d, Omega=Omega, Lambda=Lambda, Q=Q)
  Omega = up.Omega(d=d, n=n, S=S, Omega=Omega, Lambda=Lambda, Tau=Tau)
  Alpha = up.Alpha(d=d, r=r1, s=s1, Omega=Omega, Tau=Tau, Alpha=Alpha, Beta=Beta, IW=IW, KK=KK)
  if(monot.incr==T & is.numeric(W2)==FALSE){
     Beta = up.Beta.monot(r=r2, s=s2, Ds=Ds, Omega=Omega, Tau=Tau, Alpha=Alpha, Beta=Beta, IW=IW, KK=KK, Q=Q, mon.tol=mon.tol)
  }
  if(monot.incr==F){
     Beta = up.Beta(r=r2, s=s2, Ds=Ds, Omega=Omega, Tau=Tau, Alpha=Alpha, IW=IW, KK=KK, Q=Q)
  }
  Lambda = Lambda.build(Alpha=Alpha, Beta=Beta, IW=IW, KK=KK)
  Alphas[[b]] = Alpha
  Betas[[b]] = Beta
  Lambdas[[b]] = Lambda
  Omegas[[b]] = Omega
  if(Plot==T){
   if(is.numeric(W2)==FALSE){
    qW=SEG$Q1
    qW[1]=min(W1[Q])
    qW[K+1]=max(W1[Q])
    plot(qW,Beta[c(1:K,K)], type='s', xlab='w1', ylab='g(w1)', main=paste('iteration ', b,'/',B,sep=''), lwd=2, bty='n')
   }
   if(is.numeric(W2)==TRUE){
    ww1=seq(min(W1[Q]),max(W1[Q]),len=zgrid)
    ww2=seq(min(W2[Q]),max(W2[Q]),len=zgrid)
     for(i in 1:zgrid){
      for(j in 1:zgrid){
       z[i,j]=stepfun(ww1[i],ww2[j],Beta=Beta,Q1=SEG$Q1,Q2=SEG$Q2,K=K)
      }
     }
     persp(ww1,ww2,z, theta=ori3d, xlab='w1', ylab='w2', zlab='g(w1,w2)', ticktype='detailed', main=paste('iteration ', b,'/',B,sep=''), zlim=c(0,max(z)*1.1))
   }
  }
 }

 Alpha.mean = listmean(Alphas,Burn:B)
 Beta.mean = listmean(Betas,Burn:B)
 Lambda.mean = listmean(Lambdas,Burn:B)
 Omega.mean = listmean(Omegas,Burn:B)

 Alpha.map = modelog(Alphas,Burn:B)
 Alpha.SIGMA = Alpha.map$COV
 Alpha.MU = Alpha.map$M
 Alpha.map = Alpha.map$MODE
 Beta.map = modelog(Betas,Burn:B)
 Beta.SIGMA = Beta.map$COV
 Beta.MU = Beta.map$M
 Beta.map = Beta.map$MODE
 Lambda.map=Lambda.build(Alpha=Alpha.map, Beta=Beta.map, IW=IW, KK=KK)


 Omega.map = QUIC(S/n, rho=2*Lambda.map/n, msg=0)$X


 if(marg.map==TRUE){
   for(b in 1:B){
    Tau = up.Tau(d=d, Omega=Omega, Lambda=Lambda.map, Q=Q)
    Omega = up.Omega(d=d, n=n, S=S, Omega=Omega, Lambda=Lambda.map, Tau=Tau)
    Omegas[[b]] = Omega
   }
  Omega.mean = listmean(Omegas,Burn:B)
 }

return(list(Omega.mean=Omega.mean, Omega.map=Omega.map,
  Alpha.mean=Alpha.mean, Beta.mean=Beta.mean, Lambda.mean=Lambda.mean,
  Alpha.map=Alpha.map, Beta.map=Beta.map, Lambda.map=Lambda.map,
  Omegas=Omegas, Alphas=Alphas, Betas=Betas, Lambdas=Lambdas))
}








##### GAR - EMPIRICAL BAYES #####
GAR.EB = function(X, W1, W2=FALSE, B0=300, Burn0=100, B = 2000, Burn = 500, Mloglik=1000, fit='linear', iters=30, m.iters=200, knots=NULL, Plot=FALSE, mult=1.005){

 library(MASS)
 library(mgcv)
 library(QUIC)
 library(statmod)
 library(truncdist)

  POS=function(MAT, d=ncol(MAT)){
   D=diag(sqrt(diag(MAT)))
   MAT=cov2cor(MAT)
   b=0
    while(min(eigen(MAT)$val)<(10^(-10))){
      b=b+0.00001
      MAT=MAT+diag(d)*b
    }
    return(D%*%cov2cor(MAT)%*%D)
  }

  listmean = function(LIST, set=1:length(LIST)){
    MEAN=LIST[[set[1]]]*0
    for(i in set){
      MEAN=MEAN+LIST[[i]]
    }
  return(MEAN/length(set))
  }

  listmean.abs = function(LIST, set=1:length(LIST)){
    MEAN=LIST[[set[1]]]*0
    for(i in set){
      MEAN=MEAN+abs(LIST[[i]])
    }
  return(MEAN/length(set))
  }

  BlockGibbs = function(n, d, S, Lambda, Omega, Tau, B, Q){
   Omegas = list()
   Taus = list()

    for(b in 1:B){

       # Omega update
          for(i in 1:d){
             iO = solve(Omega[-i,-i])
             D = 2*diag(Lambda[-i,i]^2/Tau[-i,i])
             Ai = solve((S[i,i]+2*Lambda[i,i])*iO + D)
             Eta = mvrnorm(1, mu = -Ai%*%S[-i,i], Sigma = Ai)
             Xi = rgamma(1, shape=n/2+1, rate=S[i,i]/2+Lambda[i,i])
             Omega[-i,i] = Eta
             Omega[i,-i] = Eta
             Omega[i,i] = Xi + t(Eta)%*%iO%*%Eta
          }
          Omegas[[b]]=Omega

       # Tau update
             Mean = 1/(Lambda[Q]*abs(Omega[Q]))
             Tau = Omega*0
             Tau[Q] = 1/rinvgauss(d*(d-1)/2, mean=Mean, shape=2)
             Tau = Tau+t(Tau)
             diag(Tau)=1
          Taus[[b]]=Tau

    }
   return(list(Omegas=Omegas, Taus=Taus))
  }


  E.step = function(n, d, S, Omega.start, Tau.start, M.STEP, B, BURN, Q){
   n.err=0
   set=BURN:B
   BG=BlockGibbs(n=n, d=d, S=S, Lambda=M.STEP$LAMBDA, Omega=Omega.start, Tau=Tau.start, B=B, Q=Q)
   Omega.abs=listmean.abs(BG$Omegas, set=set)
  return(list(Omegas=BG$Omegas, Omegabar.abs=Omega.abs, Omegabar=listmean(BG$Omegas, set=set), Taubar=listmean(BG$Taus, set=set), n.err=n.err))
  }


  M.step = function(E.STEP, M.STEP, d, x1, x2, iterations=400, fit=fit, knots=NULL, DEG=NULL){

   alpha.update = function(ALPHA, G, OMEGA.abs, d){
    for(i in 1:d){
      b = 2*sum(ALPHA[-i]*OMEGA.abs[i,-i]*G[i,-i])
      ALPHA[i] = ( sqrt( b^2 + 8*OMEGA.abs[i,i]*(d+1) ) - b ) / (4*OMEGA.abs[i,i])
    }
   return(ALPHA)
   }

   if(is.numeric(x2)==FALSE){
     g.update = function(ALPHA, OMEGA.abs, d, Q){
      Y = 2*diag(ALPHA)%*%OMEGA.abs%*%diag(ALPHA)
      y=c(abs(Y[Q]))

      if(fit=='linear'){
       mod = gam(y ~ x1, family = Gamma(link = "inverse"))
      }

      if(fit=='splines'){
       mod = gam(y ~ s(x1,k=DEG, bs='cr'), family = Gamma(link = "inverse"), knots=knots)
      }

      pred.gx = predict(mod, type='link', se.fit=T)
      G = OMEGA.abs*0
      G.low = OMEGA.abs*0
      G.up = OMEGA.abs*0
      G[Q] = pred.gx$fit
      G.low[Q] = pred.gx$fit - 2*pred.gx$se.fit
      G.up[Q] = pred.gx$fit + 2*pred.gx$se.fit
      G = G+t(G)
      G.low = G.low+t(G.low)
      G.up = G.up+t(G.up)
     return(list(G=G, G.low = G.low, G.up = G.up, mod=mod))
     }
   }


   if(is.numeric(x2)){

     g.update = function(ALPHA, OMEGA.abs, d, Q){
      Y = 2*diag(ALPHA)%*%OMEGA.abs%*%diag(ALPHA)
      y=c(abs(Y[Q]))

      if(fit=='linear'){
       mod = gam(y ~ (x1+x2)^2, family = Gamma(link = "inverse"))
      }

      if(fit=='splines'){
       mod = gam(y ~ s(x1, x2, bs='tp', k=DEG), family = Gamma(link = "inverse"), knots=knots)
      }

      pred.gx = predict(mod, type='link', se.fit=T)
      G = OMEGA.abs*0
      G.low = OMEGA.abs*0
      G.up = OMEGA.abs*0
      G[Q] = pred.gx$fit
      G.low[Q] = pred.gx$fit - 2*pred.gx$se.fit
      G.up[Q] = pred.gx$fit + 2*pred.gx$se.fit
      G = G+t(G)
      G.low = G.low+t(G.low)
      G.up = G.up+t(G.up)
     return(list(G=G, G.low = G.low, G.up = G.up, mod=mod))
     }

   }

   ALPHA = M.STEP$ALPHA
   G = list(G=M.STEP$G)
   OMEGA.abs = E.STEP$Omegabar.abs
    for(ll in 1:iterations){
     ALPHA = alpha.update(ALPHA = ALPHA, G = G$G, OMEGA.abs = OMEGA.abs, d=d)
     G = g.update(ALPHA = ALPHA, OMEGA.abs = OMEGA.abs, d=d, Q=Q)
    }
   LAMBDA = diag(ALPHA)%*%G$G%*%diag(ALPHA)
   diag(LAMBDA) = ALPHA^2
  return(list(LAMBDA = LAMBDA, ALPHA = ALPHA, G = G$G, G.low = G$G.low, G.up = G$G.up, mod=G$mod))
  }

  Q.logpost = function(d, n, Lambda, Gamma, Omegas, C2, llikset, Q){
   C1=sum(log(Lambda[-Q]))
   C3=sapply(llikset, function(b) sum(abs(Lambda*Omegas[[b]])))
   MEAN=mean(C1+C2-C3)+var(C1+C2-C3)/2
   SE=sqrt((var(C1+C2-C3)+2*(var(C1+C2-C3)^2))/length(llikset))
   UP=MEAN+2*SE
   LOW=MEAN-2*SE
  return(c(LOW,MEAN,UP))
  }


 # Initialization
  M=B; M0=B0; BURN=Burn; BURN0=Burn0
  d=ncol(X)
  n=nrow(X)
  S=cov(X)*(n-1)
  OMEGA.start=solve(POS(cov(X)))
  TAU.start=diag(d)*0+0.1
  LAMBDA.start=diag(d)*0+0.001
  Q=upper.tri(diag(d))
  lliks=numeric()
  lliks.up=numeric()
  lliks.low=numeric()

  x1=c(W1[Q])
  x1sort=sort(x1)
  x1ord=order(x1)
  DEG=NULL
  if(is.numeric(W2)==FALSE){
    x2=FALSE
    if(fit=='linear'){
     print('univariate linear fit')
    }
    if(fit=='splines'){
     DEG=length(knots)
     knots=list(x1=knots)
     print(paste('univariate splines fit, degree=',DEG,sep=''))
    }
  }
  if(is.numeric(W2)==TRUE){
    x2=c(W2[Q])
    if(fit=='linear'){
     print('bivariate linear fit')
    }
    if(fit=='splines'){
     knots=list(x1=knots[[1]],x2=knots[[2]])
     DEG=length(knots[[1]])
     print(paste('bivariate splines fit, degree=',DEG,sep=''))
   }
  }

  print('initialization')
   E.STEP = list(Omegabar.abs=abs(OMEGA.start), Omegabar=OMEGA.start, Taubar=diag(d)*0+1) # old line
   M.STEP = list(ALPHA=1/sqrt(diag(OMEGA.start)), G=1/abs(cov2cor(OMEGA.start))) # old line
   M.STEP = M.step(E.STEP=E.STEP, M.STEP=M.STEP, d=d, x1=x1, x2=x2, iterations=m.iters, fit=fit, knots=knots, DEG=DEG)
   Gamma=1/abs(OMEGA.start)
   set.seed(999)
   OmegasLoglik=BlockGibbs(n=n, d=d, S=S, Lambda=Gamma, Omega=E.STEP$Omegabar, Tau=E.STEP$Taubar, B=Mloglik, Q=Q)$Omegas
   llikset=round(length(OmegasLoglik)/5):length(OmegasLoglik)
   C2=sapply(llikset, function(b) sum(abs(Gamma*OmegasLoglik[[b]])))
   lik = Q.logpost(d=d, n=n, Lambda=M.STEP$LAMBDA, Gamma=Gamma, Omegas=OmegasLoglik, C2=C2, llikset=llikset, Q=Q)
   lliks[1]=lik[2]
   lliks.up[1]=lik[3]
   lliks.low[1]=lik[1]



  # EM iterations
    k=1
    while(k < (iters+1)){
     print(paste('EM iteration ', k,'/', iters, ',   loglik=',lik[2], sep=''))

     # E-step
      E.STEP = E.step(n=n, d=d, S=S, Omega.start=E.STEP$Omegabar, Tau.start=E.STEP$Taubar, M.STEP=M.STEP, B=M0, BURN=BURN0, Q=Q)

     # M-step
      M.STEP.new = M.step(E.STEP=E.STEP, M.STEP=M.STEP, d=d, x1=x1, x2=x2, iterations=m.iters, fit=fit, knots=knots, DEG=DEG)
      if(k==iters){
        M0=M
        BURN0=BURN
      }
      lik.new=Q.logpost(d=d, n=n, Lambda=M.STEP.new$LAMBDA, Gamma=Gamma, Omegas=OmegasLoglik, C2=C2, llikset=llikset, Q=Q)
      if(lik.new[2]>max(lliks.low)){
       M.STEP=M.STEP.new
       lik=lik.new
      }
      lliks[k+1]=lik[2]
      lliks.up[k+1]=lik[3]
      lliks.low[k+1]=lik[1]


     k=k+1
     M0 = min(M,round(M0*mult))
     BURN0 = min(BURN,round(BURN0*mult))

    if(Plot==TRUE){
      if(is.numeric(W2)==FALSE){
         par(mfrow=c(1,2), mar=c(4.5,4.5,2,2))
         plot(x1sort, M.STEP$G[Q][x1ord], type='l', lwd=2, col='tomato', xlab='w', ylab='g(w)', bty='n', ylim=range(c(M.STEP$G.low[Q],M.STEP$G.up[Q])), cex.lab=1.4)
         polygon(c(x1sort,x1sort[length(x1sort):1]),c(M.STEP$G.up[Q][x1ord],M.STEP$G.low[Q][x1ord[length(x1ord):1]]), col=rgb(1,0,0,.2),border=FALSE)
      }
      if(is.numeric(W2)==TRUE){
        par(mfrow=c(1,2), mar=c(4.5,4.5,2,2))
        vis.gam(M.STEP$mod, color='topo', n.grid=40, xlab='w1', ylab='w2', zlab='g(w1,w2)', cex.lab=1.4, ticktype='detailed')
      }
    try(plot(1:length(lliks),lliks, xlab='iteration', ylab=expression(paste('log ',pi,'(',Theta,'|','data',')')), type='l', lwd=2, col='blue', xlim=c(0,iters+1), ylim=range(c(lliks.low,lliks.up)),bty='n', pch=19, main='', cex.lab=1.4),T)
    try(polygon(c(1:length(lliks),length(lliks):1), c(lliks.up,lliks.low[length(lliks):1]), col=rgb(0,0,1,.2), border=FALSE),T)
    }
  }
 E.STEP = E.step(n=n, d=d, S=S, Omega.start=E.STEP$Omegabar, Tau.start=E.STEP$Taubar, M.STEP=M.STEP, B=M, BURN=BURN, Q=Q)

 MAP=QUIC(S/n, rho=2*M.STEP$LAMBDA/n, msg=0)$X

return(list(E.STEP=E.STEP, M.STEP=M.STEP, Omega.mean=E.STEP$Omegabar, Omega.map=MAP))
}



