# encoding: utf-8
"""
search.py -- Basic genetic algorithm parameter space search
    
Concrete subclasses must override the interface methods set_targets,
calculate_outputs, and setup_engines. Also, a sampling function that computes
the results of a parameter point (whose value is passed in by keyword args)
must be either passed in as sample_func in the constructor or set as the 
default trait value for sample_func in the class definition. These results
are passed into calculate_outputs for evaluation according to the fitness
targets specified in set_targets. See doc strings for details.
    
Requires an available IPython.kernel ipcontroller instance.

Created by Joe Monaco on 2009-07-30.
Refactored by Joe Monaco on 2010-01-27.
Copyright (c) 2009 Columbia University. All rights reserved.
Copyright (c) 2010 Johns Hopkins University. All rights reserved.
"""

# Library imports
import numpy as np, os
from IPython.kernel.client import MapTask
from enthought.traits.api import Trait, Function, Float, Tuple, Int, false

# Package imports
from .analysis import AbstractAnalysis


# Example sampling function to be passed in as sample_func
def sample_point(**kwargs):
    """Compute a sample point with the given input parameters
    
    Returns a sample output value or object that will be passed to the
    calculate_outputs method for objective evaluation.
    """
    raise NotImplementedError


class GeneticSearch(AbstractAnalysis):
    
    """
    Performs a basic genetic algorithm search of a parameter space 
    
    See core.analysis.GeneticSearch documentation and collect_data method
    signature and docstring for usage.
    """
    
    label = "genetic search"
    
    # Search function and parameters
    targets = Trait(dict, desc='Objective criteria dictionary')
    min_fitness = Float(0.0, desc='Minimum possible fitness level')
    sample_func = Function(sample_point, desc='Parameter sampling function')
    
    # Plotting control traits
    norm_caxis = false(desc='Normalize fitness colors across generations')
    use_bounds = false(desc='Use initial parameter bounds for all plots')
    figsize_inches = Tuple((11,10))
    marker_size = Int(9)
    
    #
    # Convenience methods
    #
    
    @classmethod
    def get_param_dict(cls, p):
        """Gets a parameter dictionary from a parameter point record
        """
        pdict = {}
        for param in p.dtype.names:
            if param != 'fitness':
                pdict[param] = p[param]
        return pdict
    
    #
    # Subclass override (interface) methods:
    #
        
    def set_targets(self):
        """Populates *targets* dictionary with named objective output targets
        
        Set each objective element as a (target, weight) tuple, e.g.:
        self.targets['sparsity'] = (0.5, 1.0)
        
        Objective targets must be non-zero.
        
        NOTE: Required subclass override.
        """
        raise NotImplementedError('override set_targets to specify objectives')
    
    def calculate_outputs(self, sample):
        """Computes objective output values for the results of a sample point
        
        For each objective target defined by set_targets, calculate the 
        corresponding output values for the given sample results and return
        them in a dictionary.
        
        NOTE: Required subclass override.
        """
        raise NotImplementedError('override calculate_outputs for samples')
    
    def setup_engines(self, mec):
        """Prepares ipengines for running sample_point tasks
        
        NOTE: Optional subclass override.
        """
        pass

    #
    # End interface methods
    #
    
    def get_sample_fitness(self, sample):
        return self.fitness(self.calculate_outputs(sample))

    def fitness(self, outputs):
        """Gets scalar objective fitness value for a set of output criteria
        """
        # Get a sorted list of the objective criteria names
        names = outputs.keys()
        names.sort()
        
        # Collate the actual output values and relative weights
        values = [(outputs[k], self.targets[k][0]) for k in names]
        weights = np.array([self.targets[k][1] for k in names])
        
        # Compute aggregate weighted errors
        norm_error = lambda v: ((max(v)-min(v))/float(max(v)))**2
        errors = np.array([norm_error(v) for v in values])
        E = np.dot(weights, errors) / (weights.sum() + errors.shape[0]) 
    
        # Return fitness as the inverse of aggregate error
        return 1 / E

    def collect_data(self, spawn=96, max_diff=0.1, min_gen=10, conv_tol=0.05, 
        params={}, **kwargs):
        """Run a genetic algorithm search of a defined parameter space
        
        Fitness is determined by using output objective criteria generated by 
        the results of the sample_point and calculate_outputs methods. The
        set_targets method defines the targets for these criteria.
        
        Keyword arguments:
        spawn -- number of individual simulations per generation
        max_diff -- maximum normalized difference from maximum fitness for
            a parameter point to contribute to the next generation
        min_gen -- minimum number of generations to search
        conv_tol -- convergence criterion: if subsequent generations have
            a normalized difference of overall mean fitness less than this
            amount, then convergence has been achieved and the search will
            terminate if at least min_gen generations have been searched
        params -- search parameters with initial bounds; dict where each
            key-value pair specifies, respectively, a parameter name and a 
            (min, max) tuple with the lower and upper search bounds
            
        Additional kwargs are passed on to the sample_func calls.
        """
        # Get ipcontroller clients
        mec = self.get_multiengine_client()
        tc = self.get_task_client()
        
        # Save recarrays of parameters, fitness
        self.results['evolution'] = [] 
        self.results['med_fitness'] = []
        self.results['max_fitness'] = []

        # Clear out the ipengine queues and do custom setup
        mec.clear_queue()
        mec.reset()
        self.setup_engines(mec)
        
        # Store parameter bounds        
        self.results['parameters'] = params
        p_names = params.keys()
        bounds = params.copy()
        
        # Set and store target objective criteria
        self.set_targets()
        self.results['targets'] = self.targets
        
        # Per-generation dict to store parameter points, fitness
        pts = {}
        for p in p_names:
            pts[p] = np.empty(spawn, 'd')
        pts['fitness'] = np.empty(spawn, 'd')
        
        # Set default parameter point from keyword arguments
        pdict = kwargs
        
        # Initialize parameter search radius, center from bounds
        radius = {}
        center = {}
        for k in bounds: 
            radius[k] = (bounds[k][1] - bounds[k][0]) / 2.0
            center[k] = (bounds[k][1] + bounds[k][0]) / 2.0 + \
                np.zeros(spawn, 'd') # centers are per-spawn vectors

        # Initialize loop variables
        gen = 0
        delta = np.inf
        old_fitness = self.min_fitness
        cur_fitness = old_fitness
        
        # Evolutionary loop
        while gen < min_gen or delta > conv_tol:
            
            # Break loop if 'stop' file exists
            if os.path.isfile(os.path.join(self.datadir, 'stop')):
                self.out('Found stop file, breaking at generation %d!'%gen)
                break
            
            # Spawn new individuals
            self.out('Spawning generation-%d tasks'%gen)
            tasks = []
            for i in xrange(spawn):
                for p in p_names:
                    lower = center[p][i] - radius[p]
                    upper = center[p][i] + radius[p]
                    if lower < bounds[p][0]:
                        lower = bounds[p][0]
                    if upper > bounds[p][1]:
                        upper = bounds[p][1]
                    pts[p][i] = \
                        pdict[p] = \
                            lower + (upper - lower) * np.random.rand()
                tasks.append(
                    tc.run(
                        MapTask(self.sample_func, kwargs=pdict)))
            
            # Wait for simulation tasks to complete
            self.out('Now waiting for %d simulations to complete...'%spawn)
            tc.barrier(tasks)
            
            # Collate the fitness results for this generation
            self.out('Storing fitness data for generation %d...'%gen)
            pts['fitness'] = np.array(
                [self.get_sample_fitness(tc.get_task_result(t_id)) 
                    for t_id in tasks])
            self.results['med_fitness'].append(np.median(pts['fitness']))
            self.results['max_fitness'].append(np.max(pts['fitness']))
            tc.clear()
            
            # Append the results data for this generation
            array_list = [pts[p] for p in p_names] + [pts['fitness']]
            array_fmt = ','.join(['d']*(len(p_names)+1))
            self.results['evolution'].append(
                np.rec.fromarrays(array_list, 
                    formats=array_fmt, names=p_names+['fitness']))
            
            # Compute delta median fitness for convergence criterion
            old_fitness, cur_fitness = cur_fitness, np.median(pts['fitness'])
            delta = (cur_fitness - old_fitness) / old_fitness
            self.out('Generational improvement delta: %.2f%%'%(100*delta))
            
            # Seed the next generation (set centers acc. to winners)
            max_unit = np.argmax(pts['fitness'])
            max_fitness = pts['fitness'][max_unit]
            survivors = \
                (pts['fitness'] >= (1.0-max_diff)*max_fitness).nonzero()[0]
            seed_fitness = pts['fitness'][survivors]
            seed_fitness.sort()
            survivors = list(survivors) + [max_unit] # winner = 2x seed
            for i in xrange(spawn):
                for p in p_names:
                    center[p][i] = pts[p][survivors[np.mod(i, len(survivors))]]
            self.out('Fitness of surviving seeds:\n%s'%seed_fitness)
            
            # Set search radius (jitter) commensurate with improvement
            window_factor = 1.0 - 0.75 / (np.exp(-delta)+1)
            for p in p_names:
                radius[p] *= window_factor
            self.out('Reduced search radii by %.2f%%'%(100*(1-window_factor)))
            
            # Increment gen counter
            gen += 1
        
        # End of search console output
        self.out('Genetic search complete: %d generations'%gen)
        self.out('Maximum fitness achieved: %.3f'%(max_fitness))

        # Convert list results to numerical arrays
        self.results['med_fitness'] = np.array(self.results['med_fitness'])
        self.results['max_fitness'] = np.array(self.results['max_fitness'])
        
        # All done!
        self.out('Good-bye!')

    def create_plots(self):
        """Plot per-generation pairwise parameter point plots
    
        Set these traits to control plot output:
        use_bounds -- plots limits all set to initial parameter bounds
        norm_caxis -- normalize color axis across generation figures
        marker_size -- marker size in points
        figsize_inches -- (w, h) tuple specifying figure size in inches
        """
        from matplotlib.pylab import figure, subplot, scatter, ylabel, xlabel, \
            xlim, ylim, rcParams, colorbar, axis, title
        
        # Alias data from the search results
        evo = self.results['evolution']
        params = self.results['parameters']
        p_names = params.keys()
        p_names.sort()
        Np = len(p_names) # number of parameters

        # Set color axis variables and key flag strings
        vargs = {}
        normstr = ''
        boundsstr = ''
        if self.norm_caxis:
            vargs['vmin'] = min([g.fitness.min() for g in evo])
            vargs['vmax'] = max([g.fitness.max() for g in evo])
            normstr = '_norm'
        if self.use_bounds:
            boundsstr = '_full'

        # Create per-generation param x param plots
        rcParams['figure.figsize'] = self.figsize_inches
        self.figure = {}
        for gen in xrange(len(evo)):
            f = figure()
            f.set_size_inches(self.figsize_inches)
            for j in xrange(0,Np-1):
                for i in xrange(j+1,Np):
                    subplot(Np-1, Np-1, (Np-1)*(i-1) + (j+1))
                    scatter(evo[gen][p_names[j]], evo[gen][p_names[i]], 
                        c=evo[gen].fitness, s=self.marker_size, 
                        edgecolors='none', **vargs)
                    if j==0:
                        ylabel(p_names[i])
                    if i==Np-1:
                        xlabel(p_names[j])
                    if j==0 and i==1:
                        title('generation_%02d'%(gen+1))
                    if self.use_bounds:
                        xlim(params[p_names[j]])
                        ylim(params[p_names[i]])
                    else:
                        axis('tight')
                    
            # Add colorbar in upper right
            subplot(Np-1, Np-1, Np-1)
            colorbar(fraction=0.5, aspect=10)
            axis('off')    
            
            # Store this figure in the plots dictionary
            figkey = 'gen%02d%s%s'%(gen+1, normstr, boundsstr)
            self.figure[figkey] = f