#encoding: utf-8
"""
grid.analysis.scan -- AbstractAnalysis subclass for exploring statistics of spatial 
    map properties by scanning a single parameter with even sampling.

Written by Joe Monaco, 05/14/2008.
Copyright (c) 2008 Columbia University. All rights reserved.
"""

# Library imports
import numpy as N, scipy as S, os

# Package imports
from .. import PlaceNetworkStd, CheckeredRatemap, GridCollection
from ..core.analysis import AbstractAnalysis
from ..tools.string import snake2title

# Traits imports
from enthought.traits.api import Enum, Button
from enthought.traits.ui.api import View, Group, Item, Include

# Chaco imports for custom analysis view
from enthought.chaco.api import (ArrayPlotData, ArrayDataSource, Plot, 
    GridContainer)
from enthought.enable.component_editor import ComponentEditor

# Tuples of available placemap data
STAGE_DATA = ('sparsity', 'stage_coverage', 'stage_repr', 'peak_rate')
UNIT_DATA = ('max_rate', 'num_fields', 'coverage')
FIELD_DATA = ('area', 'diameter', 'peak', 'average')
DATA = STAGE_DATA + UNIT_DATA + FIELD_DATA


class BaseScan(AbstractAnalysis):

    label = 'Stat Scan'
    save_current_plot = Button
    
    traits_view = \
        View(
            Item('figure', show_label=False, editor=ComponentEditor()),
            title='Network Population Scan',
            kind='live',
            resizable=True,
            width=0.75,
            height=0.75,
            buttons=['Cancel', 'OK'])
    
    def collect_data(self, ntrials=10, npoints=4, param='J0', bounds=(1, 4), 
        **kwargs):
        raise NotImplementedError 

    def create_plots(self):
        """Create a simple 2D image plot of the parameter sweep"""
        
        # Figure is horizontal container for main plot + colorbar
        self.figure = \
            container = GridContainer(fill_padding=True, spacing=(5,5),
                padding=[20, 20, 40, 10], bgcolor='linen', shape=(3,4))
        
        # Create datasource for means and confidence intervals
        data_dict = {}
        for d in DATA:
            # The means
            data_dict[d] = self.results[d]
            data_dict[d][N.isnan(data_dict[d])] = 0.0
            
            # The 95% confidence intervals
            conf_interval = 1.96*self.results[d + '_err']
            conf_interval[N.isnan(conf_interval)] = 0.0
            data_dict[d + '_err'] = \
                N.r_[self.results[d] + conf_interval,
                    (self.results[d] - conf_interval)[::-1]]
        
        data = ArrayPlotData(
            index=self.results['samples'], 
            err_ix=N.r_[self.results['samples'], self.results['samples'][::-1]], 
            **data_dict)
        
        # Create individual plots and add to grid container
        pad_factor = 0.08
        for d in DATA:
            p = Plot(data, padding=[25, 25, 25, 40])
            styles = {'line_width':1.5, 'color':'darkcyan'}
            
            # Plot the error, line, and scatter plots
            p.plot(('err_ix', d + '_err'), name=d+'_p', type='polygon',
                edge_color='transparent', face_color='silver')
            p.plot(('index', d), name=d+'_l', type='line', **styles)
            p.plot(('index', d), name=d+'_s', type='scatter', marker='circle', 
                marker_size=int(2*styles['line_width']), color=styles['color'],
                line_width=0)
            
            # Y-axis padding
            low = (self.results[d] - self.results[d + '_err']).min()
            high = (self.results[d] + self.results[d + '_err']).max()
            padding = pad_factor * (high - low)
            low -= padding
            high += padding
            if low == high:
                low -= 1
                high += 1
            p.value_range.set_bounds(low, high)
            
            # X-axis padding
            padding = pad_factor * self.results['samples'].ptp()
            p.index_range.set_bounds(
                self.results['samples'][0] - padding, 
                self.results['samples'][-1] + padding)
            
            # Labels, grids and ticks
            p.title = snake2title(d)
            p.x_axis.title = snake2title(self.results['param'])
            p.y_grid.visible = p.x_grid.visible = False
            p.x_axis.tick_in = p.y_axis.tick_in = 0
            
            # Add the plot to the grid container
            container.add(p)


