#!/usr/bin/env python
#encoding: utf-8
"""
grid.place_network_ui.py -- Chaco2 and Traits UI interface for PlaceNetwork model class

Subclass of grid.place_network.PlaceNetwork which enables a graphical interface.

Copyright (c) 2007 Columbia University. All rights reserved.
"""

# Library imports
import numpy, time, wx

# Package imports
from .place_network import PlaceNetwork
from .core.view_model import ViewModel, ViewModelHandler
from .tools.images import array_to_rgba

# Traits imports
from enthought.traits.api import Property, Array, Instance, Range, Float, Bool
from enthought.traits.ui.api import View, VSplit, Group, Item, Heading, Include

# Chaco2 imports
from enthought.chaco.api import Plot, PlotLabel, ArrayPlotData, DataRange1D, hot
from enthought.enable.component_editor import ComponentEditor

# Global constants
COL_WIDTH = 240
DISP_UNITS = 200

# View for model initialization
initialization_view = View(
                            Heading('Set values prior to running simulation'),
                            Group(
                                Group(
                                    Item(name='desc', label='Description', width=COL_WIDTH),
                                    Item(name='num_trials', label='Trials'),
                                    Item(name='traj_type', label='Trajectory'),
                                    Item(name='dwell_factor', enabled_when="traj_type!='randwalk'"),
                                    Item(name='T', label='Duration (s)', enabled_when="traj_type=='randwalk'"),
                                    Item(name='dt', label='dt (s)'),
                                    Item(name='monitor_dt', label='Display dt (s)'),
                                    Item(name='N_CA', label='Output Units', style='readonly'),
                                    Item(name='N_EC', label='Grid Inputs', style='readonly'),
                                    Item(name='C_W', label='Connectivity'),
                                    label='Initialization', 
                                    show_border=True),
                                Group(
                                    Item(name='J0', width=COL_WIDTH),
                                    Item(name='tau_r', label='Tau (s)'),
                                    Group(
                                        Item(name='phi_lambda', label='Threshold'),
                                        Item(name='phi_sigma', label='Smoothness'),
                                        label='Field Nonlinearity:'),
                                    label='Parameters', 
                                    show_border=True),
                                orientation='horizontal'),
                            Item(name='growl', label='Growl On/Off', style='simple'),
                            Item(name='projdir', label='Project Directory'),
                            buttons=['Revert', 'Cancel', 'OK'],
                            title='Model Setup', 
                            kind='livemodal',
                            resizable=False)

# View for main simulation visualization
simulation_view = View(
                    VSplit(
                        Group(
                            Item(name='field_plots', editor=ComponentEditor()), 
                            Item(name='units_plot', editor=ComponentEditor()), 
                            Item(name='traj_plot', editor=ComponentEditor()), 
                            show_border=False,
                            show_labels=False,
                            orientation='horizontal'),
                        Group(
                            Group(
                                Group(
                                    Item(name='J0'), 
                                    label='Input Gain', 
                                    show_border=True),
                                Group(  
                                    Item(name='phi_lambda', label='Lambda'), 
                                    Item(name='phi_sigma', label='Sigma'), 
                                    label='Field Nonlinearity', 
                                    show_border=True),
                                Group(  
                                    Item(name='trail_length'), 
                                    Item(name='_program_flow', show_label=False, enabled_when='done==False'),
                                    Item(name='_reset_simulation', show_label=False),
                                    show_border=False),
                                springy=True,
                                show_border=False),
                            Item(name='phi_plot', editor=ComponentEditor(), width=0.4), 
                            show_border=False,
                            show_labels=False,
                            orientation='horizontal'),
                        show_labels=False,
                        show_border=False),
                    title='PlaceNetwork Simulation',
                    resizable=True,
                    height=1.0,
                    width=1.0,
                    buttons=['Revert', 'Cancel', 'OK'],
                    kind='live')
                    

