#subset of mitral and granule gids to construct in context of full network
import custom_params
custom_params.filename = 'cfg27'

#mitral and granule that generate the most spikes from a Nglom=10 full run

subset = [185,189,187,188,186]
connection_file = 'dummy.dic'
spike_file = 'g37train.spk2'
weights_file = '' 

import params

params.tstop=1000
import odorstim
odorstim.Fmax['Mint'][37]=0.7
for i in range(0, 37)+range(38,127):
  odorstim.Fmax['Mint'][i] = 0
import net_mitral_centric as nmc # nmc.dc sets up full model gids
from common import *
model = getmodel()
params.orn_std_e=0

''' can construct mitrals and connections dynamically. Granule
connections need to come from a file generated by the full model
'''
def subset_distrib(gids, model):
  model.gids = {gid for gid in gids if gid < ncell}
  model.mitral_gids = {gid for gid in model.gids if gid < nmitral}
  model.granule_gids = model.gids - model.mitral_gids

subset_distrib(subset, model)

print 'mitrals ', model.mitral_gids
print 'granules ', model.granule_gids

def mksubset(model):
  nmc.dc.mk_mitrals(model)
  nmc.dc.mk_mconnection_info(model)
  model.granule_gids = model.gids - model.mitral_gids
  nmc.register_mitrals(model)
  nmc.build_granules(model)
  nmc.register_granules(model)
  get_and_make_subset_connections()
  nmc.build_synapses(model)
  import GJ
  GJ.init_gap_junctions()
  if rank == 0: print "network subset setuptime ", h.startsw() - nmc.t_begin

def get_and_make_subset_connections():
  #get the needed info
  global connection_file
  import bindict
  bindict.load(connection_file)
  subset = chk_consist(bindict.gid_dict, bindict.mgid_dict, bindict.ggid_dict)
  #make them
  import mgrs
  for mdgid in subset:
    c = bindict.gid_dict[mdgid]
    slot = mgrs.gid2mg(mdgid)[3]
    rs = mgrs.mk_mgrs(c[0], c[1], c[2], c[3], 0, c[4], slot)
    model.mgrss.update({rs.md_gid : rs})
  print "%d mgrs created"%len(model.mgrss)

def chk_consist(md_dict, m_dict, g_dict):
  #Incomplete consistency check
  subset = set()
  for mgid in model.mitral_gids:
    for gid in m_dict[mgid]:
      if gid < nmitral:
        assert(gid == mgid)
      elif gid < ncell:
        assert(g_dict.has_key(gid))
      else:
        assert(md_dict.has_key(gid))
        subset.add(gid)
  for ggid in model.granule_gids:
    if g_dict.has_key(ggid):
      for gid in g_dict[ggid]:
        assert(md_dict.has_key(gid+1))
        subset.add(gid+1)
    else:
      print 'granule %d was not used in full model'%ggid
  return subset

def patstim(filename):
  #only read the spikes used, ie mgid, ggid, md_gid, and gd_gid
  wanted = set()
  wanted.update(model.gids)
  for md_gid in model.mgrss:
    rs = model.mgrss[md_gid]
    if rs.md:
      wanted.add(rs.gd_gid)
    if rs.gd:
      wanted.add(rs.md_gid)
  print wanted
  # read the spiketimes
  from binspikes import SpikesReader
  sr = SpikesReader(filename)
  # get the output spikes for each of the cells we simulate in order to be able
  # to verify that the subset simulation is same as full network sim.
  spk_standard = {}
  for gid in model.gids:
    if sr.header.has_key(gid):
      spk_standard.update({gid : sr.retrieve(gid)})
    else:
      spk_standard.update({gid : []})
  # now get all the spikes we need for input (will include the cell output
  # spikes but that can also be helpful for debugging)
  tvec = h.Vector()
  gidvec = h.Vector()
  for gid in wanted:
    if not sr.header.has_key(gid):
      print 'no spikes from %d'%gid
      continue
    spikes = sr.retrieve(gid)
    for t in spikes:
      tvec.append(t)
      gidvec.append(gid)
  # put in spiketime order
  srt = tvec.sortindex()
  tvec.index(tvec, srt)
  gidvec.index(gidvec, srt)
  #make the PatternStim
  ps = h.PatternStim()
  ps.play(tvec, gidvec)
  return (tvec, gidvec, ps, spk_standard)

mksubset(model)


# inits the weights
if weights_file:
  import weightsave
  weightsave.weight_load(weights_file)

import dummysyns
dummysyns.mk_dummy_syns([])

import odorstim
odseq = odorstim.OdorSequence(params.odor_sequence)

ps = patstim(spike_file)

def spkrecord(model):
  #record only the output spikes from cells for camparison with ps[3] spk_standard
  spkvec = h.Vector()
  gidvec = h.Vector()
  for gid in model.gids:
    pc.spike_record(gid, spkvec, gidvec)
  return (spkvec, gidvec)

simspikes = spkrecord(model)

def spkcompare():
  #compare simspikes with the spk_standard in ps[3]
  for gid in model.gids:
    tstd = h.Vector(ps[3][gid])
    # get the gid spikes
    ix = simspikes[1].c().indvwhere("==", gid)
    tsim = h.Vector().index(simspikes[0],ix)
    # is tstd same as tsim up to tstop?
    # first, same count of spikes
    nstd = int(tstd.indwhere(">", h.tstop))
    if nstd < 0: nstd = len(tstd)
    nsim = len(tsim)
    if nstd != nsim:
      print "\nFor gid %d, during interval 0 to tstop=%g, different numbers of spikes %d %d"%(gid, h.tstop, nstd, nsim)
      print "tstd"
      tstd.printf()
      print "tsim"
      tsim.printf()
      return nstd, nsim
    else:
      if tstd.c().sub(tsim).abs().indwhere(">", 1.e-6) == -1.0:
        print "\n%d spike times for gid %d are the same"%(nstd, gid)
      else:
        print "\n%d spike times for gid %d are different, tsim-tstd:"%(nstd, gid)
        tsim.c().sub(tstd).printf()
        for i in range(nstd):
          print "%d %g %g"%(i, tstd.x[i], tsim.x[i])

h.load_file("nrngui.hoc")

#h('proc setdt(){}')
h.dt = 1./64. + 1./128.

rseq = []
def saverand():
  l = h.List('Random')
  for r in l:
    rseq.append(r.seq())
def restorerand():
  l = h.List('Random')
  for i,r in enumerate(l):
    r.seq(rseq[i])
rfih = h.FInitializeHandler(3, restorerand)
saverand()

if __name__ == '__nothing__':
  #h.tstop = 1000
  grphs = {}
  for gid in model.gids:
    g = h.Graph()
    h.addplot(g, 0)
    g.size(0, h.tstop, -80, 60)
    g.addvar('gid%d.soma.v(.5)'%gid, pc.gid2cell(gid).soma(.5)._ref_v)
    grphs.update({gid : g})
  h.run()
  spkcompare()
h.load_file('subsetsim.ses')