class MultiNetworkScan(BaseScan):
    
    """
    Analyze a 1D scan of regularly-spaced distributions of network simulations
    across parameter space. 
    
    See core.analysis.AbstractAnalysis documentation and collect_data method
    signature and docstring for usage.
    """
    
    label = 'Network Scan'
    
    def collect_data(self, ntrials=10, npoints=5, param='J0', bounds=(0.5, 8), 
        **kwargs): 
        """
        Store statistics about placemap data from multiple trials along a 1D 
        parameter scan
        
        Keyword arguments:
        ntrials -- number of network trials to run per sample point
        npoints -- number of sample points, inclusive of the bounds
        param -- string name of PlaceNetwork parameter to scan
        bounds -- bounds for the parameter scan
        """
        
        # Store bounds and scan parameter
        self.results['bounds'] = N.array(bounds)
        self.results['param']  = param
        
        # Load cortex
        self.out('Creating grid collection object...')
        EC = GridCollection()
        os.chdir(self.datadir)

        # Set default model parameters
        pdict = dict(   EC=EC, 
                        growl=False, 
                        desc='scan', 
                        projdir=self.datadir, 
                        refresh_weights=True, 
                        refresh_orientation=False,
                        refresh_phases=False,
                        refresh_traj=False,
                        traj_type='checker',
                        num_trials=ntrials,
                        monitoring=True)
        pdict.update(kwargs)
        
        # Update with keyword arguments
        if param not in PlaceNetworkStd().traits(user=True).keys():
            raise ValueError, 'param (%s) is not a user parameter'%param
        
        # Create the list of sample points to scan
        self.out('Creating %s scan vector from %.2f to %.2f'%((param,)+bounds))
        if bounds[0] > bounds[1]:
            bounds = bounds[::-1]
        pts = N.linspace(bounds[0], bounds[1], num=npoints)
        self.results['samples'] = pts
        
        # Initialize stage map sample data arrays
        sparsity = N.empty(npoints, 'd')
        sparsity_err = N.empty(npoints, 'd')
        stage_coverage = N.empty(npoints, 'd')
        stage_coverage_err = N.empty(npoints, 'd')
        stage_repr = N.empty(npoints, 'd')
        stage_repr_err = N.empty(npoints, 'd')
        peak_rate = N.empty(npoints, 'd')
        peak_rate_err = N.empty(npoints, 'd')
        
        # Initialize per-unit sample data arrays
        max_rate = N.zeros(npoints, 'd')
        max_rate_err = N.zeros(npoints, 'd')
        num_fields = N.zeros(npoints, 'd')
        num_fields_err = N.zeros(npoints, 'd')
        coverage = N.zeros(npoints, 'd')
        coverage_err = N.zeros(npoints, 'd')

        # Initialize per-field sample data arrays
        area = N.zeros(npoints, 'd')
        area_err = N.zeros(npoints, 'd')
        diameter = N.zeros(npoints, 'd')
        diameter_err = N.zeros(npoints, 'd')
        peak = N.zeros(npoints, 'd')
        peak_err = N.zeros(npoints, 'd')
        average = N.zeros(npoints, 'd')
        average_err = N.zeros(npoints, 'd')
        
        # Error calculation
        def error(values):
            return N.std(values) / N.sqrt(len(values))
        
        # Per-sample data collection method
        def run_sample_point(i, model):
            self.out('Running (%d): %s = %.4f'%(i, param, getattr(model, param)))
            
            # Run the model simulation and save the results
            model.advance_all()

            # Create ratemap objects
            ir_list = [None] * ntrials
            fdata_list = [None] * ntrials
            udata_list = [None] * ntrials
            for trial in xrange(ntrials):
                ir = CheckeredRatemap(model.post_mortem(trial=trial+1))
                ir.compute_coverage()
                ir_list[trial] = ir
                fdata_list[trial] = ir.get_field_data()
                udata_list[trial] = ir.get_unit_data()
            
            # Collate the stage map data
            sparsity[i] = N.mean([ir.sparsity for ir in ir_list])
            sparsity_err[i] = error([ir.sparsity for ir in ir_list])
            stage_coverage[i] = N.mean([ir.stage_coverage for ir in ir_list])
            stage_coverage_err[i] = error([ir.stage_coverage for ir in ir_list])
            stage_repr[i] = N.mean([ir.stage_repr for ir in ir_list])
            stage_repr_err[i] = error([ir.stage_repr for ir in ir_list])
            peak_rate[i] = N.mean([ir.peak_rate for ir in ir_list])
            peak_rate_err[i] = error([ir.peak_rate for ir in ir_list])
            
            # Collate the per-unit data
            _max_rate = N.array([], 'd')
            _num_fields = N.array([], 'd')
            _coverage = N.array([], 'd')

            for udata in udata_list:
                _max_rate = N.r_[_max_rate, udata['max_r']]
                _num_fields = N.r_[_num_fields, udata['num_fields']]
                _coverage = N.r_[_coverage, udata['coverage']]

            max_rate[i] = _max_rate.mean()
            max_rate_err[i] = error(_max_rate)
            num_fields[i] = _num_fields.mean()
            num_fields_err[i] = error(_num_fields)
            coverage[i] = _coverage.mean()
            coverage_err[i] = error(_coverage)
            
            # Collate the per-field data
            _area = N.array([], 'd')
            _diameter = N.array([], 'd')
            _peak = N.array([], 'd')
            _average = N.array([], 'd')
            
            for fdata in fdata_list:
                _area = N.r_[_area, fdata['area']]
                _diameter = N.r_[_diameter, fdata['diameter']]
                _peak = N.r_[_peak, fdata['peak']]
                _average = N.r_[_average, fdata['average']]
            
            area[i] = _area.mean()
            area_err[i] = error(_area)
            diameter[i] = _diameter.mean()
            diameter_err[i] = error(_diameter)
            peak[i] = _peak.mean()
            peak_err[i] = error(_peak)
            average[i] = _average.mean()
            average_err[i] = error(_average)
                
        # Execute data collection process for each sample point
        self.out('Beginning data collection process')
        for i, p in enumerate(pts):
            pdict[param] = p
            self.execute(run_sample_point, i, PlaceNetworkStd(**pdict))
        
        # Store the mean data results
        self.results['sparsity'] = sparsity
        self.results['stage_coverage'] = stage_coverage
        self.results['stage_repr'] = stage_repr
        self.results['peak_rate'] = peak_rate
        self.results['max_rate'] = max_rate
        self.results['num_fields'] = num_fields
        self.results['coverage'] = coverage
        self.results['area'] = area
        self.results['diameter'] = diameter
        self.results['peak'] = peak
        self.results['average'] = average
        
        # ... and the error data
        self.results['sparsity_err'] = sparsity_err
        self.results['stage_coverage_err'] = stage_coverage_err
        self.results['stage_repr_err'] = stage_repr_err
        self.results['peak_rate_err'] = peak_rate_err
        self.results['max_rate_err'] = max_rate_err
        self.results['num_fields_err'] = num_fields_err
        self.results['coverage_err'] = coverage_err
        self.results['area_err'] = area_err
        self.results['diameter_err'] = diameter_err
        self.results['peak_err'] = peak_err
        self.results['average_err'] = average_err
        
        # Good-bye!
        self.out('All done!')