#encoding: utf-8
"""
analysis.py -- Parallelized data collection and figure creation

BaseAnalysis provides general functionality for running many 'embarrassingly 
parallel' long-running calls in an IPython.kernel ipcluster in order to collect 
data (e.g., running model simulations for various parameter sets and comparing
results across parameters).

Subclasses provide all relevant data collection and graphics code. Figures
can be created as Matplotlib figures.

Written by Joe Monaco, 03/30/2008.
Updated to use IPython.kernel, 01/26/2010.
Copyright (c) 2008-2009 Columbia University. All rights reserved.
Copyright (c) 2010-2011 Johns Hopkins University. All rights reserved.

This software is provided AS IS under the terms of the Open Source MIT License. 
See http://www.opensource.org/licenses/mit-license.php.
"""

# Library imports
from os import path as _path, makedirs
from matplotlib.figure import Figure
from numpy import ndarray as _ndarray

# IPython imports
from IPython.kernel import client as IPclient
from IPython.kernel.multiengineclient import FullBlockingMultiEngineClient as IPmec
from IPython.kernel.taskclient import BlockingTaskClient as IPtc

# Traits imports
from enthought.traits.api import HasTraits, String, Directory, Either, \
    Trait, Tuple, Instance, false, true

# Package imports
from . import ANA_DIR
from ..tools.bash import CPrint
from ..tools.path import unique_path

# Constants
PICKLE_FILE_NAME = 'analysis.pickle'


