objref pnm, pc, nil, cells, load_balance_, gidvec
{load_file("netparmpi.hoc")}
{load_file("loadbal.hoc")}
{load_file("lbcreate.hoc")}
splitbit = 2^28

pnm = new ParallelNetManager(0) //ncell to be determined later
pc = pnm.pc
serial = 0 // 0 if parallel, 1 if serial
if (pc.nhost == 1) { serial = 1 }

proc load_balance() {
	load_balance_phase = $1
	if (pc.id == 0) { printf("load_balance_phase = %d\n", load_balance_phase)}
	use_load_balance = (load_balance_phase >= 3)
	// easier to use  mcomplex.hoc instead of load_balance_phase == 1
	// to construct mcomplex.dat
	calculate_mechanism_complexity(load_balance_phase == 1)
	if (load_balance_phase == 1) {
		pc.runworker() pc.done() quit()
	}
	if (load_balance_phase == 2) {
		load_file("../common/splitcellbalcalc.hoc")
	}
	if (use_load_balance) {
		if (load_balance_phase == 3) {
			load_file("../common/splitcellbal.hoc")
			read_splitcell_info("splitbal")
		}else if (load_balance_phase == 4) {
			load_file("../common/wholecellbal.hoc")
			read_wholecell_info("cx")
		}else{
			execerror("invalid load_balance_phase", "")
		}
	}
}

proc read_load_balance_info() { localobj s
	s = new String()
	sprint(s.s, "%s.%d", $s1, pc.nhost)
	load_balance_.read_load_balance_info(s.s, pc.id)
	gidvec = new Vector()
}

// when balanced, following replaced by ...balance.hoc implementation
iterator pcitr() {local i1, i2
        i1 = 0
        for (i2=pnm.myid; i2 < pnm.ncell; i2 += pnm.nhost) {
                $&1 = i1
                $&2 = i2
                iterator_statement
                i1 += 1
        }
}

iterator serialize_output() {local i
	if (pc.id == 0) {
		$o1.wopen()
		$o1.close()
	}
	pc.barrier()
	for i=0, pc.nhost-1 {
		if (i == pc.id) {
			$o1.aopen()
			iterator_statement
			$o1.close()
		}
		pc.barrier()
	}
}

proc gid_distribute() {local i, gid
	if (use_load_balance == 0) {
		for pcitr(&i, &gid) {
			pc.set_gid2node(gid, pc.id)
		}
	}
}

proc want_all_spikes() {local i, gid
	for pcitr(&i, &gid) {
		pnm.spike_record(gid)
	}
}

proc spike2file() { local i  localobj outf, s
	s = new String()
	sprint(s.s, "out%d.dat", pnm.nhost)
        outf = new File(s.s)
	for serialize_output(outf) {
	        for i=0, pnm.idvec.size-1 {
        	        outf.printf("%g\t%d\n", pnm.spikevec.x[i], pnm.idvec.x[i])
		}
        }
}

par_ncell_ = 0
par_local_ncell_ = 0
obfunc par_create() {local i  localobj c, s, nc
    if (load_balance_phase == 3) {
	if ((i = load_balance_.gid.indwhere("==", $1)) != -1) {
		c = load_balanced_create(i, $s2)
		par_local_ncell_ += 1
	}
    }else{
	if (pc.gid_exists($1)) {
		if (pnm.cells.count != par_local_ncell_) {
			printf("%d Assertion error in par_create before executing: %s\n", pc.id, $s1)
			quit()
		}
		s = new String()
		sprint(s.s, "pnm.cells.append(%s)", $s2)
		execute(s.s)
		par_local_ncell_ += 1
		c = pnm.cells.object(pnm.cells.count-1)
		// defer the following three lines and do all the cells
		// at once in par_register_cells
		// the problem is that because topology has now changed
		// connect2target forces a complete update of internal
		// data structures
		c.connect2target(nil, nc)
		pc.cell($1, nc, 0)
		pc.outputcell($1)
	}
    }
	par_ncell_ += 1
	return c
}

proc par_register_cells() {local i, gid  localobj c, nc
	if (load_balance_phase != 3) return
	for pcitr(&i, &gid, 1) {
		c = pnm.cells.object(i)
		nc = nil
		if (gid >= splitbit) {
			forsec c.all {
				if (object_id(nc) == 0) {
					nc = new NetCon(&v(.5), nil)
				}
			}
		}else{
			c.connect2target(nil, nc)
		}
		pc.cell(gid, nc, 0)
		if (gid < splitbit) {
			pc.outputcell(gid)
		}
	}
}

