from common import *
h.load_file("spike2file.hoc")
idvec = h.Vector()
idvec.buffer_size(5000000)
spikevec = h.Vector()
spikevec.buffer_size(5000000)
n_spkout_files = max(nhost/64, 1) # each file contains spikes from 64 ranks
n_spkout_sort = min(n_spkout_files*8, nhost) #each file serializes from 8 ranks
# so each sorting rank gathers spikes from nhost/n_spkout_sort ranks
checkpoint_interval = 100000
clean_weights_active = False
clean_weights_interval = 10500.0
from weightsave import weight_reset as clean_weights, weight_file
def prun(tstop):
isaved=0
cvode = h.CVode()
cvode.cache_efficient(1)
#pc.spike_compress(0,0,1)
pc.setup_transfer()
#pc.timeout(0)
mindelay = pc.set_maxstep(10)
if rank == 0: print 'mindelay = %g'%mindelay
runtime = h.startsw()
exchtime = pc.wait_time()
inittime = h.startsw()
cvode.active(0)
# if rank == 0: print 'cvode active=', cvode.active()
h.stdinit()
inittime = h.startsw() - inittime
if rank == 0:
if clean_weights_active:
print 'weights reset active at %g ms' % clean_weights_interval
else:
print 'weights reset not active'
print 'init time = %g'%inittime
tnext_clean = clean_weights_interval
while h.t < tstop:
told = h.t
tnext = h.t + checkpoint_interval
if tnext > tstop:
tnext = tstop
#if clean_weights_active:
#while tnext_clean < tnext:
#pc.psolve(tnext_clean)
#clean_weights()
#tnext_clean += clean_weights_interval
pc.psolve(tnext)
# if rank == 0:
# print 'sim. checkpoint at %g' % h.t
if h.t == told:
if rank == 0:
print "psolve did not advance time from t=%.20g to tnext=%.20g\n"%(h.t, tnext)
break
# weight_file(params.filename+('.%d'%isaved))
# save spikes and dictionary in a binary format to
# make them more comprimibles
import binsave
binsave.save(params.filename, spikevec, idvec)
# h.spike2file(params.filename, spikevec, idvec, n_spkout_sort, n_spkout_files)
runtime = h.startsw() - runtime
comptime = pc.step_time()
splittime = pc.vtransfer_time(1)
gaptime = pc.vtransfer_time()
exchtime = pc.wait_time() - exchtime
if rank == 0: print 'runtime = %g'% runtime
printperf([comptime, exchtime, splittime, gaptime])
def printperf(p):
avgp = []
maxp = []
header = ['comp','spk','split','gap']
for i in p:
avgp.append(pc.allreduce(i, 1)/nhost)
maxp.append(pc.allreduce(i, 2))
if rank > 0:
return
b = avgp[0]/maxp[0]
print 'Load Balance = %g'% b
print '\n ',
for i in header: print '%12s'%i,
print '\n avg ',
for i in avgp: print '%12.2f'%i,
print '\n max ',
for i in maxp: print '%12.2f'%i,
print ''
if __name__ == '__main__':
import common
import util
common.nmitral = 1
common.ncell = 2
import net_mitral_centric as nmc
nmc.build_net_roundrobin(getmodel())
pc.spike_record(-1, spikevec, idvec)
from odorstim import OdorStim
from odors import odors
ods = OdorStim(odors['Apple'])
ods.setup(nmc.mitrals, 10., 20., 100.)
prun(200.)
util.finish()