#encoding: utf-8
"""
analysis.py -- Parallelized data collection and figure creation

AbstractAnalysis provides general functionality for running many 'embarrassingly 
parallel' long-running calls in an IPython.kernel ipcluster in order to collect 
data (e.g., running model simulations for various parameter sets and comparing
results across parameters).

Subclasses provide all relevant data collection and graphics code. Figures
can be created as either Matplotlib figures or Chaco containers.

Written by Joe Monaco, 03/30/2008.
Updated to use IPython.kernel, 01/26/2010.
Copyright (c) 2008-2009 Columbia University. All rights reserved.
Copyright (c) 2010-2011 Johns Hopkins University. All rights reserved.

This software is provided AS IS under the terms of the Open Source MIT License. 
See http://www.opensource.org/licenses/mit-license.php.
"""

# Library imports
from numpy import ndarray
from os import path, makedirs
from matplotlib.figure import Figure

# IPython imports
from IPython.kernel import client as IPclient
from IPython.kernel.multiengineclient import FullBlockingMultiEngineClient as IPmec
from IPython.kernel.taskclient import BlockingTaskClient as IPtc

# Enthought imports
from enthought.chaco.api import BasePlotContainer
from enthought.enable.component_editor import ComponentEditor
from enthought.traits.ui.api import View, Group, Item
from enthought.traits.api import (HasTraits, String, Directory, Either, 
    Trait, Tuple, Instance, false, true)

# Package imports
from . import ANA_DIR
from ..tools.bash import CPrint
from ..tools.path import unique_path

# Constants
PICKLE_FILE_NAME = 'analysis.pickle'


