objref pnm, pc, nil, cells, load_balance_, gidvec, binfo
{load_file("netparmpi.hoc")}
{load_file("binfo.hoc")}
{load_file("loadbal.hoc")}
{localloadfile("lbcreate.hoc")}
{localloadfile("mscreate.hoc")}
pnm = new ParallelNetManager(0) //ncell to be determined later
pc = pnm.pc
serial = 0 // 0 if parallel, 1 if serial
if (pc.nhost == 1) { serial = 1 }

func base_gid() {
	if (object_id(binfo)) {
		return binfo.base_gid($1)
	}else{
		if ($1 >= splitbit) {
			return $1 - splitbit
		}
	}
	return $1
}

func thishost_gid() {
	if (object_id(binfo)) {
		return binfo.thishost_gid($1)
	}else{
		if (pc.gid_exists($1)) {
			return $1
		}else if ($1 > splitbit) {
			if (pc.gid_exists($1 - splitbit)) {
				return $1 - splitbit
			}
		}else{
			if (pc.gid_exists($1 + splitbit)) {
				return $1 + splitbit
			}
		}
	}
	return -1
}

proc read_load_balance_info() { localobj s
	s = new String()
	sprint(s.s, "%s.%d", $s1, pc.nhost)
	load_balance_.read_load_balance_info(s.s, pc.id)
	gidvec = new Vector()
}

iterator pcitr() {local i1, i2, si
    if (load_balance_phase == 3 || load_balance_phase == 7) {
	// only over the source gids unless $3
	// exists and nonzero. Note that with load balance, the gid
	// exists only after the cell is created.
	si = 0
	if (numarg() == 3) {
		si = $3
	}
	for i1=0, gidvec.size-1 {
		i2 = gidvec.x[i1]
		$&1 = i1
		$&2 = i2
		if (si == 1) {
			iterator_statement
		}else if (i2  == base_gid(i2)) {
			iterator_statement
		}
	}
    }else if (load_balance_phase == 5) { // whole cell gids exist
	// after read_wholecell_info(prefix) is called and are in
	// gidvec
	for i1=0, gidvec.size-1 {
		i2 = gidvec.x[i1]
		$&1 = i1
		$&2 = i2
		iterator_statement
	}
    }else{ // round robin
        i1 = 0
        for (i2=pc.id; i2 < pnm.ncell; i2 += pc.nhost) {
                $&1 = i1
                $&2 = i2
                iterator_statement
                i1 += 1
        }
    }
}

proc gid_distribute() {
	if (use_load_balance == 0) {
		gidvec = new Vector()
		pnm.round_robin()
	}
}

proc want_all_spikes() {local i, gid
	for pcitr(&i, &gid) {
		pnm.spike_record(gid)
	}
}

proc spike2file() { localobj outf, s
	s = new String()
	sprint(s.s, "out%d.dat", pnm.nhost)
        outf = new File()
        if (pc.id == 0) {outf.wopen(s.s) outf.close } // start new file
	for pnm.serialize() {
		outf.aopen(s.s)
	        for i=0, pnm.idvec.size-1 {
        	        outf.printf("%g\t%d\n", pnm.spikevec.x[i], pnm.idvec.x[i])
	        }
	        outf.close
	}
}

par_ncell_ = 0
par_local_ncell_ = 0
proc par_create() {local i  localobj c, nc
    if (load_balance_phase == 3) {
	if ((i = load_balance_.gid.indwhere("==", par_ncell_)) != -1) {
		load_balanced_create(i, $s1)
		par_local_ncell_ += 1
	}
    }else if (load_balance_phase == 7) {
	if (multisplit_create(par_ncell_, $s1)) {
		par_local_ncell_ += 1
	}
    }else{
	if (pc.gid_exists(par_ncell_)) {
		if (cells.count != par_local_ncell_) {
			printf("%d Assertion error in par_create before executing: %s\n", pc.id, $s1)
			quit()
		}
		execute($s1)
		par_local_ncell_ += 1
		c = cells.object(cells.count-1)
		gidvec.append(par_ncell_)
		// defer the following three lines and do all the cells
		// at once in par_register_cells
		// the problem is that because topology has now changed
		// connect2target forces a complete update of internal
		// data structures
//		c.connect2target(nil, nc)
//		pc.cell(par_ncell_, nc)
//		pc.outputcell(par_ncell_)
	}
    }
	par_ncell_ += 1
}

