'''
Load balance requires an at present unknown gid distribution that cannot
be calculated til the connections are known. In particular, the complexity
of granules are dominated by their number of MGRS.
Calculation of all the connections and complexities can be accomplished
in parallel if we temporarily use a whole cell gid distribution in which
rank is easily derivable from gid, e.g. rank = gid%nhost. Then it is easy
to communicate the information needed to each rank.
'''
import params
granules = params.granules
from common import *
from all2all import all2all
import util
from gidfunc import *
import mkmitral
t_begin = h.startsw()
def gid2rank(gid):
return gid%nhost
import lateral_connections as latconn
gc2nconn = {}
for ggid in granules.ggid2pos.keys():
if (ggid - params.gid_granule_begin) % nhost == rank:
gc2nconn[ggid] = 0
def gid2rank(gid):
return gid%nhost
# return which ranks has splitted the cells
def glom2ranks(glomid):
return set([ mgid%nhost for mgid in range(glomid*params.Nmitral_per_glom, (glomid+1)*params.Nmitral_per_glom) ] + \
[ mtgid%nhost for mtgid in range(glomid*params.Nmtufted_per_glom+params.gid_mtufted_begin, (glomid+1)*params.Nmtufted_per_glom+params.gid_mtufted_begin) ])
def ggid2rank(ggid):
return (ggid - params.gid_granule_begin) % nhost
# ----------------------------------------------------------------------------------------------
# connect a segment to the granule
def connect2gc(cilist, r, gl2gc):
for i in range(len(cilist)):
gid = cilist[i][0]
glomid = mgid2glom(gid) #params.cellid2glomid(gid)
gcset = gl2gc[glomid]
try:
ggid, gisec, gx, gpos = latconn.connect_to_granule(cilist[i], r[gid], gcset)
cilist[i] = cilist[i][:3]+(ggid, gisec, gx)+(cilist[i][-1],)
gcset.add(gpos)
except TypeError:
cilist[i] = None
# find for intraglomerular connections
def detect_intraglom_conn(cilist, GL_to_GCs):
# build message
msg = {}
for rr in range(nhost): msg[rr] = []
for ci in cilist:
if ci:
glomid = mgid2glom(ci[0]) #params.cellid2glomid(ci[0])
for rr in glom2ranks(glomid): # ranks to inform are only those > current rank
if rr == rank:
continue
msg[rr].append((glomid, ci[3])) # information must be exchanged
msg = all2all(msg) # exchange the new conn.
# merge all connections
tocheck = set()
for rr, connpair in msg.items():
if rr >= rank:
tocheck.update(connpair)
# update connectivity info
for glomid, ggid in connpair:
try:
GL_to_GCs[glomid].add(granules.ggid2pos[ggid])
except KeyError:
pass
# distinguish between well vs already existing
good_pair = []
bad_pair = []
for ci in cilist:
if ci:
if (mgid2glom(ci[0]), ci[3]) in tocheck:
bad_pair.append(ci)
else:
good_pair.append(ci)
return good_pair, bad_pair
# find for intraglomerular connections
def detect_over_connected_gc(_cilist):
# granule cells new connections
msg = {}
ggid2ci = {}
for _ci in _cilist:
ggid = _ci[3]
try:
msg[ggid2rank(ggid)].append(_ci[3])
except KeyError:
msg[ggid2rank(ggid)] = [ _ci[3] ]
try:
ggid2ci[ggid].append(_ci)
except KeyError:
ggid2ci[ggid] = [ _ci ]
msg = all2all(msg)
# check for the over connected
msg_remove = {}
for rr, ggids in msg.items():
for ggid in ggids:
if gc2nconn[ggid] >= params.granule_nmax_spines:
try:
msg_remove[rr].append(ggid)
except KeyError:
msg_remove[rr] = [ ggid ]
else:
gc2nconn[ggid] += 1
msg_remove = all2all(msg_remove)
# return
good_pair = []
bad_pair = []
for ggids in msg_remove.values():
for ggid in ggids:
bad_pair.append(ggid2ci[ggid][0])
del ggid2ci[ggid][0]
for _cilist2 in ggid2ci.values():
for ci in _cilist2:
good_pair.append(ci)
return good_pair, bad_pair
''' generate the connections for mitral and tufted cells '''
def mk_mconnection_info(model):
r = {}
GL_to_GCs = {}
to_conn = []
cilist = []
# initialization
for gid in model.mitrals.keys(): #+model.mtufted.keys():
r[gid] = params.ranstream(gid, params.stream_latdendconnect) # init rng
glomid = mgid2glom(gid) #params.cellid2glomid(gid) # init GCs connected to GL
if glomid not in GL_to_GCs:
GL_to_GCs[glomid] = set()
# lateral dendrites positions
for cellid, cell in model.mitrals.items(): #+model.mtufted.values():
to_conn += latconn.lateral_connections(cellid, cell)
ntot_conn = pc.allreduce(len(to_conn),1) # all connections
# connect to granule cells
it = 0
while pc.allreduce(len(to_conn), 2) > 0:
connect2gc(to_conn, r, GL_to_GCs)
# good connect vs to redo and update GL_to_GCs
_cilist, to_conn1 = detect_intraglom_conn(to_conn, GL_to_GCs)
#_cilist, to_conn2 = detect_over_connected_gc(_cilist)
#to_conn = to_conn1 + to_conn2
to_conn = to_conn1
cilist += _cilist
it += 1
ntot_conn = pc.allreduce(len(cilist),1)/ntot_conn
# fill the model data
MCconn = 0
mTCconn = 0
for ci in cilist:
#if params.gid_is_mitral(ci[0]):
conns = model.mconnections
MCconn += 1
#elif params.gid_is_mtufted(ci[0]):
# conns = model.mt_connections
# mTCconn += 1
if ci[0] not in conns:
conns[ci[0]] = []
conns[ci[0]].append(ci)
util.elapsed('Mitral %d and mTufted %d cells connection infos. generated (it=%d,err=%.3g%%)'%(int(pc.allreduce(MCconn,1)),\
int(pc.allreduce(mTCconn,1)),\
int(pc.allreduce(it,2)),\
(1-ntot_conn)*100))
#set of gids on this rank (default round-robin)
def round_robin_distrib(model):
model.gids = set(range(rank, ncell, nhost))
model.mitral_gids = set(range(rank, nmitral, nhost))
model.granule_gids = model.gids - model.mitral_gids
round_robin_distrib(getmodel())
'''
In this section, presume connections determined by m2g_connections.py.
I.e. mitral statistics controlled and cause unknown granule statistics.
'''
def mk_mitrals(model):
''' Create all the mitrals specified by mitral_gids set.'''
model.mitrals = {}
for gid in model.mitral_gids:
m = mkmitral.mkmitral(gid)
model.mitrals.update({gid : m})
util.elapsed('%d mitrals created and connections to mitrals determined'%int(pc.allreduce(len(model.mitrals),1)))
def mk_gconnection_info_part1(model):
''' after mk_gconnection_info_part2()
rank_gconnections is the connection info for granules on rank ggid%nhost
also granule_gids are the granules on this rank (granules with no
connection will not exist)
'''
model.rank_gconnections = {}
for cilist in model.mconnections.values():
for ci in cilist:
ggid = ci[3]
r = gid2rank(ggid)
if not model.rank_gconnections.has_key(r):
model.rank_gconnections.update({r : []})
model.rank_gconnections[r].append(ci)
def mk_gconnection_info_part2(model):
#transfer the gconnection info to the proper rank and make granule_gids set
model.rank_gconnections = all2all(model.rank_gconnections)
util.elapsed('rank_gconnections known')
model.granule_gids = set([i[3] for r in model.rank_gconnections for i in model.rank_gconnections[r]])
util.elapsed('granule gids known on each rank')
def mk_gconnection_info(model):
mk_gconnection_info_part1(model)
mk_gconnection_info_part2(model)
util.elapsed('mk_gconnection_info (#granules = %d)'%int(pc.allreduce(len(model.granule_gids),1)))
if __name__ == '__main__':
model = getmodel()
mk_mitrals(model)
mk_mconnection_info(model)
mk_gconnection_info_part1(model)
sizes = all2all(model.rank_gconnections, -1)
for r in util.serialize():
print rank, " all2all sizes ", sizes
if rank == 0: print "determine_connections ", h.startsw()-t_begin