'''
Load balance requires an at present unknown gid distribution that cannot
be calculated til the connections are known. In particular, the complexity
of granules are dominated by their number of MGRS.
Calculation of all the connections and complexities can be accomplished
in parallel if we temporarily use a whole cell gid distribution in which
rank is easily derivable from gid, e.g. rank = gid%nhost. Then it is easy
to communicate the information needed to each rank.
'''

import params
granules = params.granules
from common import *
from all2all import all2all
import util
from gidfunc import *
import mkmitral

t_begin = h.startsw()

def gid2rank(gid):
  return gid%nhost


import lateral_connections as latconn



gc2nconn = {}
for ggid in granules.ggid2pos.keys():
  if (ggid - params.gid_granule_begin) % nhost == rank:
    gc2nconn[ggid] = 0
  
def gid2rank(gid):
  return gid%nhost
    
# return which ranks has splitted the cells
def glom2ranks(glomid):
  return set([ mgid%nhost for mgid in range(glomid*params.Nmitral_per_glom, (glomid+1)*params.Nmitral_per_glom) ] + \
         [ mtgid%nhost for mtgid in range(glomid*params.Nmtufted_per_glom+params.gid_mtufted_begin, (glomid+1)*params.Nmtufted_per_glom+params.gid_mtufted_begin) ])

def ggid2rank(ggid):
  return (ggid - params.gid_granule_begin) % nhost

# ----------------------------------------------------------------------------------------------
# connect a segment to the granule
def connect2gc(cilist, r, gl2gc):
  for i in range(len(cilist)):
    gid = cilist[i][0]
    glomid = mgid2glom(gid) #params.cellid2glomid(gid)
    gcset = gl2gc[glomid]
    try:
      ggid, gisec, gx, gpos = latconn.connect_to_granule(cilist[i], r[gid], gcset)
      cilist[i] = cilist[i][:3]+(ggid, gisec, gx)+(cilist[i][-1],)
      gcset.add(gpos)
    except TypeError:
      cilist[i] = None

# find for intraglomerular connections
def detect_intraglom_conn(cilist, GL_to_GCs):
  # build message
  msg = {}
  for rr in range(nhost): msg[rr] = []
  for ci in cilist:
    if ci:
      glomid = mgid2glom(ci[0]) #params.cellid2glomid(ci[0])
      for rr in glom2ranks(glomid): # ranks to inform are only those > current rank
        if rr == rank:
          continue
        msg[rr].append((glomid, ci[3])) # information must be exchanged
  msg = all2all(msg) # exchange the new conn.
  
  # merge all connections
  tocheck = set()
  for rr, connpair in msg.items():
    if rr >= rank:
      tocheck.update(connpair)
      
    # update connectivity info
    for glomid, ggid in connpair:
      try:
        GL_to_GCs[glomid].add(granules.ggid2pos[ggid])
      except KeyError:
        pass

    
  # distinguish between well vs already existing
  good_pair = []
  bad_pair = []
  for ci in cilist:
    if ci:
      if (mgid2glom(ci[0]), ci[3]) in tocheck:
        bad_pair.append(ci)
      else:
        good_pair.append(ci)
      
  return good_pair, bad_pair

# find for intraglomerular connections
def detect_over_connected_gc(_cilist):

  # granule cells new connections
  msg = {}
  ggid2ci = {}
  
  for _ci in _cilist:
    ggid = _ci[3]
    try:
      msg[ggid2rank(ggid)].append(_ci[3])
    except KeyError:
      msg[ggid2rank(ggid)] = [ _ci[3] ]
      
    try:
      ggid2ci[ggid].append(_ci)
    except KeyError:
      ggid2ci[ggid] = [ _ci ]
      
  msg = all2all(msg)
  
  # check for the over connected
  msg_remove = {}
  for rr, ggids in msg.items():
    for ggid in ggids:
      if gc2nconn[ggid] >= params.granule_nmax_spines:
        try:
          msg_remove[rr].append(ggid)
        except KeyError:
          msg_remove[rr] = [ ggid ]
      else:
        gc2nconn[ggid] += 1

  msg_remove = all2all(msg_remove)

  # return
  good_pair = []
  bad_pair = []
  
  for ggids in msg_remove.values():
    for ggid in ggids:
      bad_pair.append(ggid2ci[ggid][0])
      del ggid2ci[ggid][0]

  for _cilist2 in ggid2ci.values():
    for ci in _cilist2:
      good_pair.append(ci)
      
  return good_pair, bad_pair
  