proc par_register_cells() {local i, gid  localobj c, nc
	for pcitr(&i, &gid, 1) {	
		c = cells.object(i)
		nc = nil
		if (gid != base_gid(gid)) {
			forsec c.all {
				if (object_id(nc) == 0) {
					nc = new NetCon(&v(.5), nil)
				}
			}
		}else{
			c.connect2target(nil, nc)
		}
		pc.cell(gid, nc, 0)
		if (gid == base_gid(gid)) {
			pc.outputcell(gid)
		}
	}
}

objref tdat_   
tdat_ = new Vector(6)
objref cvec, clist
cvec = new Vector(100)
clist = new List()
proc prun() { local i localobj c
	pc.setup_transfer()
        pnm.set_maxstep(10)
        runtime=startsw()
        tdat_.x[0] = pnm.pc.wait_time
	//print "prun ", pc.id
	//for i = 0, cvec.size-1 {
	//	print cvec.x[i], clist.x[i]
	//}
	if( pc.id == 0 ) {
		if (jordan_trace == 1) {
		    for i=0, cells.count-1 {
			c = cells.object(i)
			if( strcmp(c.name, "suppyrRS") == 0 ) {
			    Jtrace[0].record(&c.comp[1].v(0.5))
			    tvec.record(&t)
			    break
			}
		    }
		    for i=0, cells.count-1 {
			c = cells.object(i)
			if( strcmp(c.name,"suppyrFRB") == 0 ) {
			    Jtrace[1].record(&c.comp[1].v(0.5))
			    break
			}
		    }
		    for i=0, cells.count-1 {
			c = cells.object(i)
			if( strcmp(c.name,"supaxax") == 0 ) {
			    Jtrace[2].record(&c.comp[1].v(0.5))
			    break
			}
		    }
		    for i=0, cells.count-1 {
			c = cells.object(i)
			if( strcmp(c.name,"spinstell") == 0 ) {
			    Jtrace[3].record(&c.comp[1].v(0.5))
			    break
			}
		    }
		    for i=0, cells.count-1 {
			c = cells.object(i)
			if( strcmp(c.name,"tuftIB") == 0 ) {
			    Jtrace[4].record(&c.comp[1].v(0.5))
			    break
			}
		    }
		    for i=0, cells.count-1 {
			c = cells.object(i)
			if( strcmp(c.name,"tuftRS") == 0 ) {
			    Jtrace[5].record(&c.comp[1].v(0.5))
			    break
			}
		    }
		    for i=0, cells.count-1 {
			c = cells.object(i)
			if( strcmp(c.name,"nontuftRS") == 0 ) {
			    Jtrace[6].record(&c.comp[1].v(0.5))
			    break
			}
		    }
		    for i=0, cells.count-1 {
			c = cells.object(i)
			if( strcmp(c.name,"deepaxax") == 0 ) {
			    Jtrace[7].record(&c.comp[1].v(0.5))
			    break
			}
		    }
		    for i=0, cells.count-1 {
			c = cells.object(i)
			if( strcmp(c.name,"deepbask") == 0 ) {
			    Jtrace[8].record(&c.comp[1].v(0.5))
			    break
			}
		    }
		    for i=0, cells.count-1 {
			c = cells.object(i)
			if( strcmp(c.name,"supbask") == 0 ) {
			    Jtrace[9].record(&c.comp[1].v(0.5))
			    break
			}
		    }
		    for i=0, cells.count-1 {
			c = cells.object(i)
			if( strcmp(c.name,"supLTS") == 0 ) {
			    Jtrace[10].record(&c.comp[1].v(0.5))
			    break
			}
		    }
		    for i=0, cells.count-1 {
			c = cells.object(i)
			if( strcmp(c.name,"deepLTS") == 0 ) {
			    Jtrace[11].record(&c.comp[1].v(0.5))
			    break
			}
		    }
		}
	}
        stdinit()
	if (savestatetest == 1) {
		pnm.psolve(tstop/2)
		savestate()
	}else if (savestatetest == 2) {
		restorestate()
	}
        pnm.psolve(tstop)
        tdat_.x[0] = pnm.pc.wait_time - tdat_.x[0]
        runtime = startsw() - runtime
        tdat_.x[1] = pnm.pc.step_time
        tdat_.x[2] = pnm.pc.send_time
	tdat_.x[3] = pc.vtransfer_time(0) // for gaps
	tdat_.x[4] = pc.vtransfer_time(1) // for splitcells
//      printf("%d wtime %g\n", pnm.myid, waittime)


}

