# encoding: utf-8
"""
model.py -- AbstractModel is a framework for writing time-series models

AbstractModel provides a Traits-based foundation for models. It provides a flexible
messaging service with Messager, time-series data collection and tracking with 
TimeSeries, exception handling and saves results and complete information 
about each simulation. 

Model subclasses simply define all the Traits and support methods needed for 
the scientific problem and then implement per-trial setup (trial_setup()) and
the per-timestep computation (run_timestep()). AbstractModel handles the rest.

AbstractModel -- Abstract modeling framework
ModelPostMortem -- Skeleton object storing completed simulation results

Copyright (c) 2007, 2008 Columbia University. All rights reserved.

This software is provided AS IS under the terms of the Open Source MIT License. 
See http://www.opensource.org/licenses/mit-license.php.
"""

# Library imports 
import time, os

# Package imports
from . import TS_FMT, MODEL_MESSAGES as MSG
from .timeseries import TimeSeries, TSPostMortem
from ..tools.array_container import ArrayContainer
from ..tools.messager import Messager
from ..tools.path import unique_path

# Traits imports
from enthought.traits.api import (HasTraits, Property, Instance, Directory, 
    File, Trait, Range, List, String, Int, Float, true, false)
        

class ModelPostMortem(ArrayContainer):
    
    """
    Converts AbstractModel.results data into attributes of skeleton object
    
    The shape of the array attributes depends on whether one or more trials were
    run in the model simulation:
    - num_trials == 1: time is the first dimension
    - num_trials > 1: trials x time are the first two dimensions 
    - in both cases: the shape of the tracked data follows time. 
    
    Public methods:
    get_trial_data -- single-trial post-mortem for a particular trial (trials are 
        indexed starting from 1)
    
    The **tracking** and **ntrials** attributes are set as a convenience.
    """ 
    
    def __init__(self, data):
        """Transform AbstractModel.results into array attributes on this object
        
        For multi-trial data, each data attribute is a numpy object array where
        each element is a trial data array. This avoids copying data since non-
        contiguous ndarray creation is iffy at this point.
        """
        from numpy import array
        ntrials = len(data)
        self.ntrials = ntrials
        if ntrials == 0:
            return
        self.tracking = data[0].tracking
        self.t = data[0].t
        _d = {}
        for key in self.tracking:
            if ntrials == 1:
                _d[key] = getattr(data[0], key)
            else:
                _d[key] = array([None]*ntrials, dtype='O')
                for trial in xrange(ntrials):
                    _d[key][trial] = getattr(data[trial], key)
        self.__dict__.update(_d)
    
    def get_trial_data(self, trial):
        """Extract a single trial of data from this object
        
        Required argument:
        trial -- trial number, indexed from 1
        
        Returns a single-trial data object.
        """
        if self.ntrials == 1:
            return self
        if trial < 1 or trial > self.ntrials:
            raise ValueError, 'bad trial number (%d)'%trial

        trial = int(trial)
        _d = {}
        for key in self.tracking:
            _d[key] = getattr(self, key)[trial-1]
        _d['t'] = self.t
        
        return ModelPostMortem([TSPostMortem(**_d)])


