#encoding: utf-8
"""
analysis.py -- Parallelized data collection and figure creation
BaseAnalysis provides general functionality for running many 'embarrassingly
parallel' long-running calls in an IPython.kernel ipcluster in order to collect
data (e.g., running model simulations for various parameter sets and comparing
results across parameters).
Subclasses provide all relevant data collection and graphics code. Figures
can be created as Matplotlib figures.
Written by Joe Monaco, 03/30/2008.
Updated to use IPython.kernel, 01/26/2010.
Copyright (c) 2008-2009 Columbia University. All rights reserved.
Copyright (c) 2010-2011 Johns Hopkins University. All rights reserved.
This software is provided AS IS under the terms of the Open Source MIT License.
See http://www.opensource.org/licenses/mit-license.php.
"""
# Library imports
from os import path as _path, makedirs
from matplotlib.figure import Figure
from numpy import ndarray as _ndarray
# IPython imports
from IPython.kernel import client as IPclient
from IPython.kernel.multiengineclient import FullBlockingMultiEngineClient as IPmec
from IPython.kernel.taskclient import BlockingTaskClient as IPtc
# Traits imports
from enthought.traits.api import HasTraits, String, Directory, Either, \
Trait, Tuple, Instance, false, true
# Package imports
from . import ANA_DIR
from ..tools.bash import CPrint
from ..tools.path import unique_path
# Constants
PICKLE_FILE_NAME = 'analysis.pickle'
class BaseAnalysis(HasTraits):
"""
Base functionality for data analysis and plotting of model results
Data collection is initiated by calling an instance of a BaseAnalysis
subclass with arguments that get passed to the subclass's collect_data()
method. Afterwards, the associated figure can be created and displayed
by calling the view() method.
Subclasses should override:
collect_data -- do everything necessary to create the arrays of data that
you wish to be plotted; these arrays should be stored in the *results*
dictionary attribute
create_plots -- all the graphics code for appropriately plotting the data
collected in *results*; for Matplotlib code, *figure* must be set to a
figure handle or a dict of figure handles keyed by filename stubs. In
use, this should not be called directly (use view() and save_plots()).
label -- override the trait default to be more descriptive
Constructor keyword arguments:
desc -- short phrase describing the analysis being performed
datadir -- specify path where all the data should be saved
figure_size -- tuple (width, height) of figure in pixels (600, 800)
log -- whether to log output messages to analysis.log in the datadir (True)
Public methods for data collection:
get_multiengine_client -- gets an ipcontroller multi-engine client and
verifies the presence of valid ipengines
get_task_client -- like above, but returns a task client for queueing
a sequence of tasks to be farmed out to the ipcluster
Public methods:
reset -- cleans this analysis object so that it may be called again
view -- brings up the figure(s) created by create_plots()
save_plots -- saves the figure as an image file in *datadir*
save_data -- pickles the *results* dictionary in *datadir*
"""
# Data location
label = String(__name__.split('.')[-1])
desc = String
datadir = Directory
autosave = true
# A matplotlib figure handle or dict of handles
figure = Trait(None, dict, Instance(Figure))
figure_size = Tuple((600, 800))
# IPython.kernel clients
mec = Trait(None, Instance(IPmec))
tc = Trait(None, Instance(IPtc))
# Console and log output
log = true
logfd = Instance(file)
out = Instance(CPrint)
# Dictionary for storing collected data; a "finished" flag
results = Trait(dict)
finished = false
def __init__(self, **traits):
HasTraits.__init__(self, **traits)
try:
if not _path.exists(self.datadir):
makedirs(self.datadir)
except OSError:
self.out('Reverting to base directory:\n%s'%ANA_DIR, error=True)
self.datadir = ANA_DIR
finally:
self.datadir = _path.abspath(self.datadir)
self.out('%s initialized:\n%s'%(self.__class__.__name__, str(self)))
def __call__(self, *args, **kw):
"""Execute data collection; this is a wrapper for collect_data
"""
if self.finished:
self.out('Already completed. Create new instance or reset.')
return
self.out('Running data collection...')
try:
self.collect_data(*args, **kw)
except Exception, e:
import pdb, os
self.out('Unhandled exception:\n%s: %s'%(e.__class__.__name__,
e.message), error=True)
pdb.post_mortem(os.sys.exc_info()[2])
else:
self.finished = True
if len(self.results):
self._save_call_args(*args, **kw)
self.out('Finished collecting data:\n%s'%'\n'.join(['%s: %s'%
(k, self.results[k]) for k in self.results
if type(self.results[k]) is _ndarray]))
if self.autosave:
self.save_data()
else:
self.out('Warning: No results found! Analysis incomplete?')
finally:
if self.logfd and not self.logfd.closed:
self.logfd.close()
return
def __str__(self):
"""Column-formatted output of information about this analysis object
"""
col_w = 16
s = ['Subclass:'.ljust(col_w) + self.__class__.__name__]
if self.desc:
s += ['Description:'.ljust(col_w) + self.desc]
s += ['Directory:'.ljust(col_w) + self.datadir]
if self.mec is not None:
s += ['Engines:'.ljust(col_w) + str(self.mec.get_ids())]
else:
s += ['Engines:'.ljust(col_w) + 'N/A']
s += ['Autosave:'.ljust(col_w) + str(self.autosave)]
s += ['Log output:'.ljust(col_w) + str(self.log)]
s += ['Completed:'.ljust(col_w) + str(self.finished)]
if self.results:
s += ['Results:'.ljust(col_w) + '%d items:'%len(self.results)]
res_list = str(self.results.keys())[1:-1]
if len(res_list) < 60:
s += [' '*col_w + res_list]
else:
res_split = res_list[:60].split(',')
res_split[-1] = ' etc.'
res_list = ','.join(res_split)
s += [' '*col_w + res_list]
else:
s += ['Results:'.ljust(col_w) + 'None']
return '\n'.join(s)
# Subclass override methods
def collect_data(self, *args, **kw):
"""Subclass override; set the results dictionary
"""
raise NotImplementedError
def create_plots(self, *args, **kw):
"""Subclass override; create figure object
"""
raise NotImplementedError
# Public methods
def get_multiengine_client(self):
"""Gets a multi-engine client for an ipcontroller
Returns None if a valid connection could not be established.
"""
if self.mec is not None:
return self.mec
# Create and return new multi-engine client
mec = None
try:
mec = IPclient.MultiEngineClient()
except Exception, e:
self.out('Could not connect to ipcontroller:\n%s: %s'%
(e.__class__.__name__, e.message), error=True)
else:
engines = mec.get_ids()
N = len(engines)
if N:
self.out('Connected to %d ipengine instances:\n%s'%(N,
str(engines)))
else:
self.out('No ipengines connected to controller', error=True)
finally:
self.mec = mec
return mec
def get_task_client(self):
"""Gets a task client for an ipcontroller
Returns None if a valid connection could not be established.
"""
if self.tc is not None:
return self.tc
# Create and return new task client
tc = None
try:
tc = IPclient.TaskClient()
except Exception, e:
self.out('Could not connect to ipcontroller:\n%s: %s'%
(e.__class__.__name__, e.message), error=True)
finally:
self.tc = tc
return tc
def reset(self):
"""Reset analysis state so that it can be called again
"""
self.finished = False
self.results = {}
self.datadir = self._datadir_default()
self.log = False # close old log file
self.log = True # open new log file
return True
def execute(self, func, *args, **kw):
"""Wrapper for executing long-running function calls in serial
"""
if not callable(func):
self.out("Function argument to execute() must be callable",
error=True)
return
# Log and execute the call
self.out('Running %s():\n%s\nArgs: %s\nKeywords: %s'%
(func.__name__, str(func), args, kw))
func(*args, **kw)
return
def view(self):
"""Bring up the figure for this analysis
"""
if self._no_data():
return
self.create_plots()
success = False
if isinstance(self.figure, dict) or isinstance(self.figure, Figure):
self.out('Bringing up MPL figure(s)...')
from pylab import isinteractive, ion, show
if not isinteractive():
ion()
show()
success = True
else:
self.out('No valid figure object found!', error=True)
return success
def save_data(self):
"""Saves the results data for a completed analysis
"""
if self._no_data():
return
filename = _path.join(self.datadir, PICKLE_FILE_NAME)
try:
fd = open(filename, 'w')
except IOError:
self.out('Could not open save file!', error=True)
else:
import cPickle
try:
cPickle.dump(dict(self.results), fd)
except cPickle.PicklingError, e:
self.out('PicklingError: %s'%str(e), error=True)
except TypeError, e:
self.out('TypeError: %s'%str(e), error=True)
else:
self.out('Analysis data save to file:\n%s'%filename)
finally:
fd.close()
return
@classmethod
def load_data(cls, pickle_path='.'):
"""Gets a new analysis object containing saved results data
"""
if not pickle_path.endswith(PICKLE_FILE_NAME.split('.')[-1]):
pickle_path = _path.join(pickle_path, PICKLE_FILE_NAME)
pickle_path = _path.abspath(pickle_path)
if not _path.exists(pickle_path):
raise IOError, 'Analysis data not found: %s'%pickle_path
# Open the file
fd = file(pickle_path, 'r')
# Unpickle the results data
import cPickle, sys
results = cPickle.load(fd)
datadir = _path.split(pickle_path)[0]
fd.close()
# Create and return the new analysis object
new_analysis = cls(results=results, datadir=datadir, finished=True,
log=False)
CPrint(prefix=cls.__name__)('Results loaded from path:\n%s'%
pickle_path)
return new_analysis
def save_plots(self, stem='figure', fmt='pdf'):
"""Saves the current results plots as image file(s)
Optional keyword arguments:
stem -- base filename for image file (default 'figure')
fmt -- specifies image format for saving, either 'pdf' or 'png'
"""
if self._no_data():
return
# Validate format specification
if fmt not in ('pdf', 'png'):
self.out("Image format must be either 'pdf' or 'png'", error=True)
return
# Inline function for creating unique image filenames
get_filename = \
lambda stem: unique_path(_path.join(self.datadir, stem),
fmt="%s_%02d", ext=fmt)
filename_list = []
# Create and save figure(s) as specified
figure_saved = False
if isinstance(self.figure, dict):
for stem in self.figure:
f = self.figure[stem]
if isinstance(stem, str) and isinstance(f, Figure):
fn = get_filename(stem)
f.savefig(fn)
filename_list.append(fn)
filename_list.sort()
figure_saved = True
elif isinstance(self.figure, Figure):
dpi = self.figure.get_dpi()
self.figure.set_size_inches(
(self.figure_size[0]/dpi, self.figure_size[1]/dpi))
fn = get_filename(stem)
self.figure.savefig(fn)
filename_list.append(fn)
figure_saved = True
else:
self.out('Figure object is not valid. Please recreate.', error=True)
# Output results of save operation and return
if figure_saved:
self.out('Figure(s) saved as:\n%s'%('\n'.join(filename_list)))
else:
self.out('Plots have not been created!', error=True)
return figure_saved
# Support methods
def _no_data(self):
"""Whether analysis contains incomplete results data
"""
good = self.finished and self.results
if not good:
self.out('Run data collection first!', error=True)
return not good
def _save_call_args(self, *args, **kw):
"""Saves call arguments to a log file in the analysis directory
"""
try:
fd = file(_path.join(self.datadir, 'call_args.log'), 'a')
except IOError:
self.out('Failed to open file for saving call arguments',
error='True')
else:
s = []
if args:
s += ['Arguments: %s\n'%str(args)[1:-1]]
if kw:
keys = kw.keys()
keys.sort()
s += ['Keywords:']
s += ['%s = %s'%(k, kw[k]) for k in keys]
if not (args or kw):
s = ['-'*60, 'No arguments passed to call', '-'*60]
fd.write('\n'.join(s) + '\n')
finally:
if fd and not fd.closed:
fd.close()
return
# Traits change notification handlers
def _logfd_changed(self, old, new):
if old and not old.closed:
old.close()
self.out.outfd = new
def _log_changed(self, logging):
if logging:
self.logfd = self._logfd_default()
else:
self.logfd = None
# Traits defaults
def _out_default(self):
return CPrint(prefix=''.join(self.label.title().split()),
outfd=self.logfd)
def _logfd_default(self):
if self.log:
try:
fd = file(_path.join(self.datadir, 'analysis.log'), 'a')
except IOError:
self.log = False
else:
return fd
return None
def _datadir_default(self):
"""Set a subdirectory path for the data based on label and description"""
munge = lambda s: '_'.join(s.strip().lower().split())
sd_stem = munge(self.label) + '-'
if self.desc:
sd_stem += munge(self.desc) + '-'
stem = _path.join(ANA_DIR, sd_stem)
return unique_path(stem)