objref mxhist_
proc mkhist() {
	if (pnm.myid == 0) {
		mxhist_ = new Vector($1)
		pc.max_histogram(mxhist_)
	}
}
proc prhist() {local i, j
	if (pnm.myid == 0 && object_id(mxhist_)) {
		printf("histogram of #spikes vs #exchanges\n")
		j = 0
		for i=0, mxhist_.size-1 {
			if (mxhist_.x[i] != 0) { j = i }
		}
		for i = 0, j {
			printf("%d\t %d\n", i, mxhist_.x[i])
		}
		printf("end of histogram\n")
	}
}


mindelay_ = 1e9
func mindelay() {local i, md
	if (pc.nhost > 1) {
		pc.context("{pc.post(\"mindelay\", mindelay_)}")
		for i=1, pc.nhost-1 {
			pc.take("mindelay", &md)
			if (md < mindelay_) {
				mindelay_ = md
			}
		}		
	}
	return mindelay_ // see nc_append
}

objref tavg_stat, tmin_stat, tmax_stat, idmin_stat, idmax_stat
proc poststat() {
	pnm.pc.post("poststat", pnm.myid, tdat_)
}
proc getstat() {local i, j, id localobj tdat
	tdat = tdat_.c	tavg_stat = tdat_.c  tmin_stat = tdat_.c  tmax_stat = tdat_.c
	idmin_stat = tdat_.c.fill(0)  idmax_stat = tdat_.c.fill(0)
	if (pnm.nwork > 1) {
		pnm.pc.context("poststat()\n")
		for i=0, pnm.nwork-2 {
			pnm.pc.take("poststat", &id, tdat)
			tavg_stat.add(tdat)
			for j = 0, tdat_.size-1 {
				if (tdat.x[j] > tmax_stat.x[j]) {
					idmax_stat.x[j] = id
					tmax_stat.x[j] = tdat.x[j]
				}
				if (tdat.x[j] < tmin_stat.x[j]) {
					idmin_stat.x[j] = id
					tmin_stat.x[j] = tdat.x[j]
				}
			}
		}
	}
	tavg_stat.div(pnm.nhost)
}