class AbstractModel(HasTraits):

    """
    Generalized base model class for constructing a time-series model
    
    Model parameters and settings should be specified as keyword arguments to 
    the subclass constructor.
    
    Attributes are categorized by Trait metadata:
    track -- setting to True will cause attribute to be tracked as a time-series
    user -- specify whether this attribute is a user-settable parameter
    
    Model attributes:
    label -- short name for the model (default module name)
       * this will be used as the project subdirectory
    desc -- descriptive phrase to be added to filenames (default label)
    app_name -- full name for this model (default label)
    icon_path -- full path to the model icon (default mandle_joe.png)
    console -- bool value toggles console messages (default True)
    num_trials -- number of trials to run (default 1)
    
    Important attributes:
    out -- the Messager object that handles messaging
    ts -- the TimeSeries object that handles all data collection
    trial -- current trial number
    trial_result -- if no variables are being tracked by *ts*, then the
        attribute whose name is stored in trial_result will be saved as the
        trial result data for each trial
    t -- current time in seconds of the current trial
    done -- whether all trials have completed
    pause -- setting to True pauses simulation (advance() restarts)
    monitoring -- (bool) toggles progress monitoring messages

    Public methods:
    info -- display a nice complete description of this model
    post_mortem -- get a ModelPostMortem object for the completed data
    reset -- trash any progress and reset the simulation
    advance -- run the next trial (or resume from pausing)
    advance_all -- run through all trials
    save_data -- save simulation data and a model information file
    parameter_dict -- get a dictionary of all user-settable parameters

    Subclasses override these methods:
    trial_setup -- setup to be performed prior to each trial
    run_timestep -- the per-timestep computation for this model; this is 
        wrapped by a while loop within advance(), e.g.:
            while self.ts:
                self.run_timestep()
                self.ts.advance()
    """
    
    # Output control
    out = Instance(Messager)
    
    # Program flow
    pause = false(user=True, label='Pause', desc='Toggle active state of model')
    done = false(user=True, label='Stop', desc='End model simulation now')
    
    # Metadata
    label = String(__name__)
    desc = String
    app_name = label
    
    # Progress monitoring
    monitoring = true(user=True)
    monitor = Property(false)
    monitor_dt = Float(1)
    
    # General model machinery
    trial_result = String
    timestamp = Trait(time.strftime(TS_FMT))
    success = true
    num_trials = Range(low=1)
    results = List
    trial = Int(1)
    ts = Instance(TimeSeries)
    dt = Trait(TimeSeries.__class_traits__['dt'])
    T = Trait(TimeSeries.__class_traits__['T'])
    t = Property(Float)
    
    def __init__(self, **traits):
        HasTraits.__init__(self, **traits)
        self._renew_timeseries()
        
    def __str__(self): 
        if self:
            progress = '\nAt: Trial %d of %d @ %.2f/%.2fs\n'%(self.trial, 
                self.num_trials, self.t, self.T)
        else:
            progress = '\nDone: Completed %d trials out of %d\n'% \
                (len(self.results), self.num_trials)
        user_params = self.traits(user=True).keys()
        user_params.sort()
        hdr = self.__class__.__name__ + '(Model) object\n' + '-'*32
        p_str = \
            '\nParameters:\n' + '\n'.join(['\t%s : %s'%(k, repr(getattr(self, k))) \
                for k in user_params])
        if not self.traits(track=True):
            t_str = ''
        else:
            t_str = \
                '\nTracking:\n\t' + ', '.join(self.traits(track=True).keys())
        return hdr + p_str + t_str
            
    def __repr__(self): 
        return self.__str__().split('\n')[0]
    
    def __nonzero__(self): 
        return not self.done 
    
    # Subclass override methods
    
    def trial_setup(self):
        """Subclass override: perform trial setup
        """
        raise NotImplementedError
        
    def run_timestep(self):
        """Subclass override: simulation time-step
        """
        raise NotImplementedError
    
    # Public methods
    
    def info(self): 
        self.out(self, 'Model information')
    
    def post_mortem(self): 
        """Get a PostMortem object for this instance
        """
        return ModelPostMortem(self.results)

    def reset(self):
        """Wipe out all results and go back to initial state
        """
        self.done = False
        self.pause = False
        self.trial = 1
        self.results = []
        self._renew_timeseries()
        self.out('Model has been reset!', 'Reset', 'init')
    
    def advance(self):
        """Run the current trial, save data, and handle exceptions
        """
        if self.done or self.pause:
            return time.sleep(.2)
        if self.t and self.ts:
            self.out('Unpausing trial %d...'%self.trial, notification='monitor')
        else:
            self.out('Starting trial %d...'%self.trial, notification='init')
            self.trial_setup()
            self._renew_timeseries()
        try:
            self._run_trial()
        except KeyboardInterrupt:
            self.out(*MSG['TRIALSTOP'])
            self.done = raw_input('> ').lower() == 'stop'
        except Exception, e:
            import pdb
            self.out(MSG['UNHEXC'][0]%(self.trial, self.t, repr(e).split('(')[0], 
                e.message), MSG['UNHEXC'][1], MSG['UNHEXC'][2])
            pdb.post_mortem(os.sys.exc_info()[2])
            self.success = False
            self.done = True
        else:
            if not self.pause:
                self._store_trial()
                if len(self.results) < self.trial:
                    self.trial = len(self.results)
                    self.out(MSG['TRIALRST'][0]%(self.trial+1), MSG['TRIALRST'][1], 
                        MSG['TRIALRST'][2])
                self.done = self.trial == self.num_trials
                self.trial += not self.done
        finally:
            if self.done:
                self.out('Finished simulation!', notification='complete')
    
    def advance_all(self):
        """Run through all trials
        """
        while not self.done:
            self.advance()

    def save_data(self, dpath='.'):
        """Save pickle files of model info and results data
        
        Saved file details:
        X.info -- a text file describing various properties of the model, its 
            state, and its parameters.
        X.tar.gz -- an ArrayContainer format archive of the post-mortem data of
            this model's results
        """     
        import cPickle
        if not self._create_datapath(dpath):
            self.out(MSG['DPATH'][0]%dpath, MSG['DPATH'][1], MSG['DPATH'][2])
            return

        fn_title = 'Save model data'
        stem_fn = '_'.join(self.label.lower().split())
        if self.desc:
            stem_fn += '-' + '_'.join(self.desc.lower().split())
        info_fn = unique_path(os.path.join(dpath, stem_fn), ext='info', 
            fmt='%s-%03d')
        stem = info_fn[:-4]
        data_fn = stem
        try:
            info_fd = open(info_fn, 'w')
        except IOError:
            self.out('Could not open file(s) for writing:\n->' + stem + 
                '{info,data}', fn_title, 'error')
        else:            
            # Write out the info file
            self._write_info_file(info_fd, stem)
            info_fd.close()
            
            # Save the post-mortem data in a compressed archive
            self.post_mortem().tofile(data_fn)
            
            # All done!
            self.out('Saved files:\n->%s\n->%s'%(info_fn, data_fn+'tar.gz'), 
                fn_title, sticky=True)
        
    def parameter_dict(self):
        """Get a dictionary with all user-settable parameters for this model
        """
        params = self.traits(user=True).keys()
        pdict = {}
        for p in params:
            pdict[p] = getattr(self, p)
        return pdict

    # Traits notifications and properties

    def _get_monitor(self): 
        return bool(self.t % self.monitor_dt < 0.5001*self.dt)
    
    def _get_t(self):
        return self.ts.t
    
    def _pause_changed(self, paused):
        if not paused:
            self.advance()
    
    def _out_default(self):
        return Messager(title=self.label + ' Simulation')
    
    # Private support methods
    
    def _run_trial(self):
        """Run a trial by wrapping subclass timestep method
        """
        while self.ts:
            
            # Enable pausing functionality
            if self.pause: 
                break
            
            # Subclass-provided timestep computation
            self.run_timestep()
            
            # Progress messages
            self._handle_monitor_msg()        
            
            # Store current state and advance the timeseries
            self.ts.advance()

    def _handle_monitor_msg(self):
        """Handle progress monitoring messages
        """
        if self.monitoring and self.monitor:
            self.out('t = %.2fsec (%.1f%%)'%(self.t, 100*self.t/self.T), 
                notification='monitor')
    
    def _renew_timeseries(self):
        """Construct a fresh data tracker
        """
        self.ts = TimeSeries(self, dt=self.dt, T=self.T)
        if self.monitor_dt < self.T / 10.0:
            self.monitor_dt = int(self.T) / 10.0
    
    def _store_trial(self):
        try:
            self.ts.finish()
        except Exception, e:
            self.out(MSG['SAVEFAIL'][0]%e.message, MSG['SAVEFAIL'][1], 
                MSG['SAVEFAIL'][2])
        else:
            if self.ts.tracking == [] and self.trial_result != '':
                self.results.append(getattr(self, self.trial_result))
                self.out('Trial %d data (\'%s\') saved!'%(self.trial, 
                    self.trial_result), 'Trial Saved')
            else:
                self.results.append(self.ts.post_mortem())
                self.out('Trial %d data saved!'%self.trial, 'Trial Saved')
    
    def _create_datapath(self, dpath):
        if not os.path.isdir(dpath):
            if os.path.isfile(dpath):
                self.out(MSG['DPFILE'][0]%dpath, MSG['DPFILE'][1], 
                    MSG['DPFILE'][2])
                return False
            else:
                try:
                    os.makedirs(dpath)
                except OSError:
                    self.out(MSG['DPFAIL'][0]%dpath, MSG['DPFAIL'][1], 
                        MSG['DPFAIL'][2])
                    return False
                else:
                    self.out(MSG['DPCREATE'][0]%dpath, MSG['DPCREATE'][1], 
                        MSG['DPCREATE'][2])
        return True

    def _write_info_file(self, fd, stem):
        from numpy import asarray
        tstamp = time.strftime(TS_FMT)
        div = '='*70
        info = [div, 'Simulation Information File', div]
        div = '-'*70
        info += ['', 'Model subclass    : ' + self.__class__.__name__]
        info += ['Subclass module   : ' + self.__module__, '']
        info += ['-> Instantiated : ' + self.timestamp, 
            '-> Saved        : ' + tstamp, '']
        if not self.success:
            info += ['', '  *** NOTE: This simulation encountered errors! ***', '']
        info += [div, self.__class__.__name__+' Simulation:', div]
        info += ['', div, self.__class__.__name__+' Parameters:', div]
        info += [' * ' + str(k).ljust(15) + '= ' + str(getattr(self, k)) 
            for k in self.traits(user=True)]
        info += ['', div, 'Collected Time-Series:', div]
        info += [' * ' + str(self.ts).split('\n')[0]]
        info += [' * Progress: trial %d @ %.2f seconds'%(self.trial, self.t)]
        if self.success:
            info += [' * Completed successfully!']
        else:
            info += [' * Errors were detected, check results!']
        info += ['', ' * Tracked Variables:']
        info += ['\t- %s : %s'%(k, str(asarray(getattr(self, k)).shape)) 
            for k in self.traits(track=True)]
        info += ['', div, 'Results Data Archive:', div]
        info += ['Post-mortem data:\n\t' + stem + 'tar.gz']
        info += ['', 'Access data by:']
        info += ['>>> from core.api import ModelPostMortem']
        info += ['>>> pm = ModelPostMortem.fromfile(\'%s\')'%(stem + 'tar.gz')]
        info += ['>>> print pm.tracking']
        info += ['', div, self.__class__.__name__ + ' Docstring:', div]
        info += ['', '\n'.join([s.strip() for s in self.__doc__.split('\n')])]
        fd.write('\n'.join(info))
        fd.close()