{load_file("nrngui.hoc")}
{load_file("../common/loadbal.hoc")}

objref gidvec, cvec, splitxlist, splitixlist, cpu, splitcplx, splitindex, bal
objref splitbres, splitbrlist
bal = new LoadBalance()

gidvec = new Vector()
cvec = new Vector()
splitxlist = new List()
splitixlist = new List()
splitbrlist = new List()
cpu = new Vector()
splitcplx = new Vector()
splitindex = new Vector()
splitbres = new Vector()

proc rdat() {local i, j, k, n1, n2, n3, c, n3 localobj f, vecx, vecix, vecb
	f = new File()
	f.ropen($s1)
	
	n1 = f.scanvar()
	for i=0, n1-1 {
		n2 = f.scanvar()
		for j=0, n2-1 {
			f.scanvar()
			gid = f.scanvar()
			c = f.scanvar()
			n3 = f.scanvar()
			gidvec.append(gid)
			cvec.append(c)
			vecx = new Vector(n3)
			vecix = new Vector(n3)
			vecb = new Vector(n3)
			for k = 0, n3-1 {
				vecx.x[k] = f.scanvar()
				vecix.x[k] = f.scanvar()
				vecb.x[k] = f.scanvar()
			}
			splitxlist.append(vecx)
			splitixlist.append(vecix)
			splitbrlist.append(vecb)
		}
	}
}


// input $1 = ncpu, $o2 weights in order of insertion (more or less presorted)
// output $o3 parallel to $o1 with elements being cpu number
// return balance
func lptmodified() {local i, j, n1  localobj wcpu
	$o3.resize($o2.size)
	wcpu = new Vector($1)
	n1 = $1
	if (n1 > $o2.size) { n1 = $o2.size }
	for i = 0, n1-1 {
		wcpu.x[i] = $o2.x[i]
		$o3.x[i] = i
	}
	for i = $1, $o2.size-1 {
		j = wcpu.min_ind
		wcpu.x[j] += $o2.x[i]
		$o3.x[i] = j
	}
	return wcpu.max*$1/wcpu.sum
}

func balance2() {local i, n1, x, c, is  localobj svec, cx, cpu, s
	svec = $o2.sortindex.reverse
	// first $1/2 of them choose largest cells and split as
	// evenly as possible
	n1 = $4
	if (n1 > $o2.size) { n1 = $o2.size }
	cx = new Vector($o2.size + n1)
	cpu = cx.c
	for i = 0, n1-1 {
		c = $o2.x[svec.x[i]]
		s = $o3.object(svec.x[i])
		is = s.indwhere(">", c/2+.01)
		cx.x[i] = s.x[is-1]
		cx.x[n1 + i] = c - cx.x[i]
	}
	// all the rest
	for i = n1, $o2.size-1 {
		cx.x[n1 + i] = $o2.x[svec.x[i]]
	}
	x = lptmodified($1, cx, cpu)
	if (n1 == 0) {
		printf("%d wholecell LPT %d%%\n", $1, int((x - 1)*100))
	}else{
		printf("%d SplitLPT %d%%\n", $1, int((x - 1)*100))
	}
	return x
}

func balance() {local i, err  localobj f, si, s
	s =new String()
	sprint(s.s, "%s.dat", $s2)
	rdat(s.s)
	err = bal.distrib($1, cvec, splitxlist, splitixlist, \
		cpu, splitcplx, splitindex, splitbrlist, splitbres)
	f = new File()
	sprint(s.s, "%s.%d", $s2, $1)
	f.wopen(s.s)
	f.printf("%d\n", gidvec.size)
	si = cpu.sortindex
//print gidvec.size, si.size, cpu.size, gidvec.size, splitindex.size
//print splitbres, splitbres.size
	for i=0, gidvec.size-1 {
		f.printf("%d %d %d %d %d %d\n", cpu.x[si.x[i]], \
			gidvec.x[si.x[i]], splitindex.x[si.x[i]], \
			splitbres.x[si.x[i]], \
			splitcplx.x[si.x[i]], cvec.x[si.x[i]], \
			bal.cplx.x[cpu.x[si.x[i]]] )
	}
	f.close()
//	balance2($1, cvec, splitxlist, int($1/2))
//	balance2($1, cvec, splitxlist, 156)
//	balance2($1, cvec, splitxlist, 0)
	return err
}

objref pc
pc = new ParallelContext()
ncpu = 32*2^(pc.id)
percenterr = balance(ncpu, "splitbal")
//bal.cplx.printf
//print cvec.sum
//print bal.cplx.sum
{printf("ncpu=%d load balance error = %d%% average load = %d   max load = %d\n", ncpu,percenterr, cvec.sum/ncpu, bal.cplx.max)}

proc chkbal() {local i localobj ix, split
	ix = new Vector()
	split= new Vector()
	for i=0, cpu.max-1 {
		ix.indvwhere(cpu, "==", i)
		split.index(splitindex, ix)
		if (split.indvwhere("!=", 0).size > 1) {
			printf("cpu %d with %d\n", i, split.size)
		}
	}
}
chkbal()
{pc.runworker()}
{pc.done()}
quit()