objref spstat_
proc postspstat() {local i
	spstat_ = new Vector()
	cvode.spike_stat(spstat_)
	i = spstat_.size
	spstat_.resize(spstat_.size + 4)
	spstat_.x[i] = pc.spike_statistics(&spstat_.x[i+1], &spstat_.x[i+2],\
		&spstat_.x[i+3])
	pnm.pc.post("postspstat", pnm.myid, spstat_)
}
proc print_spike_stat_info() {local i, j, id  localobj spstat, sum, min, max, idmin, idmax, label
	spstat = new Vector()
	spstat_ = new Vector()
	cvode.spike_stat(spstat_)
	i = spstat_.size
	spstat_.resize(spstat_.size + 4)
	spstat_.x[i] = pc.spike_statistics(&spstat_.x[i+1], &spstat_.x[i+2],\
		&spstat_.x[i+3])
	sum = spstat_.c
	min = spstat_.c
	max = spstat_.c
	idmin = spstat_.c.fill(0)
	idmax = spstat_.c.fill(0)
	if (pnm.nwork > 1) {
		pnm.pc.context("postspstat()\n")
		for i=0, pnm.nwork-2 {
			pnm.pc.take("postspstat", &id, spstat)
			sum.add(spstat)
			for j=0, spstat.size-1 {
				if (spstat.x[j] > max.x[j]) {
					idmax.x[j] = id
					max.x[j] = spstat.x[j]
				}
				if (spstat.x[j] < min.x[j]) {
					idmin.x[j] = id
					min.x[j] = spstat.x[j]
				}
			}
		}
	}
	label = new List()
	label.append(new String("eqn"))
	label.append(new String("NetCon"))
	label.append(new String("deliver"))
	label.append(new String("NC deliv"))
	label.append(new String("PS send"))
	label.append(new String("S deliv"))
	label.append(new String("S send"))
	label.append(new String("S move"))
	label.append(new String("Q insert"))
	label.append(new String("Q move"))
	label.append(new String("Q remove"))
	label.append(new String("max sent"))
	label.append(new String("sent"))
	label.append(new String("received"))
	label.append(new String("used"))

	printf("%10s %13s %10s %10s    %5s   %5s\n",\
		"", "total", "min", "max", "idmin", "idmax")
	for i=0, spstat_.size-1 {
		printf("%-10s %13.0lf %10d %10d    %5d   %5d\n",\
label.object(i).s, sum.x[i], min.x[i], max.x[i], idmin.x[i], idmax.x[i])
	}

	printf("\n%-10s %-10s %-10s %-10s %-10s %-10s %-10s %-10s\n",\
		"setup", "run", "avgspkxfr", "avgcomp", "avgx2q", "avgvxfr", "avgsplit", "avgcmplx")
	printf("%-10.4g %-10.4g", setuptime, runtime)
	for i=0, tdat_.size-1 { printf(" %-10.4g", tavg_stat.x[i]) }

	printf("\n\n%5s %-15s %-15s %-15s %-15s %-15s %-15s\n", \
		"", "id   spkxfr", "id   com", "id   x2q", "id   vxfr", "id   split", "id   cmplx")
	printf("%-5s", "min")
	for i=0, tdat_.size-1 { printf(" %-4d %-10.4g", idmin_stat.x[i], tmin_stat.x[i]) }
	printf("\n%-5s", "max")
	for i=0, tdat_.size-1 { printf(" %-4d %-10.4g", idmax_stat.x[i], tmax_stat.x[i]) }
	printf("\n")
}

proc perf2file() { local i  localobj perf
	perf = new File()
	perf.aopen("perf.dat")
	perf.printf("%d %d %d %d    %g %g     ",pnm.nhost, pnm.ncell, load_balance_phase, use_gap, setuptime, runtime)
	for i=0, tdat_.size-1 { perf.printf(" %g", tavg_stat.x[i]) }
	perf.printf("     ")
	for i=0, tdat_.size-1 { perf.printf(" %d %g ", idmin_stat.x[i], tmin_stat.x[i]) }
	perf.printf("     ")
	for i=0, tdat_.size-1 { perf.printf(" %d %g ", idmax_stat.x[i], tmax_stat.x[i]) }
	perf.printf("\n")

	perf.close
}


proc calculate_mechanism_complexity() {local i, j, n  localobj f, o, s
	s = new String()
	f = new File()
	load_balance_ = new LoadBalance()
	if ($1 == 1) {
		load_balance_.ExperimentalMechComplex()
		if (pnm.myid == 0) {
			f.wopen("mcomplex.dat")
			for j=0, 1 {
				for i=0, load_balance_.m_complex_[j].size - 1 {
					load_balance_.mt[j].select(i)
					load_balance_.mt[j].selected(s.s)
					f.printf("%g %s\n", load_balance_.m_complex_[j].x[i], s.s)
				}					  
			}
		}
		f.close
	}else{
		if (f.ropen("mcomplex.dat")) {
			for j=0, 1 {
				n = load_balance_.m_complex_[j].size
				load_balance_.m_complex_[j].scanf(f, n)
			}
			f.close
		}
	}
}

