if (!name_declared("pc")) {
execute1("objref pc")
execute1("pc = new ParallelContext()")
}
proc sortspikes() {local res, i, j, k, imin, x, nbin, tmin, tmax, tvl, spikecount \
localobj srt, s1, s2, cnts, d1, d2, gs
s1 = $o1
s2 = $o2
d1 = $o3
d2 = $o4
nbin = $5 // number of ranks (from 0) to do the sorting
if (nbin > pc.nhost) { nbin = pc.nhost }
// for testing, round to the file output resolution
if (1) {
res = .00001
s1.div(res).add(.5).floor().mul(res)
}
spike_count = pc.allreduce(s1.size, 1)
if (pc.id == 0) {printf("spike_count %d, sort using %d ranks\n", spike_count, nbin)}
if (spike_count == 0) { return }
// not always sorted even on per processor basis
srt = s1.sortindex
s1.index(s1, srt)
s2.index(s2, srt)
if (pc.nhost == 1) {
// should also sort gid
d1.copy(s1)
d2.copy(s2)
return
}
// min and max spike time
if (s1.size) {
tmin = s1.x[0]
tmax = s1.x[s1.size-1]
}else{
tmin = 1e9
tmax = 0
}
tmin = pc.allreduce(tmin, 3)
tmax = pc.allreduce(tmax, 2) + 1
// exchange
tvl = (tmax - tmin)/nbin
cnts = new Vector(pc.nhost)
j = 0
for i=0, nbin - 1 {
x = tmin + (i+1)*tvl
k = 0
while (j < s1.size) {
if (s1.x[j] < x) {
j += 1
k += 1
}else{
break
}
}
cnts.x[i] = k
}
pc.alltoall(s1, cnts, d1)
pc.alltoall(s2, cnts, d2)
if (d1.size) {
srt = d1.sortindex
d1.index(d1, srt)
d2.index(d2, srt)
// now sort the gids without destroying the spiketime sort
gs = new Vector()
imin = 0
n = d1.size
for i=1, n {
if (i < n) if (d1.x[imin] == d1.x[i]) {
continue
}
if (i - imin > 1) {
gs.resize(0)
gs.copy(d2, imin, i-1)
gs.sort
d2.copy(gs, imin, 0, gs.size-1)
}
imin = i
}
}
pc.barrier()
}
spike2file_time = 0
spike2file_ncall = 0
proc spike2file() { local i, j, nf, me, ts, nbin, nspike, nbinperfile \
localobj outf, s, vs, vg
ts = startsw()
nbin = pc.nhost // everybody sorts
nf = 1 // everybody (nbin consecutive ranks) serialized into one file
// nbin/nf should be an integer
if (numarg() >= 3) { nbin = $3 }
if (numarg() == 4) { nf = $4 }
if (nbin > pc.nhost) {
nbin = pc.nhost
}
if (nf > nbin) {
nf = nbin
}
nbinperfile = int(nbin/nf)
if (nbinperfile*nf < nbin) { nbinperfile += 1 }
// nhost ranks sort into ranks 0...nbin
// those ranks write into nf files. ie. rank r writes to file
// int(rank/(nbinperfile))
vs = new Vector()
vg = new Vector()
sortspikes($o1, $o2, vs, vg, nbin)
nspike = pc.allreduce(vs.size, 1)
if (pc.id == 0) printf("spike2file %d spikes sorted in %g sec on %d ranks\n", nspike, startsw() - ts, nbin)
s = new String()
// nothing beyond nbin will write a file
outf = new File()
for j=0, nbinperfile-1 { // adjacent ranks that serialize into same file
sprint(s.s, "spk%03d.dat", int(pc.id/nbinperfile))
if (pc.id < nbin && pc.id%(nbinperfile) == j) {
if (j == 0 && spike2file_ncall == 0) {
outf.wopen(s.s)
outf.close()
}
outf.aopen(s.s)
for i=0, vs.size-1 {
//outf.printf("%.5f %d\n", vs.x[i], vg.x[i])
outf.printf("%.10g %d\n", vs.x[i], vg.x[i])
}
outf.close
}
pc.barrier
}
$o1.resize(0)
$o2.resize(0)
spike2file_ncall += 1
ts = startsw() - ts
if (pc.id == 0) printf("spike2file call#%d spikes=%d nbin=%d nfile=%d in %g sec\n", spike2file_ncall, nspike, nbin, nf, ts)
spike2file_time += ts
}
proc spiketest() {localobj sp, id
sp = new Vector(10)
sp.indgen()
id = sp.c.add(1000*pc.id)
spike2file(sp, id, $1, $2)
pc.runworker()
pc.done()
quit()
}