class PlaceNetworkUI(ViewModel, PlaceNetwork):
    
    """
    PlaceNetwork model with Traits UI graphical real-time interface
    """
    
    pause = True
    
    # View traits
    sim_view = simulation_view
    init_view = initialization_view
    
    # Plot instances
    field_plots = Instance(Plot)
    units_plot = Instance(Plot)
    traj_plot = Instance(Plot)
    phi_plot = Instance(Plot)
    
    # Redefine user parameters as Range traits for sliders
    J0 = Range(low=0.0, high=100.0, value=45)
    
    # Control nonlinearity variables with sliders
    phi_lambda = Range(low=0.0, high=0.5, value=0.04)
    phi_sigma = Range(low=0.001, high=1.0, value=0.02)
    
    # Field plots tracking data
    h_aff = Property(Float, track=True)
    h_rec = Property(Float, track=True)
    h_sum = Property(Float, track=True)
    
    # Phi plot data
    h_range = Property(Array)
    phi_sample = Property(Array)
    _phi_updated = Bool(False)
    
    t0 = Float
    
    # Add to the simulation timestep
    
    def run_timestep(self):
        PlaceNetwork.run_timestep(self)
        
        # NOP to allow GUI to process
        time.sleep(.001)

    # Creating Plot instances as trait default functions
    
    def _field_plots_default(self):
        zero = numpy.array([0], 'd')
        data = ArrayPlotData(t=zero, h_rec=zero, h_aff=zero, h_sum=zero)
        p = Plot(data)
        p.plot(('t', 'h_aff'), name='Afferent', type='line', line_width=1, color='royalblue')
        p.plot(('t', 'h_rec'), name='Recurrent', type='line', line_width=1, color='tomato')
        p.plot(('t', 'h_sum'), name='Total', type='line', line_width=1, color='sienna')
        
        p.legend.visible = True
        p.legend.border_visible = False
        p.legend.align = 'ur'
        p.legend.bgcolor = (0.8, 0.8, 1.0, 0.4)
        p.legend.border_padding = 6
        p.legend.labels = ['Afferent', 'Recurrent', 'Total']
        
        p.y_grid.visible = p.x_grid.visible = False
        p.title = 'Synaptic Fields'
        p.x_axis.title = 'Time (s)'
        p.y_axis.title = 'Field Strength'
        p.bgcolor = 'mintcream'
        return p
    
    def _units_plot_default(self):
        N = min([self.N_CA, DISP_UNITS])
        data = ArrayPlotData(i=numpy.arange(N), r=self.r[:N], i_aff=self.i_aff[:N])
        p = Plot(data)
        p.plot(('i', 'r', 'i_aff'), type='cmap_scatter', color_mapper=hot, 
            marker='circle', marker_size=3, line_width=0)
        p.title = 'Place Cell Output'
        p.x_axis.title = 'Output Units'
        p.y_axis.title = 'Rate / Iaff'
        p.value_range.set_bounds(0.0, 1.0)
        p.x_grid.visible = p.y_grid.visible = False
        p.bgcolor = 'slategray'
        return p
    
    def _traj_plot_default(self):
        """Trajectory plot based on TrajectoryView.t_plot in chaco_threading_demo"""
        zero = numpy.array([0], 'd')
        data = ArrayPlotData(x=zero, y=zero)

        h, w = self.traj.Map.H, self.traj.Map.W
        data.set_data('x0', zero + self.traj.Map.x0[0])
        data.set_data('y0', zero + self.traj.Map.x0[1])
        
        p = Plot(data)
        p.plot(('x', 'y'), name='trail', color='red')
        p.plot(('x0', 'y0'), name='head', type='scatter', marker='circle', color='red')
        p.y_axis.visible = p.x_axis.visible = False
        p.y_grid.visible = p.x_grid.visible = False
        p.border_visible = True
        p.border_width = 2
        p.title = 'Rat Trajectory'
        
        p.index_range.set_bounds(0, w)
        p.value_range.set_bounds(0, h)
        
        p.overlays.append(PlotLabel('X (%d cm)'%w, component=p, overlay_position='bottom'))
        p.overlays.append(PlotLabel('Y (%d cm)'%h, component=p, overlay_position='left', angle=90))
        return p
        
    def _phi_plot_default(self):
        data = ArrayPlotData(h=self.h_range, phi=self.phi_sample)
        p = Plot(data)
        p.plot(('h', 'phi'), type='line', name='phi', color='slateblue', line_width=2.7)
        p.x_axis.title = 'h'
        p.y_axis.title = 'Phi[h]'
        p.x_grid.line_color = p.y_grid.line_color = 'slategray'
        p.bgcolor = 'khaki'
        p.title = 'Nonlinearity'
        return p
        
    # Callback for updating plot data
    
    def _update_plots(self):
        # Field plots data trails
        t, h_aff, h_rec, h_sum = self._trails('t', 'h_aff', 'h_rec', 'h_sum')
        if self.t > self.dt:
            self.field_plots.data.set_data('t', t)
            self.field_plots.data.set_data('h_aff', h_aff)
            self.field_plots.data.set_data('h_rec', h_rec)
            self.field_plots.data.set_data('h_sum', h_sum)
            
            # Trajectory trails
            new_x, new_y = self._trails('x', 'y')
            self.traj_plot.data.set_data('x', new_x)
            self.traj_plot.data.set_data('y', new_y)
            self.traj_plot.data.set_data('x0', numpy.array([new_x[-1]]))
            self.traj_plot.data.set_data('y0', numpy.array([new_y[-1]]))
        
        # Units plot update
        N = min([self.N_CA, DISP_UNITS])
        self.units_plot.data.set_data('r', self.r[:N])
        self.units_plot.data.set_data('i_aff', self.i_aff[:N])
        self.units_plot.value_range.high_setting = max([1, 1.05*self.r[:N].max()])
            
        # Phi data update
        self._update_phi_plot()
    
    def _update_phi_plot(self):
        if self._phi_updated:
            self.phi_plot.data.set_data('h', self.h_range)
            self.phi_plot.data.set_data('phi', self.phi_sample)
            self._phi_updated = False       
    
    # Trajectory changes refresh the plot
    
    def _stage_changed(self):
        self._refresh_traj_plot()
    
    def _traj_type_changed(self):
        self._refresh_traj_plot()
    
    def _refresh_traj_plot(self):
        self.traj = self.new_trajectory()
        self.traj_plot = self._traj_plot_default()
    
    # Field tracking properties and trait notifications
    
    def _get_h_aff(self):
        return self.i_aff.mean()
    
    def _get_h_rec(self):
        return -self.J0 * self.r.sum()
    
    def _get_h_sum(self):
        return self.h_aff + self.h_rec
    
    # Nonlinearity plot automation and line data
    
    def _get_phi_sample(self):
        return self.phi_h(self.h_range - self.phi_lambda)
    
    def _get_h_range(self):
        return numpy.arange(0, max([2.5, 2.5*self.phi_lambda]), 0.02)
     
    def _phi_lambda_changed(self):
        self._phi_pause_update()
        
    def _phi_sigma_changed(self):
        self._phi_pause_update()

    def _phi_pause_update(self):
        self._phi_updated = True
        if self.pause:
            self._update_phi_plot()
    
    # Convenience functions for calling views
    
    def setup(self):
        self.configure_traits(view='init_view')
    
    def simulation(self):
        self.configure_traits(view='sim_view', handler=ViewModelHandler())


if __name__ == "__main__":
    import os
    from .dmec import GridCollection
    EC = GridCollection()
    ca3 = PlaceNetworkUI(EC=EC, C_W=0.33, growl=False, T=300, desc='demo run')
    ca3.setup()
    ca3.simulation()