func par_ncappend() {local i, gid, synid localobj cell, syn
	i = -1
	if (use_load_balance == 0 || load_balance_phase == 5) {
		i = pnm.nc_append($1, $2, $3, $4, $5)
		return i
	}
	gid = $2
	synid = $3
	if (pc.gid_exists(gid)) {
		cell = pc.gid2cell(gid)
	}else if (pc.gid_exists(gid+splitbit)) {
		gid += splitbit
		cell = pc.gid2cell(gid)
	}
	if (object_id(cell)) {
		syn = cell.synlist.object(synid)
		if (syn.has_loc()) {
			i = pnm.nc_append($1, gid, synid, $4, $5)
		}
	}
	return i
}

objref tdat_   
tdat_ = new Vector(6)
mindelay_ = 1e9
proc prun() {
	if (load_balance_phase == 2) {
		print_splitcell_balance_info()
		pc.runworker() pc.done() quit()
	}else{
		calc_load_balance()
	}
	pc.setup_transfer()
        mindelay_ = pc.set_maxstep(10)
        runtime=startsw()
        tdat_.x[0] = pnm.pc.wait_time
        stdinit()
if (0) {
	if (0) {
		pnm.psolve(tstop/2)
		savestate()
	}else{
		restorestate()
	}
}
        pnm.psolve(tstop)
        tdat_.x[0] = pnm.pc.wait_time - tdat_.x[0]
        runtime = startsw() - runtime
        tdat_.x[1] = pnm.pc.step_time
        tdat_.x[2] = pnm.pc.send_time
	tdat_.x[3] = pc.vtransfer_time(0) // for gaps
	tdat_.x[4] = pc.vtransfer_time(1) // for splitcells
//      printf("%d wtime %g\n", pnm.myid, waittime)
}

objref mxhist_
proc mkhist() {
	if (pnm.myid == 0) {
		mxhist_ = new Vector($1)
		pc.max_histogram(mxhist_)
	}
}
proc prhist() {local i, j
	if (pnm.myid == 0 && object_id(mxhist_)) {
		printf("histogram of #spikes vs #exchanges\n")
		j = 0
		for i=0, mxhist_.size-1 {
			if (mxhist_.x[i] != 0) { j = i }
		}
		for i = 0, j {
			printf("%d\t %d\n", i, mxhist_.x[i])
		}
		printf("end of histogram\n")
	}
}


func mindelay() {local i, md
	if (pc.nhost > 1) {
		pc.context("{pc.post(\"mindelay\", mindelay_)}")
		for i=1, pc.nhost-1 {
			pc.take("mindelay", &md)
			if (md < mindelay_) {
				mindelay_ = md
			}
		}		
	}
	return mindelay_ // see nc_append
}

objref tavg_stat, tmin_stat, tmax_stat, idmin_stat, idmax_stat
proc poststat() {
	pnm.pc.post("poststat", pnm.myid, tdat_)
}
proc getstat() {local i, j, id localobj tdat
	tdat = tdat_.c	tavg_stat = tdat_.c  tmin_stat = tdat_.c  tmax_stat = tdat_.c
	idmin_stat = tdat_.c.fill(0)  idmax_stat = tdat_.c.fill(0)
	if (pnm.nwork > 1) {
		pnm.pc.context("poststat()\n")
		for i=0, pnm.nwork-2 {
			pnm.pc.take("poststat", &id, tdat)
			tavg_stat.add(tdat)
			for j = 0, tdat_.size-1 {
				if (tdat.x[j] > tmax_stat.x[j]) {
					idmax_stat.x[j] = id
					tmax_stat.x[j] = tdat.x[j]
				}
				if (tdat.x[j] < tmin_stat.x[j]) {
					idmin_stat.x[j] = id
					tmin_stat.x[j] = tdat.x[j]
				}
			}
		}
	}
	tavg_stat.div(pnm.nhost)
}