proc print_load_balance_info() {local i, j, gid, c  localobj b, f, o, oi, oj
	o = new String()
	sprint(o.s, "bal.%04d", pnm.myid)
	if (object_id(load_balance_)) {
		b = load_balance_
	}else{
		b = new LoadBalance()
	}
	if ($1 > 0) {
		f = new File()
		f.wopen(o.s)
		f.printf("%d\n", cells.count)
	}
	for pcitr(&i, &gid, 1) {
		c = b.cell_complexity(pnm.pc.gid2obj(gid))
		tdat_.x[5] += c
		if ($1 == 2 ) {
			o = b.resolutions(oi, oj)
			f.printf("%d %d %d %d\n", i, gid, c, o.size)
			for j=0, o.size-1 {
				f.printf("%d %d %d\n", o.x[j], oi.x[j], oj.x[j])
			}
		}else if ($1 == 1){
			f.printf("%d %d %d %d\n", i, gid, c, 0)
		}
	}
	if ($1 > 0) {
		f.close()
	}
}

proc print_single_split_load_balance_info() {local i, j, k, gid, c, rank \
  localobj b, f, o, oi, oj, olist, bvec
	if (object_id(load_balance_)) {
		b = load_balance_
	}else{
		b = new LoadBalance()
	}
	olist = new List()
	bvec = new Vector()
	for pcitr(&i, &gid, 1) {
		c = b.cell_complexity(pnm.pc.gid2obj(gid))
		tdat_.x[5] += c
		bvec.append(c)
		o = b.resolutions(oi, oj)
		olist.append(o.c)
		olist.append(oi.c)
		olist.append(oj.c)
	}
	f = new File()
	if (pc.id == 0) {
		f.wopen($s1)
		f.printf("1\n%d\n", pnm.ncell)
		f.close
	}
	for rank = 0, pc.nhost - 1 {
		if (rank == pc.id) {
			f.aopen($s1)
			k = -1
			for pcitr(&i, &gid, 1) {
				o = olist.o(k+=1)	
				oi = olist.o(k+=1)	
				oj = olist.o(k+=1)	
				c = bvec.x[i]
f.printf("%d %d %d %d\n", i, gid, c, o.size)
				for j=0, o.size-1 {
f.printf("%d %d %d\n", o.x[j], oi.x[j], oj.x[j])
				}
			}
			f.close()
		}
		pc.barrier()
	}
}

objref cb_cx, cb_gid
proc cell_complexity() {local i, gid, cx, tcx, mcx, nc
	// return total_complexity, max_complexity, ncell
	cb_cx = new Vector()
	cb_gid = new Vector()
	tcx = 0 mcx = 0 nc = 0
	for pcitr(&i, &gid, 1) {
		cx = load_balance_.cell_complexity(pc.gid2obj(gid))
		cb_cx.append(cx)
		cb_gid.append(gid)
		tcx += cx
		if (mcx < cx) {
			mcx = cx
		}
		nc += 1
	}
	nc = pc.allreduce(nc, 1)
	tcx = pc.allreduce(tcx, 1)
	mcx = pc.allreduce(mcx, 2)
	$&1 = tcx  $&2 = mcx  $&3 = nc
}

proc print_wholecell_info() {local tcx, mcx, nc, rank, i \
  localobj f, s
	cell_complexity(&tcx, &mcx, &nc)
	s = new String()
	sprint(s.s, "%s.dat", $s1)
	f = new File()
	if (pc.id == 0) {
		f.wopen(s.s)
		f.printf("%d\n", pc.nhost)
		f.close
	}
	for rank = 0, pc.nhost - 1 {
		if (rank == pc.id) {
			f.aopen(s.s)
f.printf("%d\n", cb_cx.size)
			for i=0, cb_cx.size-1 {
f.printf("%d %g %g\n", cb_gid.x[i], cb_cx.x[i], 0)
			}
			f.close()
		}
		pc.barrier()
	}
}

