# -*- coding: utf-8 -*-
"""
Created on Wed Sep 28 15:04:57 2016
@author: dalbis
"""
# -*- coding: utf-8 -*-
"""
Created on Tue Aug 25 17:30:16 2015
@author: dalbis
"""
import sys
from multiprocessing import Pool
from grid_rate import GridRate
from grid_rate_avg import GridRateAvg
from grid_spikes import GridSpikes
import itertools
import socket
import traceback
import os
import numpy as np
import datetime,time
from simlib import gen_hash_id,format_val,logSim,format_elapsed_time,print_progress,ensureDir,ensureParentDir
from grid_const import ModelType
from grid_functions import map_merge
# maximum processes per host
procs_by_host={'compute1':50,
'compute2':50,
'compute3':8,
'cluster01':20,
'cluster02':5,
'mcclintock':6}
# batch data path
batch_data_folder_map={
ModelType.MODEL_RATE:'../results/grid_rate_batch',
ModelType.MODEL_RATE_AVG:'../results/grid_rate_avg_batch',
ModelType.MODEL_SPIKING:'../results/grid_spiking_batch'
}
def function(sim):
try:
if sim.do_run is True:
sim.post_init(do_print=False)
sim.run(do_print=False)
sim.post_run(do_print=False)
except Exception:
print
print 'Exception in running %s'%sim.dataPath
traceback.print_exc()
class GridBatch(object):
def __init__(self,model_type,batch_default_map,batch_override_map,force=False):
self.model_type=model_type
self.batch_default_map=batch_default_map
self.batch_override_map=batch_override_map
self.force=force
self.startTimeStr=''
self.endTimeStr=''
self.elapsedTime=0
def post_init(self):
##############################################################################
###### CREATE POOL
##############################################################################
self.batch_data_folder=batch_data_folder_map[self.model_type]
ensureDir(self.batch_data_folder)
# create pool
self.host=socket.gethostname()
if self.host in procs_by_host.keys():
self.num_procs=procs_by_host[self.host]
else:
self.num_procs=7
self.pool=Pool(processes=self.num_procs)
self.sims=[]
self.hashes=[]
self.all_par_values=sorted(itertools.product(*self.batch_override_map.values()))
self.batch_override_str=' '.join([ '%s (%s-%s)'%(key,format_val(min(values)),
format_val(max(values))) for key,values in self.batch_override_map.items()])
# loop over all different paramater values
for par_values in self.all_par_values:
override_param_map={k:v for (k,v) in zip(self.batch_override_map.keys(),par_values)}
parMap=map_merge(self.batch_default_map,override_param_map)
if self.model_type == ModelType.MODEL_RATE:
self.sim_class=GridRate
elif self.model_type == ModelType.MODEL_RATE_AVG:
self.sim_class=GridRateAvg
elif self.model_type == ModelType.MODEL_SPIKING:
self.sim_class=GridSpikes
sim=self.sim_class(parMap)
#print sim.hash_id+' Run: %s'%sim.do_run
if self.force:
sim.force_gen_inputs=True
sim.force_gen_corr=True
sim.do_run=True
if sim.do_run is True:
self.sims.append(sim)
self.hashes.append(sim.hash_id)
# generate batch hash
self.batch_hash=gen_hash_id('_'.join(self.hashes))
self.batch_data_path=os.path.join(self.batch_data_folder,'%s_data.npz'%self.batch_hash)
self.batch_params_path=os.path.join(self.batch_data_folder,'%s_params.txt'%self.batch_hash)
self.batch_summary_str=\
"\n\nBATCH HASH: %s\n\nBATCH PARAMS = %s\n\n"%\
(self.batch_hash,
self.batch_override_str
)
print self.batch_summary_str
self.toSaveMap={'hashes':self.hashes,
'batch_override_map':self.batch_override_map,
'batch_default_map':self.batch_default_map
}
if os.path.exists(self.batch_data_path) and not self.force:
return False
else:
print '\n\n*** BATCH DATA NOT PRESENT!! ***\n\n'
print self.batch_data_path
print '%d/%d simulations to be run'%(len(self.sims),len(self.all_par_values))
return True
def run(self):
##############################################################################
###### RUN POOL
##############################################################################
startTime=time.time()
startTimeDate=datetime.datetime.fromtimestamp(time.time())
self.startTimeStr=startTimeDate.strftime('%Y-%m-%d %H:%M:%S')
print 'BATCH MODE: Starting %d/%d processes on %s'%(len(self.sims),self.num_procs,self.host)
for sim in self.sims:
self.pool.apply_async(function, args=(sim,))
self.pool.close()
self.pool.join()
# logging simulation end
endTime=datetime.datetime.fromtimestamp(time.time())
self.endTimeStr=endTime.strftime('%Y-%m-%d %H:%M:%S')
self.elapsedTime =time.time()-startTime
print 'Batch simulation ends: %s'%self.endTimeStr
print 'Elapsed time: %s\n' %format_elapsed_time(self.elapsedTime)
def post_run(self):
#############################################################################
##### MERGE DATA
#############################################################################
print
print 'SIMULATIONS COMPLETED'
print
print 'Merging data...'
sys.stdout.flush()
initial_weights_map={}
final_weights_map={}
final_weight_score_map={}
final_weight_angle_map={}
final_weight_spacing_map={}
final_weight_phase_map={}
final_weight_cx_map={}
evo_weight_scores_map={}
final_rates_map={}
final_rate_score_map={}
final_rate_angle_map={}
final_rate_spacing_map={}
final_rate_phase_map={}
final_rate_cx_map={}
evo_weight_profiles_map={}
start_clock=time.time()
# load/compute data to show for each combination of parameter_values
idx=-1
for chash,par_values in zip(self.hashes,self.all_par_values):
idx+=1
print_progress(idx,len(self.all_par_values),start_clock=start_clock)
sys.stdout.flush()
dataPath=os.path.join(self.sim_class.results_path,'%s_data.npz'%chash)
try:
data=np.load(dataPath,mmap_mode='r')
except Exception:
print 'This file is corrupted: %s'%dataPath
initial_weights_map[par_values]=data['J0']
final_weights_map[par_values]=data['final_weights']
final_weight_score_map[par_values]=data['final_weight_score']
final_weight_angle_map[par_values]=data['final_weight_angle']
final_weight_spacing_map[par_values]=data['final_weight_spacing']
final_weight_phase_map[par_values]=data['final_weight_phase']
final_weight_cx_map[par_values]=data['final_weight_cx']
if 'scores' in data.keys():
evo_weight_scores_map[par_values]=data['scores']
final_rates_map[par_values]=data['final_rates']
final_rate_score_map[par_values]=data['final_rate_score']
final_rate_angle_map[par_values]=data['final_rate_angle']
final_rate_spacing_map[par_values]=data['final_rate_spacing']
final_rate_phase_map[par_values]=data['final_rate_phase']
final_rate_cx_map[par_values]=data['final_rate_cx']
# fourier profiles over time
import gridlib as gl
L=data['paramMap'][()]['L']
n=data['paramMap'][()]['n']
num_snaps=self.batch_default_map['num_snaps']
J_mat=data['J_vect'].reshape(n,n,num_snaps)
weights_dft,weights_freqs,weigths_allfreqs=gl.dft2d_num(J_mat,L,n)
weights_dft_profiles=gl.dft2d_profiles(weights_dft)
evo_weight_profiles_map[par_values]=weights_dft_profiles
mergedDataMap={
'initial_weights_map': initial_weights_map,
'final_weights_map': final_weights_map,
'final_weight_score_map':final_weight_score_map,
'final_weight_angle_map':final_weight_angle_map,
'final_weight_spacing_map':final_weight_spacing_map,
'final_weight_phase_map':final_weight_phase_map,
'final_weight_cx_map':final_weight_cx_map,
'evo_weight_scores_map':evo_weight_scores_map,
'final_rates_map':final_rates_map,
'final_rate_score_map':final_rate_score_map,
'final_rate_angle_map':final_rate_angle_map,
'final_rate_spacing_map':final_rate_spacing_map,
'final_rate_phase_map':final_rate_phase_map,
'final_rate_cx_map':final_rate_cx_map,
'evo_weight_profiles_map':evo_weight_profiles_map,
'weights_freqs':weights_freqs
}
self.toSaveMap=map_merge(self.toSaveMap,mergedDataMap)
# save
ensureParentDir(self.batch_data_path)
logSim(self.batch_hash,self.batch_override_str,self.startTimeStr,self.endTimeStr,self.elapsedTime,self.batch_default_map,self.batch_params_path,doPrint=False)
print
print 'BATCH HASH: %s'%self.batch_hash
np.savez(self.batch_data_path,**self.toSaveMap)
print
print 'Batch data saved in: %s\n'%self.batch_data_path
print