from util import *
from all2all import all2all
import heapq
def lpt(cx, npart):
''' from the list of (cx, gid) return a npart length list with each partition
being a total_cx followed by a list of (cx, gid).
'''
cx.sort(key=lambda x:x[0], reverse=True)
# initialize a priority queue for fast determination of current
# partition with least complexity. The priority queue always has
# npart items in it. At this time we do not care which partition will
# be associated with which rank so a partition on the heap is just
# (totalcx, [list of (cx, gid)]
h = []
for i in range(npart):
heapq.heappush(h, (0.0, []))
#each cx item goes into the current least complex partition
for c in cx:
lp = heapq.heappop(h) # least partition
lp[1].append(c)
heapq.heappush(h, (lp[0]+c[0], lp[1]))
parts = [heapq.heappop(h) for i in range(len(h))]
return parts
def statistics(parts):
npart = len(parts)
total_cx = 0
max_part_cx = 0
ncx = 0
max_cx = 0
for part in parts:
ncx += len(part[1])
total_cx += part[0]
if part[0] > max_part_cx:
max_part_cx = part[0]
for cx in part[1]:
if cx[0] > max_cx:
max_cx = cx[0]
avg_part_cx =total_cx/npart
loadbal = 1.0
if max_part_cx > 0.:
loadbal = avg_part_cx/max_part_cx
s = "loadbal=%g total_cx=%g npart=%d ncx=%d max_part_cx=%g max_cx=%g"%(loadbal,total_cx,npart,ncx,max_part_cx, max_cx)
return s
if __name__ == '__main__':
from util import serialize, finish
for cx in ([(i, i) for i in range(10)],[]):
print len(cx), ' complexity items ', cx
pinfo = lpt(cx, 3)
print len(pinfo), ' lpt partitions ', pinfo
print statistics(pinfo)