class AbstractAnalysis(HasTraits):
    
    """
    Base functionality for data analysis and plotting of model results
    
    Data collection is initiated by calling an instance of a AbstractAnalysis
    subclass with arguments that get passed to the subclass's collect_data()
    method. Afterwards, the associated figure can be created and displayed 
    by calling the view() method.
    
    Subclasses should override:
    collect_data -- do everything necessary to create the arrays of data that
        you wish to be plotted; these arrays should be stored in the *results*
        dictionary attribute
    create_plots -- all the graphics code for appropriately plotting the data
        collected in *results*; for Matplotlib code, *figure* must be set to a 
        figure handle or a dict of figure handles keyed by filename stubs; for 
        Chaco code, *figure* must be an instance of a BasePlotContainer 
        subclass and can be shown as a Traits View. In use, this should not be 
        called directly (use view() and save_plots()).
    label -- override the trait default to be more descriptive

    Constructor keyword arguments:
    desc -- short phrase describing the analysis being performed
    datadir -- specify path where all the data should be saved
    figure_size -- tuple (width, height) of figure in pixels (600, 800)
    log -- whether to log output messages to analysis.log in the datadir (True)

    Public methods for data collection:
    get_multiengine_client -- gets an ipcontroller multi-engine client and 
        verifies the presence of valid ipengines
    get_task_client -- like above, but returns a task client for queueing
        a sequence of tasks to be farmed out to the ipcluster
    
    Public methods:
    reset -- cleans this analysis object so that it may be called again
    view -- brings up the figure(s) created by create_plots()
    save_plots -- saves the figure as an image file in *datadir*
    save_data -- pickles the *results* dictionary in *datadir*
    """
    
    # Data location
    label = String(__name__.split('.')[-1])
    desc = String
    datadir = Directory
    autosave = true
    
    # Chaco container, a matplotlib figure handle, or dict of handles
    figure = Trait(None, dict, Instance(Figure), Instance(BasePlotContainer))
    figure_size = Tuple((600, 800))
    
    # IPython.kernel clients
    mec = Trait(None, Instance(IPmec))
    tc = Trait(None, Instance(IPtc))
    
    # Console and log output
    log = true
    logfd = Instance(file)
    out = Instance(CPrint)

    # Dictionary for storing collected data; a "finished" flag
    results = Trait(dict)
    finished = false

    # View for Chaco plots (this can be overriden to customize the view)
    traits_view = \
        View(
            Item("figure", editor=ComponentEditor(), show_label=False),
            title='Analysis View',
            resizable=True,
            width=0.5,
            height=0.5,
            kind='nonmodal',
            buttons=['Cancel', 'OK'])

    def __init__(self, **traits):
        HasTraits.__init__(self, **traits)
        
        try:
            if not path.exists(self.datadir):
                makedirs(self.datadir)
        except OSError:
            self.out('Reverting to base directory:\n%s'%ANA_DIR, error=True)
            self.datadir = ANA_DIR
        finally:
            self.datadir = path.abspath(self.datadir)

        self.out('%s initialized:\n%s'%(self.__class__.__name__, str(self)))
    
    def __call__(self, *args, **kw):
        """Execute data collection; this is a wrapper for collect_data
        """
        if self.finished:
            self.out('Already completed. Create new instance or reset.')
            return 
            
        self.out('Running data collection...')
        try:
            self.collect_data(*args, **kw)
        except Exception, e:
            import pdb, os
            self.out('Unhandled exception:\n%s: %s'%(e.__class__.__name__, 
                e.message), error=True)
            pdb.post_mortem(os.sys.exc_info()[2])
        else:
            self.finished = True
            if len(self.results):
                self._save_call_args(*args, **kw)
                self.out('Finished collecting data:\n%s'%'\n'.join(['%s: %s'%
                    (k, self.results[k]) for k in self.results 
                        if type(self.results[k]) is ndarray]))
                if self.autosave:
                    self.save_data()
            else:
                self.out('Warning: No results found! Analysis incomplete?')
        finally:
            if self.logfd and not self.logfd.closed:
                self.logfd.close()
        return
    
    def __str__(self):
        """Column-formatted output of information about this analysis object
        """
        col_w = 16
        s =  ['Subclass:'.ljust(col_w) + self.__class__.__name__]
        if self.desc:
            s += ['Description:'.ljust(col_w) + self.desc]
        s += ['Directory:'.ljust(col_w) + self.datadir]
        if self.mec is not None:
            s += ['Engines:'.ljust(col_w) + str(self.mec.get_ids())]
        else:
            s += ['Engines:'.ljust(col_w) + 'N/A']
        s += ['Autosave:'.ljust(col_w) + str(self.autosave)]
        s += ['Log output:'.ljust(col_w) + str(self.log)]
        s += ['Completed:'.ljust(col_w) + str(self.finished)]
        if self.results:
            s += ['Results:'.ljust(col_w) + '%d items:'%len(self.results)]
            res_list = str(self.results.keys())[1:-1]
            if len(res_list) < 60:
                s += [' '*col_w + res_list]
            else:
                res_split = res_list[:60].split(',')
                res_split[-1] = ' etc.'
                res_list = ','.join(res_split)
                s += [' '*col_w + res_list]
        else:
            s += ['Results:'.ljust(col_w) + 'None']
        return '\n'.join(s)
    
    # Subclass override methods
    
    def collect_data(self, *args, **kw):
        """Subclass override; set the results dictionary
        """
        raise NotImplementedError

    def create_plots(self, *args, **kw):
        """Subclass override; create figure object
        """
        raise NotImplementedError

    # Public methods
    
    def get_multiengine_client(self):
        """Gets a multi-engine client for an ipcontroller
        
        Returns None if a valid connection could not be established.
        """
        if self.mec is not None:
            return self.mec
        
        # Create and return new multi-engine client
        mec = None
        try:
            mec = IPclient.MultiEngineClient()
        except Exception, e:
            self.out('Could not connect to ipcontroller:\n%s: %s'%
                (e.__class__.__name__, e.message), error=True)
        else:
            engines = mec.get_ids()
            N = len(engines)
            if N:
                self.out('Connected to %d ipengine instances:\n%s'%(N, 
                    str(engines)))
            else:
                self.out('No ipengines connected to controller', error=True)
        finally:
            self.mec = mec
        return mec
        
    def get_task_client(self):
        """Gets a task client for an ipcontroller
        
        Returns None if a valid connection could not be established.
        """
        if self.tc is not None:
            return self.tc
        
        # Create and return new task client
        tc = None
        try:
            tc = IPclient.TaskClient()
        except Exception, e:
            self.out('Could not connect to ipcontroller:\n%s: %s'%
                (e.__class__.__name__, e.message), error=True)
        finally:
            self.tc = tc
        return tc

    def reset(self):
        """Reset analysis state so that it can be called again
        """
        self.finished = False
        self.results = {}
        self.datadir = self._datadir_default()
        self.log = False    # close old log file
        self.log = True     # open new log file
        return True
    
    def execute(self, func, *args, **kw):
        """Wrapper for executing long-running function calls in serial
        """
        if not callable(func):
            self.out("Function argument to execute() must be callable",
                error=True)
            return
        
        # Log and execute the call
        self.out('Running %s():\n%s\nArgs: %s\nKeywords: %s'%
            (func.__name__, str(func), args, kw))
        func(*args, **kw)
        return

    def view(self):
        """Bring up the figure for this analysis
        """
        if self._no_data():
            return
        
        self.create_plots()
        success = False
        if isinstance(self.figure, dict) or isinstance(self.figure, Figure):
            self.out('Bringing up MPL figure(s)...')
            from pylab import isinteractive, ion, show
            if not isinteractive():
                ion()
                show()
            success = True
        elif isinstance(self.figure, BasePlotContainer):
            self.out('Bringing up Chaco view...')
            self.configure_traits()
            success = True
        else:
            self.out('No valid figure object found!', error=True)
            
        return success
    
    def save_data(self):
        """Saves the results data for a completed analysis
        """
        if self._no_data():
            return
            
        filename = path.join(self.datadir, PICKLE_FILE_NAME)
        try:
            fd = open(filename, 'w')
        except IOError:
            self.out('Could not open save file!', error=True)
        else:
            import cPickle
            try:
                cPickle.dump(dict(self.results), fd)
            except cPickle.PicklingError, e:
                self.out('PicklingError: %s'%str(e), error=True)
            except TypeError, e:
                self.out('TypeError: %s'%str(e), error=True)
            else:
                self.out('Analysis data save to file:\n%s'%filename)
            finally:
                fd.close()
        return
    
    @classmethod
    def load_data(cls, picklepath='.'):
        """Gets a new analysis object containing saved results data
        """
        if not picklepath.endswith(PICKLE_FILE_NAME.split('.')[-1]):
            picklepath = path.join(picklepath, PICKLE_FILE_NAME)

        picklepath = path.abspath(picklepath)        

        if not path.exists(picklepath):
            raise IOError, 'Analysis data not found: %s'%picklepath
        
        # Open the file
        fd = file(picklepath, 'r')
        
        # Unpickle the results data
        import cPickle, sys
        results = cPickle.load(fd)
        datadir = path.split(picklepath)[0]
        fd.close()
        
        # Create and return the new analysis object
        new_analysis = cls(results=results, datadir=datadir, finished=True,
            log=False)
        CPrint(prefix=cls.__name__)('Results loaded from path:\n%s'%
            picklepath)
        return new_analysis
    
    def save_plots(self, stem='figure', fmt='pdf'):
        """Saves the current results plots as image file(s)
        
        For Chaco plots: if you want to save as a PDF, then the reportlab 
        library is required since kiva.backend_pdf requires a Canvas object. 
        (Available at http://www.reportlab.org/)
        
        Optional keyword arguments:
        stem -- base filename for image file (default 'figure')
        fmt -- specifies image format for saving, either 'pdf' or 'png'
        """
        if self._no_data():
            return
        
        # Validate format specification
        if fmt not in ('pdf', 'png'):
            self.out("Image format must be either 'pdf' or 'png'", error=True)
            return
        
        # Inline function for creating unique image filenames
        get_filename = \
            lambda stem: unique_path(path.join(self.datadir, stem), 
                fmt="%s_%02d", ext=fmt)
        filename_list = []
        
        # Create and save figure(s) as specified
        figure_saved = False
        if isinstance(self.figure, dict):
            for stem in self.figure:
                f = self.figure[stem]
                if isinstance(stem, str) and isinstance(f, Figure):
                    fn = get_filename(stem)
                    f.savefig(fn)
                    filename_list.append(fn)
            filename_list.sort()
            figure_saved = True
        elif isinstance(self.figure, Figure):
            dpi = self.figure.get_dpi()
            self.figure.set_size_inches(
                (self.figure_size[0]/dpi, self.figure_size[1]/dpi))
            fn = get_filename(stem)
            self.figure.savefig(fn)
            filename_list.append(fn)
            figure_saved = True
        elif isinstance(self.figure, BasePlotContainer):
            container = self.figure
            fn = get_filename(stem)
            container.bounds = list(self.figure_size)
            container.do_layout(force=True)
            if fmt is 'png':
                from enthought.kiva.backend_image import GraphicsContext
                gc = GraphicsContext(
                    (container.bounds[0]+1, container.bounds[1]+1))
                container.draw(gc)
                gc.save(fn)
                figure_saved = True
            elif fmt is 'pdf':
                from enthought.kiva.backend_pdf import GraphicsContext
                try:
                    from reportlab.pdfgen.canvas import Canvas
                except ImportError:
                    self.out('Chaco plot PDF generation requires reportlab!',
                        error=True)
                    return
                gc = GraphicsContext(Canvas(fn))
                container.draw(gc)
                gc.save()
                figure_saved = True
            filename_list.append(fn)
        else:
            self.out('Figure object is not valid. Please recreate.', error=True)
        
        # Output results of save operation and return
        if figure_saved:
            self.out('Figure(s) saved as:\n%s'%('\n'.join(filename_list)))
        else:
            self.out('Plots have not been created!', error=True)            
        return figure_saved
        
    # Support methods
    
    def _no_data(self):
        """Whether analysis contains incomplete results data
        """
        good = self.finished and self.results
        if not good:
            self.out('Run data collection first!', error=True)
        return not good    
    
    def _save_call_args(self, *args, **kw):
        """Saves call arguments to a log file in the analysis directory
        """
        try:
            fd = file(path.join(self.datadir, 'call_args.log'), 'a')
        except IOError:
            self.out('Failed to open file for saving call arguments', 
                error='True')
        else:
            s = []
            if args:
                s += ['Arguments: %s\n'%str(args)[1:-1]]
            if kw:
                keys = kw.keys()
                keys.sort()
                s += ['Keywords:']
                s += ['%s = %s'%(k, kw[k]) for k in keys]
            if not (args or kw):
                s = ['-'*60, 'No arguments passed to call', '-'*60]
            fd.write('\n'.join(s) + '\n')
        finally:
            if fd and not fd.closed:
                fd.close()
        return
    
    # Traits change notification handlers
    
    def _logfd_changed(self, old, new):
        if old and not old.closed:
            old.close()
        self.out.outfd = new
    
    def _log_changed(self, logging):
        if logging:
            self.logfd = self._logfd_default()
        else:
            self.logfd = None
                
    # Traits defaults

    def _out_default(self):
        return CPrint(prefix=''.join(self.label.title().split()), 
            outfd=self.logfd)

    def _logfd_default(self):
        if self.log:
            try:
                fd = file(path.join(self.datadir, 'analysis.log'), 'a')
            except IOError:
                self.log = False
            else:
                return fd
        return None
    
    def _datadir_default(self):
        """Set a subdirectory path for the data based on label and description"""
        munge = lambda s: '_'.join(s.strip().lower().split())
        
        sd_stem = munge(self.label) + '-'
        if self.desc:
            sd_stem += munge(self.desc) + '-'
        stem = path.join(ANA_DIR, sd_stem)

        return unique_path(stem)