class BaseAnalysis(HasTraits):
    
    """
    Base functionality for data analysis and plotting of model results
    
    Data collection is initiated by calling an instance of a BaseAnalysis
    subclass with arguments that get passed to the subclass's collect_data()
    method. Afterwards, the associated figure can be created and displayed 
    by calling the view() method.
    
    Subclasses should override:
    collect_data -- do everything necessary to create the arrays of data that
        you wish to be plotted; these arrays should be stored in the *results*
        dictionary attribute
    create_plots -- all the graphics code for appropriately plotting the data
        collected in *results*; for Matplotlib code, *figure* must be set to a 
        figure handle or a dict of figure handles keyed by filename stubs. In 
        use, this should not be called directly (use view() and save_plots()).
    label -- override the trait default to be more descriptive

    Constructor keyword arguments:
    desc -- short phrase describing the analysis being performed
    datadir -- specify path where all the data should be saved
    figure_size -- tuple (width, height) of figure in pixels (600, 800)
    log -- whether to log output messages to analysis.log in the datadir (True)

    Public methods for data collection:
    get_multiengine_client -- gets an ipcontroller multi-engine client and 
        verifies the presence of valid ipengines
    get_task_client -- like above, but returns a task client for queueing
        a sequence of tasks to be farmed out to the ipcluster
    
    Public methods:
    reset -- cleans this analysis object so that it may be called again
    view -- brings up the figure(s) created by create_plots()
    save_plots -- saves the figure as an image file in *datadir*
    save_data -- pickles the *results* dictionary in *datadir*
    """
    
    # Data location
    label = String(__name__.split('.')[-1])
    desc = String
    datadir = Directory
    autosave = true
    
    # A matplotlib figure handle or dict of handles
    figure = Trait(None, dict, Instance(Figure))
    figure_size = Tuple((600, 800))
    
    # IPython.kernel clients
    mec = Trait(None, Instance(IPmec))
    tc = Trait(None, Instance(IPtc))
    
    # Console and log output
    log = true
    logfd = Instance(file)
    out = Instance(CPrint)

    # Dictionary for storing collected data; a "finished" flag
    results = Trait(dict)
    finished = false

    def __init__(self, **traits):
        HasTraits.__init__(self, **traits)
        
        try:
            if not _path.exists(self.datadir):
                makedirs(self.datadir)
        except OSError:
            self.out('Reverting to base directory:\n%s'%ANA_DIR, error=True)
            self.datadir = ANA_DIR
        finally:
            self.datadir = _path.abspath(self.datadir)

        self.out('%s initialized:\n%s'%(self.__class__.__name__, str(self)))
    
    def __call__(self, *args, **kw):
        """Execute data collection; this is a wrapper for collect_data
        """
        if self.finished:
            self.out('Already completed. Create new instance or reset.')
            return 
            
        self.out('Running data collection...')
        try:
            self.collect_data(*args, **kw)
        except Exception, e:
            import pdb, os
            self.out('Unhandled exception:\n%s: %s'%(e.__class__.__name__, 
                e.message), error=True)
            pdb.post_mortem(os.sys.exc_info()[2])
        else:
            self.finished = True
            if len(self.results):
                self._save_call_args(*args, **kw)
                self.out('Finished collecting data:\n%s'%'\n'.join(['%s: %s'%
                    (k, self.results[k]) for k in self.results 
                        if type(self.results[k]) is _ndarray]))
                if self.autosave:
                    self.save_data()
            else:
                self.out('Warning: No results found! Analysis incomplete?')
        finally:
            if self.logfd and not self.logfd.closed:
                self.logfd.close()
        return
    
    def __str__(self):
        """Column-formatted output of information about this analysis object
        """
        col_w = 16
        s =  ['Subclass:'.ljust(col_w) + self.__class__.__name__]
        if self.desc:
            s += ['Description:'.ljust(col_w) + self.desc]
        s += ['Directory:'.ljust(col_w) + self.datadir]
        if self.mec is not None:
            s += ['Engines:'.ljust(col_w) + str(self.mec.get_ids())]
        else:
            s += ['Engines:'.ljust(col_w) + 'N/A']
        s += ['Autosave:'.ljust(col_w) + str(self.autosave)]
        s += ['Log output:'.ljust(col_w) + str(self.log)]
        s += ['Completed:'.ljust(col_w) + str(self.finished)]
        if self.results:
            s += ['Results:'.ljust(col_w) + '%d items:'%len(self.results)]
            res_list = str(self.results.keys())[1:-1]
            if len(res_list) < 60:
                s += [' '*col_w + res_list]
            else:
                res_split = res_list[:60].split(',')
                res_split[-1] = ' etc.'
                res_list = ','.join(res_split)
                s += [' '*col_w + res_list]
        else:
            s += ['Results:'.ljust(col_w) + 'None']
        return '\n'.join(s)
    
    # Subclass override methods
    
    def collect_data(self, *args, **kw):
        """Subclass override; set the results dictionary
        """
        raise NotImplementedError

    def create_plots(self, *args, **kw):
        """Subclass override; create figure object
        """
        raise NotImplementedError

    # Public methods
    
    def get_multiengine_client(self):
        """Gets a multi-engine client for an ipcontroller
        
        Returns None if a valid connection could not be established.
        """
        if self.mec is not None:
            return self.mec
        
        # Create and return new multi-engine client
        mec = None
        try:
            mec = IPclient.MultiEngineClient()
        except Exception, e:
            self.out('Could not connect to ipcontroller:\n%s: %s'%
                (e.__class__.__name__, e.message), error=True)
        else:
            engines = mec.get_ids()
            N = len(engines)
            if N:
                self.out('Connected to %d ipengine instances:\n%s'%(N, 
                    str(engines)))
            else:
                self.out('No ipengines connected to controller', error=True)
        finally:
            self.mec = mec
        return mec
        
    def get_task_client(self):
        """Gets a task client for an ipcontroller
        
        Returns None if a valid connection could not be established.
        """
        if self.tc is not None:
            return self.tc
        
        # Create and return new task client
        tc = None
        try:
            tc = IPclient.TaskClient()
        except Exception, e:
            self.out('Could not connect to ipcontroller:\n%s: %s'%
                (e.__class__.__name__, e.message), error=True)
        finally:
            self.tc = tc
        return tc

    def reset(self):
        """Reset analysis state so that it can be called again
        """
        self.finished = False
        self.results = {}
        self.datadir = self._datadir_default()
        self.log = False    # close old log file
        self.log = True     # open new log file
        return True
    
    def execute(self, func, *args, **kw):
        """Wrapper for executing long-running function calls in serial
        """
        if not callable(func):
            self.out("Function argument to execute() must be callable",
                error=True)
            return
        
        # Log and execute the call
        self.out('Running %s():\n%s\nArgs: %s\nKeywords: %s'%
            (func.__name__, str(func), args, kw))
        func(*args, **kw)
        return

    def view(self):
        """Bring up the figure for this analysis
        """
        if self._no_data():
            return
        
        self.create_plots()
        success = False
        if isinstance(self.figure, dict) or isinstance(self.figure, Figure):
            self.out('Bringing up MPL figure(s)...')
            from pylab import isinteractive, ion, show
            if not isinteractive():
                ion()
                show()
            success = True
        else:
            self.out('No valid figure object found!', error=True)
            
        return success
    
    def save_data(self):
        """Saves the results data for a completed analysis
        """
        if self._no_data():
            return
            
        filename = _path.join(self.datadir, PICKLE_FILE_NAME)
        try:
            fd = open(filename, 'w')
        except IOError:
            self.out('Could not open save file!', error=True)
        else:
            import cPickle
            try:
                cPickle.dump(dict(self.results), fd)
            except cPickle.PicklingError, e:
                self.out('PicklingError: %s'%str(e), error=True)
            except TypeError, e:
                self.out('TypeError: %s'%str(e), error=True)
            else:
                self.out('Analysis data save to file:\n%s'%filename)
            finally:
                fd.close()
        return
    
    @classmethod
    def load_data(cls, pickle_path='.'):
        """Gets a new analysis object containing saved results data
        """
        if not pickle_path.endswith(PICKLE_FILE_NAME.split('.')[-1]):
            pickle_path = _path.join(pickle_path, PICKLE_FILE_NAME)

        pickle_path = _path.abspath(pickle_path)        

        if not _path.exists(pickle_path):
            raise IOError, 'Analysis data not found: %s'%pickle_path
        
        # Open the file
        fd = file(pickle_path, 'r')
        
        # Unpickle the results data
        import cPickle, sys
        results = cPickle.load(fd)
        datadir = _path.split(pickle_path)[0]
        fd.close()
        
        # Create and return the new analysis object
        new_analysis = cls(results=results, datadir=datadir, finished=True,
            log=False)
        CPrint(prefix=cls.__name__)('Results loaded from path:\n%s'%
            pickle_path)
        return new_analysis
    
    def save_plots(self, stem='figure', fmt='pdf'):
        """Saves the current results plots as image file(s)
        
        Optional keyword arguments:
        stem -- base filename for image file (default 'figure')
        fmt -- specifies image format for saving, either 'pdf' or 'png'
        """
        if self._no_data():
            return
        
        # Validate format specification
        if fmt not in ('pdf', 'png'):
            self.out("Image format must be either 'pdf' or 'png'", error=True)
            return
        
        # Inline function for creating unique image filenames
        get_filename = \
            lambda stem: unique_path(_path.join(self.datadir, stem), 
                fmt="%s_%02d", ext=fmt)
        filename_list = []
        
        # Create and save figure(s) as specified
        figure_saved = False
        if isinstance(self.figure, dict):
            for stem in self.figure:
                f = self.figure[stem]
                if isinstance(stem, str) and isinstance(f, Figure):
                    fn = get_filename(stem)
                    f.savefig(fn)
                    filename_list.append(fn)
            filename_list.sort()
            figure_saved = True
        elif isinstance(self.figure, Figure):
            dpi = self.figure.get_dpi()
            self.figure.set_size_inches(
                (self.figure_size[0]/dpi, self.figure_size[1]/dpi))
            fn = get_filename(stem)
            self.figure.savefig(fn)
            filename_list.append(fn)
            figure_saved = True
        else:
            self.out('Figure object is not valid. Please recreate.', error=True)
        
        # Output results of save operation and return
        if figure_saved:
            self.out('Figure(s) saved as:\n%s'%('\n'.join(filename_list)))
        else:
            self.out('Plots have not been created!', error=True)            
        return figure_saved
        
    # Support methods
    
    def _no_data(self):
        """Whether analysis contains incomplete results data
        """
        good = self.finished and self.results
        if not good:
            self.out('Run data collection first!', error=True)
        return not good    
    
    def _save_call_args(self, *args, **kw):
        """Saves call arguments to a log file in the analysis directory
        """
        try:
            fd = file(_path.join(self.datadir, 'call_args.log'), 'a')
        except IOError:
            self.out('Failed to open file for saving call arguments', 
                error='True')
        else:
            s = []
            if args:
                s += ['Arguments: %s\n'%str(args)[1:-1]]
            if kw:
                keys = kw.keys()
                keys.sort()
                s += ['Keywords:']
                s += ['%s = %s'%(k, kw[k]) for k in keys]
            if not (args or kw):
                s = ['-'*60, 'No arguments passed to call', '-'*60]
            fd.write('\n'.join(s) + '\n')
        finally:
            if fd and not fd.closed:
                fd.close()
        return
    
    # Traits change notification handlers
    
    def _logfd_changed(self, old, new):
        if old and not old.closed:
            old.close()
        self.out.outfd = new
    
    def _log_changed(self, logging):
        if logging:
            self.logfd = self._logfd_default()
        else:
            self.logfd = None
                
    # Traits defaults

    def _out_default(self):
        return CPrint(prefix=''.join(self.label.title().split()), 
            outfd=self.logfd)

    def _logfd_default(self):
        if self.log:
            try:
                fd = file(_path.join(self.datadir, 'analysis.log'), 'a')
            except IOError:
                self.log = False
            else:
                return fd
        return None
    
    def _datadir_default(self):
        """Set a subdirectory path for the data based on label and description"""
        munge = lambda s: '_'.join(s.strip().lower().split())
        
        sd_stem = munge(self.label) + '-'
        if self.desc:
            sd_stem += munge(self.desc) + '-'
        stem = _path.join(ANA_DIR, sd_stem)

        return unique_path(stem)