''' generate the connections for mitral and tufted cells '''
def mk_mconnection_info(model):
  r = {}
  GL_to_GCs = {}
  to_conn = []
  cilist = []

  # initialization
  for gid in model.mitrals.keys(): #+model.mtufted.keys():
    r[gid] = params.ranstream(gid, params.stream_latdendconnect) # init rng
    
    glomid = mgid2glom(gid) #params.cellid2glomid(gid) # init GCs connected to GL
    if glomid not in GL_to_GCs:
      GL_to_GCs[glomid] = set() 


  # lateral dendrites positions
  for cellid, cell in model.mitrals.items(): #+model.mtufted.values():
    to_conn += latconn.lateral_connections(cellid, cell)


  ntot_conn = pc.allreduce(len(to_conn),1) # all connections
  
  # connect to granule cells
  it = 0
  while pc.allreduce(len(to_conn), 2) > 0:
    connect2gc(to_conn, r, GL_to_GCs)
    # good connect vs to redo and update GL_to_GCs
    _cilist, to_conn1 = detect_intraglom_conn(to_conn, GL_to_GCs)
    #_cilist, to_conn2 = detect_over_connected_gc(_cilist)
    #to_conn = to_conn1 + to_conn2
    to_conn = to_conn1
    cilist += _cilist
    it += 1

  ntot_conn = pc.allreduce(len(cilist),1)/ntot_conn
  
  # fill the model data
  MCconn = 0
  mTCconn = 0
  for ci in cilist:
    #if params.gid_is_mitral(ci[0]):
    conns = model.mconnections
    MCconn += 1
    #elif params.gid_is_mtufted(ci[0]):
    #  conns = model.mt_connections
    #  mTCconn += 1
      
    if ci[0] not in conns:
      conns[ci[0]] = []
    conns[ci[0]].append(ci)
    
      
  util.elapsed('Mitral %d and mTufted %d cells connection infos. generated (it=%d,err=%.3g%%)'%(int(pc.allreduce(MCconn,1)),\
                                                                                                       int(pc.allreduce(mTCconn,1)),\
                                                                                                       int(pc.allreduce(it,2)),\
                                                                                                       (1-ntot_conn)*100))


#set of gids on this rank (default round-robin)
def round_robin_distrib(model):
  model.gids = set(range(rank, ncell, nhost))
  model.mitral_gids = set(range(rank, nmitral, nhost))
  model.granule_gids = model.gids - model.mitral_gids

round_robin_distrib(getmodel())

'''
In this section, presume connections determined by m2g_connections.py.
I.e. mitral statistics controlled and cause unknown granule statistics.
'''


def mk_mitrals(model):
  ''' Create all the mitrals specified by mitral_gids set.'''
  model.mitrals = {}
  for gid in model.mitral_gids:
    m = mkmitral.mkmitral(gid)
    model.mitrals.update({gid : m})
  util.elapsed('%d mitrals created and connections to mitrals determined'%int(pc.allreduce(len(model.mitrals),1)))

def mk_gconnection_info_part1(model):
  ''' after mk_gconnection_info_part2()
      rank_gconnections is the connection info for granules on rank ggid%nhost
      also granule_gids are the granules on this rank (granules with no
      connection will not exist)
  '''
  model.rank_gconnections = {}
  for cilist in model.mconnections.values():
    for ci in cilist:
      ggid = ci[3]
      r = gid2rank(ggid)
      if not model.rank_gconnections.has_key(r):
        model.rank_gconnections.update({r : []})
      model.rank_gconnections[r].append(ci)

def mk_gconnection_info_part2(model):
  #transfer the gconnection info to the proper rank and make granule_gids set
  model.rank_gconnections = all2all(model.rank_gconnections)
  util.elapsed('rank_gconnections known')
  model.granule_gids = set([i[3] for r in model.rank_gconnections for i in model.rank_gconnections[r]])
  util.elapsed('granule gids known on each rank')

def mk_gconnection_info(model):
  mk_gconnection_info_part1(model)
  mk_gconnection_info_part2(model)
  util.elapsed('mk_gconnection_info (#granules = %d)'%int(pc.allreduce(len(model.granule_gids),1)))


if __name__ == '__main__':
  model = getmodel()
  mk_mitrals(model)
  mk_mconnection_info(model)
  mk_gconnection_info_part1(model)

  sizes = all2all(model.rank_gconnections, -1)
  for r in util.serialize():
    print rank, " all2all sizes ", sizes

if rank == 0: print "determine_connections ", h.startsw()-t_begin