objref spstat_
proc postspstat() {local i
	spstat_ = new Vector()
	cvode.spike_stat(spstat_)
	i = spstat_.size
	spstat_.resize(spstat_.size + 4)
	spstat_.x[i] = pc.spike_statistics(&spstat_.x[i+1], &spstat_.x[i+2],\
		&spstat_.x[i+3])
	pnm.pc.post("postspstat", pnm.myid, spstat_)
}
proc print_spike_stat_info() {local i, j, id  localobj spstat, sum, min, max, idmin, idmax, label
	spstat = new Vector()
	spstat_ = new Vector()
	cvode.spike_stat(spstat_)
	i = spstat_.size
	spstat_.resize(spstat_.size + 4)
	spstat_.x[i] = pc.spike_statistics(&spstat_.x[i+1], &spstat_.x[i+2],\
		&spstat_.x[i+3])
	sum = spstat_.c
	min = spstat_.c
	max = spstat_.c
	idmin = spstat_.c.fill(0)
	idmax = spstat_.c.fill(0)
	if (pnm.nwork > 1) {
		pnm.pc.context("postspstat()\n")
		for i=0, pnm.nwork-2 {
			pnm.pc.take("postspstat", &id, spstat)
			sum.add(spstat)
			for j=0, spstat.size-1 {
				if (spstat.x[j] > max.x[j]) {
					idmax.x[j] = id
					max.x[j] = spstat.x[j]
				}
				if (spstat.x[j] < min.x[j]) {
					idmin.x[j] = id
					min.x[j] = spstat.x[j]
				}
			}
		}
	}
	label = new List()
	label.append(new String("eqn"))
	label.append(new String("NetCon"))
	label.append(new String("deliver"))
	label.append(new String("NC deliv"))
	label.append(new String("PS send"))
	label.append(new String("S deliv"))
	label.append(new String("S send"))
	label.append(new String("S move"))
	label.append(new String("Q insert"))
	label.append(new String("Q move"))
	label.append(new String("Q remove"))
	label.append(new String("max sent"))
	label.append(new String("sent"))
	label.append(new String("received"))
	label.append(new String("used"))

	printf("%10s %13s %10s %10s    %5s   %5s\n",\
		"", "total", "min", "max", "idmin", "idmax")
	for i=0, spstat_.size-1 {
		printf("%-10s %13.0lf %10d %10d    %5d   %5d\n",\
label.object(i).s, sum.x[i], min.x[i], max.x[i], idmin.x[i], idmax.x[i])
	}

	printf("\n%-10s %-10s %-10s %-10s %-10s %-10s %-10s %-10s\n",\
		"setup", "run", "avgspkxfr", "avgcomp", "avgx2q", "avgvxfr", "avgsplit", "avgcmplx")
	printf("%-10.4g %-10.4g", setuptime, runtime)
	for i=0, tdat_.size-1 { printf(" %-10.4g", tavg_stat.x[i]) }

	printf("\n\n%5s %-15s %-15s %-15s %-15s %-15s %-15s\n", \
		"", "id   spkxfr", "id   com", "id   x2q", "id   vxfr", "id   split", "id   cmplx")
	printf("%-5s", "min")
	for i=0, tdat_.size-1 { printf(" %-4d %-10.4g", idmin_stat.x[i], tmin_stat.x[i]) }
	printf("\n%-5s", "max")
	for i=0, tdat_.size-1 { printf(" %-4d %-10.4g", idmax_stat.x[i], tmax_stat.x[i]) }
	printf("\n")
}

proc perf2file() { local i  localobj perf
	perf = new File()
	perf.aopen("perf.dat")
	perf.printf("%d %d %d     %g %g     ",pnm.nhost, pnm.ncell, load_balance_phase, setuptime, runtime)
	for i=0, tdat_.size-1 { perf.printf(" %g", tavg_stat.x[i]) }
	perf.printf("     ")
	for i=0, tdat_.size-1 { perf.printf(" %d %g ", idmin_stat.x[i], tmin_stat.x[i]) }
	perf.printf("     ")
	for i=0, tdat_.size-1 { perf.printf(" %d %g ", idmax_stat.x[i], tmax_stat.x[i]) }
	perf.printf("\n")

	perf.close
}


proc calculate_mechanism_complexity() {local i, j, n  localobj f, o, s
	s = new String()
	f = new File()
	load_balance_ = new LoadBalance()
	if ($1 == 1) {
		load_balance_.ExperimentalMechComplex()
		if (pnm.myid == 0) {
			f.wopen("mcomplex.dat")
			for j=0, 1 {
				for i=0, load_balance_.m_complex_[j].size - 1 {
					load_balance_.mt[j].select(i)
					load_balance_.mt[j].selected(s.s)
					f.printf("%g %s\n", load_balance_.m_complex_[j].x[i], s.s)
				}					  
			}
		}
		f.close
	}else{
		if (f.ropen("mcomplex.dat")) {
			for j=0, 1 {
				n = load_balance_.m_complex_[j].size
				load_balance_.m_complex_[j].scanf(f, n)
			}
			f.close
		}
	}
}

proc calc_load_balance() {local i, j, gid, c  localobj b, f, o, oi, oj
	b = load_balance_
	for pcitr(&i, &gid, 1) {
		c = b.cell_complexity(pnm.pc.gid2obj(gid))
		tdat_.x[5] += c
	}
}

proc savestate() {local i  localobj s, ss, f, rl
	s = new String()
	sprint(s.s, "svst.%04d", pc.id)
	f = new File(s.s)
	ss = new SaveState()
	ss.save()
	ss.fwrite(f, 0)

	rl = new List("Random")
	f.printf("Random %d\n", rl.count)
	for i=0, rl.count-1 {
		f.printf("%d\n", rl.object(i).seq())
	}
	f.close
}

proc restorestate() {local i  localobj s, ss, f, rl
	s = new String()
	sprint(s.s, "svst.%04d", pc.id)
	f = new File(s.s)
	ss = new SaveState()
	ss.fread(f, 0)
	rl = new List("Random")
	if (f.scanvar() != rl.count) {
		execerror("Random count unexpected", "")
	}
	for i=0, rl.count-1 {
		rl.object(i).seq(f.scanvar())
	}
	f.close
	ss.restore()
}