proc print_multisplit_info() {local tcx, mcx, nc, rank, i, gid, nh, opt \
  localobj f, s, mlist, ms
	ms = new Vector(100)
	mlist = new List()
	nh = $2
	cell_complexity(&tcx, &mcx, &nc)
	opt = tcx/nh * msoptfactor
if (pc.id == 0) printf("opt=%g tcx=%g mcx=%g nh=%d nc=%d\n", opt, tcx, mcx, nh, nc)
	for pcitr(&i, &gid, 1) {
		load_balance_.cell_complexity(pc.gid2obj(gid))
		load_balance_.multisplit(gid, opt, ms)
		mlist.append(ms.c)
	}
	s = new String()
	sprint(s.s, "%s.dat", $s1)
	f = new File()
	if (pc.id == 0) {
		f.wopen(s.s)
		f.printf("%d\n", pc.nhost)
		f.close
	}
	for rank = 0, pc.nhost - 1 {
		if (rank == pc.id) {
			f.aopen(s.s)
			f.printf("%d\n", mlist.count)
			for i=0, mlist.count-1 {
				pmsdat(f, mlist.object(i))
			}
			f.close()
		}
		pc.barrier()
	}
}

// File, dataVec
proc pmsdat() {local i, i1, n1, i2, n2, i3, n3, id, cx, tcx
	i = -1
	tcx = 0

	$o1.printf("%d", $o2.x[i+=1]) // gid
	$o1.printf(" %g", $o2.x[i+=1]) // total complexity of cell
	n1 = $o2.x[i+=1]
	$o1.printf(" %d\n", n1) // number of pieces
	for i1 = 0, n1-1 {
		n2 = $o2.x[i+=1] // at number of subtrees
		$o1.printf("  %d\n", n2) // number of subtrees
		for i2 = 0, n2-1 {
			cx = $o2.x[i+=1] // at subtree complexity
			tcx += cx
			n3 = $o2.x[i+=1] // at number of children in a subtree
			$o1.printf("   %g %d\n", cx, n3) // subtree complexity
			if (n3 > 0) { $o1.printf("    ") }
			for i3 = 0, n3 - 1 {
				id = $o2.x[i+=1] // at next child
				$o1.printf(" %d", id)
			}
			if (n3 > 0) { $o1.printf("\n") }			
		}
	}
	if (0) {
		printf("gid=%d cell complexity = %g  sum of pieces = %g\n", \
			$o2.x[0], $o2.x[1], tcx)
	}
}

proc savestate() {local i  localobj s, ss, f, rl
	s = new String()
	sprint(s.s, "svst.%04d", pc.id)
	f = new File(s.s)
	ss = new SaveState()
	ss.save()
	ss.fwrite(f, 0)

	rl = new List("Random")
	f.printf("Random %d\n", rl.count)
	for i=0, rl.count-1 {
		f.printf("%d\n", rl.object(i).seq())
	}
	f.close
}

proc restorestate() {local i  localobj s, ss, f, rl
	s = new String()
	sprint(s.s, "svst.%04d", pc.id)
	f = new File(s.s)
	ss = new SaveState()
	ss.fread(f, 0)
	rl = new List("Random")
	if (f.scanvar() != rl.count) {
		execerror("Random count unexpected", "")
	}
	for i=0, rl.count-1 {
		rl.object(i).seq(f.scanvar())
	}
	f.close
	ss.restore()
}

proc read_wholecell_info() {local i, gid, icpu, ncell, nhost  localobj s, f
	gidvec = new Vector()
	s = new String()
	sprint(s.s, "%s.%d.dat", $s1, pc.nhost)
	f = new File()
	if (f.ropen(s.s) == 0) {
		execerror("could not open", s.s)
	}
	if (pc.id == 0) { printf("gids from %s\n", s.s) }
	ncell = f.scanvar()
	nhost = f.scanvar()
	for i=0, ncell - 1 {
		gid = f.scanvar()
		icpu = f.scanvar()
		if (icpu == pnm.myid) {
			pc.set_gid2node(gid, pc.id)
		}
	}
}

proc read_multisplit_info() {local i, base, spgid
	gidvec = new Vector()
	binfo = new BalanceInfo($s1, pnm.myid, pnm.nhost)
	// note that it is not possible to call pc.set_gid2node
	// now since it is not known which piece has the output port til
	// the cell is created. So cannot call pcitr til after cell
	// creation
}