#encoding: utf-8
"""
pathint -- Path integration figure showing time-series data (without cues)

Created by Joe Monaco on 2010-10-07.

Copyright (c) 2009-2011 Johns Hopkins University. All rights reserved.

This software is provided AS IS under the terms of the Open Source MIT License. 
See http://www.opensource.org/licenses/mit-license.php.
"""

# Library imports
import os
import numpy as np
from numpy import pi
import matplotlib as mpl
import matplotlib.pylab as plt

# Package imports
from ..core.analysis import BaseAnalysis
from ..vmo import VMOModel
from ..session import VMOSession
from ..placemap import CirclePlaceMap
from ..compare import correlation_matrix, correlation_diagonals
from ..tools.stats import smooth_pdf
from ..tools.filters import halfwave
from ..tools.images import array_to_image


class PathIntFigure(BaseAnalysis):
    
    """
    Create the example path integration dataset for producing the basic
    introductory figure to the path integration mechanism of the 
    VMOModel model
    """
    
    label = 'path integration'
    
    def collect_data(self, **kwargs):
        """Run basic VMOModel path integration simulation and save relevant
        data for the intro figure of the paper.
        
        Additional keyword arguments are passed on to VMOModel.
        """
        # Set up and run the path integration model
        self.out('Running simulation...')
        model = VMOModel(N_outputs=500, N_theta=1000, N_cues=1, C_W=0.05,
            gamma_local=0, gamma_distal=0, **kwargs)
        model.advance()
        pm = model.post_mortem()
        
        # Save session data object for visualization
        session_fn = 'session.tar.gz'
        self.out('Computing and saving responses to %s...'%repr(session_fn))
        SD = VMOSession(model, save_I=True, save_E=True)
        SD.tofile(os.path.join(self.datadir, session_fn))
        self.results['session_file'] = session_fn
        
        # Save other simulation data
        data_fn = 'data.npz'
        self.out('Saving other data to %s...'%repr(data_fn))
        np.savez(os.path.join(self.datadir, data_fn), vel=pm.vel)
        self.results['data_file'] = data_fn
        
        # Good-bye!
        self.out('All done!')
        
    def create_plots(self, examples=[0,1], lap_inset=[3,4]):
        """Create figure with basic data panels 
        """
        self.figure = {}
        figsize = 10, 12
        plt.rcParams['figure.figsize'] = figsize
        self.figure['pathint'] = f = plt.figure(figsize=figsize)
        f.suptitle(self.label.title())
        
        # Load the simulation data if necessary
        if 'SD' in self.results:
            SD = self.results['SD']
        else:
            self.results['SD'] = SD = VMOSession.fromfile(
                os.path.join(self.datadir, self.results['session_file']))
        if 'data' in self.results:
            data = self.results['data']
        else:
            self.results['data'] = data = np.load(
                os.path.join(self.datadir, self.results['data_file']))
        os.chdir(self.datadir)
        self.out.outfd = file('figure.log', 'w')
        
        # Track angle
        tlim = (SD.t[0], SD.t[-1])
        alpha_star = SD.alpha - SD.alpha[0]
        alpha_star[alpha_star<=0] += 2*pi
        ax = plt.subplot(12,1,1)
        ax.plot(SD.t, alpha_star, 'k-', lw=1.5)
        ax.set_xlim(tlim)
        ax.set_axis_off()
        
        # V_x
        vlim = (-40,40)
        bar_x = 100
        bar_y = vlim[0]+4
        ax = plt.subplot(12,1,2)
        ax.plot(SD.t, data['vel'][:,0,0], 'k-', lw=0.5, zorder=5)
        ax.plot(tlim, [0,0], '-', c='0.6', lw=1, zorder=2)
        ax.plot([bar_x,bar_x+30], [bar_y]*2, 'k-', lw=1.5, zorder=3) # scale bar: 30 s
        ax.plot([bar_x+30]*2, [bar_y,bar_y+10], 'k-', lw=1.5, zorder=3) # 10 cm/s
        self.out('Velocity trace scale bar = 30s x 10cm/s')
        ax.set_xlim(tlim)
        ax.set_ylim(vlim)
        ax.set_axis_off()

        # V_y
        ax = plt.subplot(12,1,3)
        ax.plot(SD.t, data['vel'][:,0,1], 'k-', lw=0.5, zorder=5)
        ax.plot(tlim, [0,0], '-', c='0.6', lw=1, zorder=2)
        ax.set_xlim(tlim)
        ax.set_ylim(vlim)
        ax.set_axis_off()     
        
        # Ex 1
        ax1 = plt.subplot(12,1,4)
        ax1.plot(SD.t, SD.I_cache[examples[0]], 'b-', lw=0.5, zorder=5)
        # ax1.plot(tlim, [0,0], 'k-', lw=1, zorder=7)
        ax1.axis('tight')
        ax1.set_xlim(tlim)
        ax1.set_axis_off()
        
        # Ex 2
        ax2 = plt.subplot(12,1,5)
        ax2.plot(SD.t, SD.I_cache[examples[1]], 'b-', lw=0.5, zorder=5)
        # ax2.plot(tlim, [0,0], 'k-', lw=1, zorder=7)
        ax2.axis('tight')
        ax2.set_xlim(tlim)
        ax2.set_axis_off()
        
        # Set y limits for examples
        max_I = 1.03*np.max(np.absolute(SD.I_cache[examples]))
        ax1.set_ylim(-max_I, max_I)
        ax2.set_ylim(-max_I, max_I)
        
        # Lap indicators
        rect_kw = dict(fill=True, linewidth=0, zorder=0, alpha=0.2)
        cols = ['b', 'g']
        for sp in 2, 3, 4, 5:
            ax = plt.subplot(12,1,sp)
            for lap in xrange(len(SD.laps) - 1):
                lap_rect = mpl.patches.Rectangle((SD.laps[lap], vlim[0]), 
                    SD.laps[lap+1]-SD.laps[lap], vlim[1]-vlim[0], 
                    fc=cols[lap%2], **rect_kw)
                ax.add_artist(lap_rect)
        
        # Example insets
        crop = slice((SD.t >= SD.laps[lap_inset[0]]).nonzero()[0][0], 
            (SD.t >= SD.laps[lap_inset[1]+1]).nonzero()[0][0])
        examples = np.asarray(examples)
        E_max = 1.05*np.absolute(SD.E_cache[examples,crop]).max()
        for subp in ((12,1,6), (12,1,7)), ((12,2,15), (12,2,17)):
            ax = plt.subplot(*subp[0])
            ax.plot(SD.t[crop], SD.I_cache[examples[0],crop], 'b-', lw=0.5)
            ax.plot(SD.t[crop], SD.E_cache[examples[0],crop], 'k-', lw=1.5, aa=True)
            ax.plot([SD.t[crop.start], SD.t[crop.stop]], [SD.thresh]*2, 'k--')
            ax.axis('tight')
            ax.set_ylim(-E_max, E_max)
            ax.set_axis_off()        
            ax = plt.subplot(*subp[1])
            ax.plot(SD.t[crop], SD.I_cache[examples[1],crop], 'b-', lw=0.5)
            ax.plot(SD.t[crop], SD.E_cache[examples[1],crop], 'k-', lw=1.5, aa=True)
            ax.plot([SD.t[crop.start], SD.t[crop.stop]], [SD.thresh]*2, 'k--')
            ax.axis('tight')
            ax.set_ylim(-E_max, E_max)
            ax.set_axis_off()
            bar_x = SD.laps[lap_inset[0]]
            bar_y = -E_max+2
            ax.plot([bar_x,bar_x+5], [bar_y]*2, 'k-', lw=1.5, zorder=3) # scale bar: 5 s
            ax.plot([bar_x+5]*2, [bar_y,bar_y+5], 'k-', lw=1.5, zorder=3) # 5 Hz
        self.out('Input scale bar = 5s x 5Hz')


        # Start new figure
        self.figure['responses'] = f = plt.figure(figsize=figsize)
        f.suptitle('Output Responses')
        
        # Input peak distro
        E_max = SD.E_cache.max()
        ax = plt.subplot(421)
        ax.plot(*smooth_pdf(SD.E_cache.max(axis=1)), c='k', lw=1.5, aa=True)
        ax.plot([SD.thresh]*2, [0, 1.05*ax.get_ylim()[1]], 'r-', zorder=5)
        ax.axis('tight')
        ax.set_xlim(0, 1.05*E_max)
        
        # Input/output trajectories
        # ax = plt.subplot(8,6,4)
        # ax.scatter(x=SD.x[::4], y=SD.y[::4], c=SD.E_cache[examples[0],::4], 
        #     s=1, lw=0, marker='o')
        # ax.axis('equal')
        # ax.set_axis_off()
        
        ax = plt.subplot(843)
        ax.scatter(x=SD.x[::4], y=SD.y[::4], 
            c=halfwave(SD.E_cache[examples[0],::4]-SD.thresh), 
            s=1, lw=0, marker='o')
        ax.axis('equal')
        ax.set_axis_off()
        
        ax = plt.subplot(844)
        ax.pcolor(SD.R_laps[examples[0]].T, cmap=mpl.cm.jet)
        self.out('Example 1, lap pcolor max = %.2f'%SD.R_laps[examples[0]].max())
        ax.axis('tight')
        ax.set_axis_off()
        array_to_image(np.flipud(SD.R_laps[examples[0]].T), 'ex1_pcolor.png', 
            cmap=mpl.cm.jet)
        
        # ax = plt.subplot(8,6,10)
        # ax.scatter(x=SD.x[::4], y=SD.y[::4], c=SD.E_cache[examples[1],::4], 
        #     s=1, lw=0, marker='o')
        # ax.axis('equal')
        # ax.set_axis_off()
        
        ax = plt.subplot(847)
        ax.scatter(x=SD.x[::4], y=SD.y[::4], 
            c=halfwave(SD.E_cache[examples[1],::4]-SD.thresh), 
            s=1, lw=0, marker='o')
        ax.axis('equal')
        ax.set_axis_off()
        
        ax = plt.subplot(848)
        ax.pcolor(SD.R_laps[examples[1]].T, cmap=mpl.cm.jet)
        self.out('Example 2, lap pcolor max = %.2f'%SD.R_laps[examples[1]].max())
        ax.axis('tight')
        ax.set_axis_off()
        array_to_image(np.flipud(SD.R_laps[examples[1]].T), 'ex2_pcolor.png', 
            cmap=mpl.cm.jet)
        
        # Population responses
        # ax = plt.subplot(434)
        R = SD.get_population_matrix()
        # ax.imshow(R, aspect='auto', interpolation='nearest')
        # ax.set_axis_off()
        array_to_image(R, 'pop_response.png', 
            cmap=mpl.cm.jet)
        self.out('Pop matrix max = %.2f'%SD.R.max())        
        
        # ax = plt.subplot(435)
        PM = CirclePlaceMap(SD)
        # ax.imshow(PM.coverage_maps, cmap=mpl.cm.gray_r, aspect='auto', 
        #     interpolation='nearest')
        # ax.set_axis_off()
        array_to_image(PM.coverage_maps, 'field_extents.png', 
            cmap=mpl.cm.gray_r)
        self.out('Num active units = %d'%PM.num_active)
        
        # ax = plt.subplot(8, 3, 9)
        C = correlation_matrix(SD)
        # ax.imshow(np.flipud(C))
        # ax.axis('equal')
        # ax.set_axis_off()
        array_to_image(np.flipud(C), 'corr_matrix.png', 
            cmap=mpl.cm.jet)
        
        # ax = plt.subplot(8, 3, 12)
        D = correlation_diagonals(C, centered=True)
        # ax.plot(D[0], D[1], 'k-', lw=1.5, aa=True)
        # ax.axis('tight')
        # ax.set_ylim(0, 1.05)
        width = (D[0,1]-D[0,0]) * (D[1]>=D[1].max()/2).sum()
        self.out('Corr width at half max = %.1f degrees'%width)
        
        # Population distros
        # ax = plt.subplot(8, 3, 13)
        # ax.plot(*smooth_pdf(PM.field_size, sd=10), c='k', lw=1.5)
        # ax.axis('tight')
        # newy = 1.1*ax.get_ylim()[1]
        # ax.set_xlim(-10, 200)
        # ax.set_ylim(ymax=newy)
        med_size = np.median(PM.field_size)
        # ax.plot([med_size]*2, [0, newy], 'r-', zorder=1)
        self.out('Number of fields = %d'%PM.num_fields.sum())
        self.out('Median field size = %.1f degrees'%med_size)
        
        # ax = plt.subplot(8, 3, 14)
        # ax.plot(*smooth_pdf(PM.maxima[:,1]), c='k', lw=1.5)
        # ax.axis('tight')
        # newy = 1.1*ax.get_ylim()[1]
        # ax.set_xlim(-1, 12)
        # ax.set_ylim(ymax=newy)
        med_rate = np.median(PM.maxima[:,1])
        # ax.plot([med_rate]*2, [0, newy], 'r-', zorder=1)
        self.out('Median firing rate = %.2f Hz'%med_rate)
        
        # ax = plt.subplot(8, 3, 15)
        # ax.plot(*smooth_pdf(SD.I_rate[SD.sortix], sd=0.25), c='k', lw=1.5)
        # ax.axis('tight')
        # newy = 1.1*ax.get_ylim()[1]
        # ax.set_xlim(0, 12)
        # ax.set_ylim(ymax=newy)
        med_info = np.median(SD.I_rate[SD.sortix])
        # ax.plot([med_info]*2, [0, newy], 'r-', zorder=1)
        self.out('Median spatial info = %.1f bits/spike'%med_info)
        
        plt.draw()
        plt.rcParams['figure.figsize'] = plt.rcParamsDefault['figure.figsize']
        self.out.outfd.close()