// iterator over this cpu subset using lpt whole cell algorithm
// assuming a specified complexity for mitral and granule

objref lptiter_gidvec, pc
{load_file("loadbal.hoc")}
{load_file("mitral.hoc")}
{load_file("granule.hoc")}
{load_file("split.hoc")}

proc cxhelp() {localobj cell, lb, pc
	lb = new LoadBalance()
	cell = new Granule()
	cell.soma cxgranule = lb.cell_complexity()
	cell = new GranuleSpine()
	cell.neck cxspine = lb.cell_complexity()
	cell = new Mitral()
	cell.soma cxmitral = lb.cell_complexity()
	cell.secden[0] disconnect()
	cell.secden[1] disconnect()
	cell.secden[0] cxsecden = lb.cell_complexity()
	cell.soma cxmainum_mitral = lb.cell_complexity()
	lb.mt[1].select("FastInhib")
	cxfi = lb.m_complex_[1].x(lb.mt[1].selected())
	lb.mt[1].select("AmpaNmda")
	cxampa = lb.m_complex_[1].x(lb.mt[1].selected())
	lb.mt[1].select("ThreshDetect")
	cxtd = lb.m_complex_[1].x(lb.mt[1].selected())
	
	pc = new ParallelContext()
	if (1 && pc.id == 0) {
		printf("cxgranule=%g\n", cxgranule)
		printf("cxspine=%g\n", cxspine)
		printf("cxmitral=%g\n", cxmitral)
		printf("cxsecden=%g\n", cxsecden)
		printf("cxmainum_mitral=%g\n", cxmainum_mitral)
		printf("cxfi=%g\n", cxfi)
		printf("cxampa=%g\n", cxampa)
		printf("cxtd=%g\n", cxtd)
	}
	pc.gid_clear()
}
if (1) {
	cxhelp()
}else{
	cxgranule=216
	cxspine = 23
	cxmitral=1475
	cxsecden=602
	cxmainum_mitral=273
	cxfi=3
	cxampa=4
	cxtd=1
}
cxcpu = 0 // to be computed by lptiter_set

iterator cell_gids() { local i
	if (object_id(lptiter_gidvec) == 0) { lptiter_set() }
	for i=0, lptiter_gidvec.size-1 {
		$&1 = lptiter_gidvec.x[i]
		$&2 = i
		iterator_statement
	}
}

proc lptiter_set() {local i  localobj ldbal, cx, p, gid
	// gid and cx are parallel
	if (is_split) {
		// mitrals are split into three pieces
		// the main piece is given the standard whole
		// cell gid and consists of the soma, axon, primary
		// and tuft. The other two pieces are the left
		// and right secondary dendrite and have a
		// gid = mgid + splitbit*(secondary_index + 1)
		gid = new Vector(num_mitral*3 + num_granule)
		cx = gid.c
		for i=0, num_mitral-1 {
			gid.x[i] = i
			gid.x[num_mitral + i] = i + splitbit
			gid.x[2*num_mitral + i] = i + 2*splitbit
		}
		for i=0, num_granule-1 {
			gid.x[3*num_mitral + i] = num_mitral + i
		} 
		for i=0, gid.size-1 {
			cx.x[i] = fcx(gid.x[i])
		}
	}else{
		// for whole cells
		gid = new Vector(ncell)
		cx = new Vector(ncell)
		// for whole cells
		gid.indgen().add(num_mitral_begin)
		cx.fill(cxmitral, 0, num_mitral-1)
		if (num_granule > 0) {
			cx.fill(cxgranule, num_mitral, ncell - 1)
		}
		for i=0, num_mitral-1 {
			cx.x[i] += (cxfi + cxtd)*(how_many_syn_on_secden(i, 0) + how_many_syn_on_secden(i,1))
		}
		for i=num_mitral, ncell-1 {
			cx.x[i] += (cxtd+cxampa)*how_many_syn_on_granule(i-num_mitral)
		}
	}
	p = lptiter_lpt(cx, pnm.nhost, 0)
	lptiter_gidvec = p.c.indvwhere("==", pnm.myid)
	cx.index(cx, lptiter_gidvec)
	cxcpu = cx.sum
	lptiter_gidvec.index(gid, lptiter_gidvec)
//	printf("cxcpu=%g\n", cxcpu)
//	for i=0, lptiter_gidvec.size-1 { printf("%d %d\n", pnm.myid, lptiter_gidvec.x[i])}
}

// complexity of piece given the piece gid
func fcx() {
	if ($1 >= 2*splitbit) { //left mitral secden
		return cxsecden + (cxtd+cxfi)*how_many_syn_on_secden(basegid($1), 1)
	}else if ($1 >= splitbit) { // right mitral secden
		return cxsecden + (cxtd+cxfi)*how_many_syn_on_secden(basegid($1), 0)
	}else if ($1 >= num_mitral ) { // granule
		$1 -= num_mitral
		return cxgranule + (cxspine+cxtd+cxampa)*how_many_syn_on_granule($1)
	}else{ // main mitral
		return cxmainum_mitral
	}
}

// from loadbal.hoc
// least processing time algorithm
// $o1 is vector of weights  $2 is number of partitions
// return is vector of partition indices parallel to weights
obfunc lptiter_lpt() {local i, j  localobj wx, ix, pw
	if ($3) {
		print $o1.size, " piece weights"
		$o1.printf
	}
	wx = $o1.sortindex.reverse
	ix = new Vector($o1.size)
	pw = new Vector($2)
	for i=0, $o1.size-1 {
		j = wx.x[i]
		w = $o1.x[j]
		ip = pw.min_ind
		pw.x[ip] += w
		ix.x[j] = ip
	}
	if ($3) {
		print $2, " partition complexities"
		pw.printf
	}
	if (pw.mean) {
		thread_cxbal_ = pw.max/(pw.mean)
	}else{
		thread_cxbal_ = 1
	}
	return ix
}