# encoding: utf-8
"""
placemap.py -- Constructing spatial ratemaps and basic placemap analysis

Created by Joe Monaco on 02-07-2008.
Copyright (c) 2007, 2008 Columbia University. All rights reserved.
"""

# Library imports
import numpy as N, scipy as S, os

# Package imports
from .stage import StagingMap
from .core.model import AbstractModel
from .tools.bash import CPrint
from .tools.interp import GSmoothInterp2D
from .tools.images import tile2D, array_to_image
from .tools.misc import Null
from .tools.path import unique_path

# Traits imports
from enthought.traits.api import Trait, List, Array, Property, Instance, \
    Float, Int, false


# Place field determinative criteria

NOISE_FLOOR = 0.20              # Noise floor relative to map peak
FIELD_CUTOFF = 0.20             # Unit-based cutoff for defining fields relative 
                                # to peak estimated activity rate 
                                # --> MullerKubie89: 20% of peak activity
MIN_FIELD_SIZE = 50             # Putative fields with areas smaller than this 
                                # area are ignored (cm^2).
                                # --> MullerKubie89: 200 cm^2 contiguous area


class AbstractPlaceMap(StagingMap):

    """
    Superclass with general functionality for ratemap classes
    
    Required constructor argument: the AbstractModel subclass object that has the
    trajectory ('x', 'y') and rate ('r') timeseries OR a ModelPostMortem
    object containing the data ('x', 'y', 'r').
    
    To get maps for network model simulations containing multiple trials, use 
    the classmethod trial_split() to get a python list object of ratemap 
    instances for each trial.
    """

    # Network output data
    data = Instance(object)
    
    # Per-map normalized ratemap array
    Norm = Array

    def __init__(self, data, **traits):
        StagingMap.__init__(self, **traits)
        if isinstance(data, AbstractModel):
            data = data.post_mortem()
        self._validate_data(data)
        if data.ntrials > 1:
            self.out('Use trial_split classmethod for multiple trials!',
                error=True)
            return
        self.data = data
    
    @classmethod
    def _validate_data(cls, data):
        """
        Return a valid post mortem object of the constructor data argument
        """
        if not (hasattr(data, 'r') and hasattr(data, 'x') and
                hasattr(data, 'y')):
            raise ValueError, 'Data must have x, y, and r attributes'

    # Public methods
    
    def initialize(self):
        """Delete raw data after map initialization"""
        super(AbstractPlaceMap, self).initialize()
        self.data = None
    
    def initialize_norm(self):
        """Compute the normalized maps after creating the raw ratemaps"""
        if not self._initialized:
            self.initialize()
        
        self.Norm = N.empty((self.num_maps, self.H, self.W), 'd')
        for m in xrange(self.num_maps): 
            if self.Map[m].max():
                self.Norm[m] = self.Map[m] / self.Map[m].max()
            else:
                self.Norm[m] = 0
        self.out('Normalized ratemaps computed!')
    
    def save_map_image(self, desc=None, imgdir='spatial_maps', which='Map'):
        """
        Save PNG images of tiled maps stored in this ratemap object
        
        Parameters:
        desc -- short description of maps (default to model description)
        subdir -- save images in this subdirectory (default 'spatial_maps')
        which -- string name of attribute map array to tile and save (any
            (num_maps, H, W) shaped array) (default 'Map')
        """
        if not self._initialized:
            self.out('Ratemaps are not initialized!', error=True)
            return
            
        # Validate attribute map object specified by which keyword
        themap = getattr(self, which, None)
        if themap is None or type(themap) is not N.ndarray:
            self.out('Nonexistent or non-array attribute \'%s\''%which, 
                error=True)
            return
        if themap.shape[0] != self.num_maps or themap.ndim != 3:
            self.out('Map object %s does not have shape %dxHxW'%(which, 
                self.num_maps), error=True)
            return
        
        # Set map description
        if desc is None: 
            desc = self.desc
        
        # Unique image file path
        img_stem = '_'.join(desc.split() + [which.lower()])
        img_path = unique_path(os.path.join(imgdir, img_stem), 
            ext='png', fmt='%s_%02d')
        
        # Save the tiled maps to disk 
        array_to_image(tile2D(themap, gridvalue=1.03*themap.max()), img_path)
        self.out('Saved %s as a tiled image:\n-->%s'%(which.lower(), img_path), 
            prefix='Maps saved')
    
    @classmethod
    def trial_split(cls, data):
        """Get a list of ratemap objects for a multi-trial simulation"""
        if isinstance(data, AbstractModel):
            data = data.post_mortem()
        cls._validate_data(data)
        
        maps_list = []
        for t in xrange(1, data.ntrials+1):
            maps_list.append(cls(data.get_trial_data(t), desc='trial %d'%t))
        
        return maps_list
    
    # Traits defaults
    
    def _num_maps_default(self):
        return self.data.r.shape[1]


class PlaceMap(AbstractPlaceMap):
    
    """
    Baseclass with functionality for dealing with ratemaps of place fields
    
    Public methods:
    find_maxima -- save the location and magnitude of the ratemap maxima 
    find_peaks -- locate all peaks within each ratemap
    store_fields -- create coverage maps for all individual place fields
    get_unit_data -- return a records array of per-unit data*
    get_field_data -- return a recrods array of per-field data*
    
    * Records data can be accessed as:
    rec_data[i] -- entire record for ith entry
    rec_data['col_name'] -- entire data column
    rec_data[i]['col_name'] -- value of column col_name for ith entry
    
    Data attributes:
    maxima -- [find_maxima()] N_CA x 3 array of position (x, y) and rate of 
        each map's maxima; rows are 0's for silent units
    peaks -- [find_peaks()] object-array of arrays of peak postions (x, y) and 
        rates found in each ratemap; element is None for silent units
    fields -- [store_fields()] object-array of N_fields[m] x H x W arrays of 
        boolean coverage maps for each field in the mth map; element is None 
        for silent units
    
    Stage representation data:
    (These attributes are set by calling compute_coverage())
    coverage_maps -- maps of total field coverage for each place cell
    stage_coverage_map -- a stage map showing coverage across cells
    stage_coverage -- fractional coverage across cells
    stage_repr_map -- a stage map showing relative representation
        of each pixel in the stage (0.0 means no cells have overlapping
        fields, 1.0 means all cells have overlapping fields)
    stage_repr -- average number of fields representing any stage-pixel
    num_active -- the number of active place cells in this map
    sparsity -- the proportion of silent cells in this map
    """
    
    # Primary data sources
    maxima = Array
    peaks = Array
    fields = Array
    
    # Per-PC data
    coverage_maps = Array
    single_maps = Array
    
    # Across-PC data
    stage_coverage_map = Array
    stage_coverage = Float
    stage_repr_map = Array
    stage_repr = Float
    peak_rate = Float
    num_active = Int
    sparsity = Float
    
    # Private attributes
    _surround = Array
    _peaks_found = false
    _fields_stored = false
    _coverage_computed = false
    
    # Pubic methods for computing data sources
    
    def reset(self, reinitialize=False):
        """Reset state of place map computations
        """
        if reinitialize:
            self._initialized = False
        self._peaks_found = False
        self._fields_stored = False
        self._coverage_computed = False
        self.maxima = self._maxima_default()
        self.peaks = self._peaks_default()
        self.fields = self._fields_default()
        self.coverage_maps = self._coverage_maps_default()
        self.single_maps = self._single_maps_default()
        self.stage_coverage_map = self._stage_coverage_map_default()
        self.stage_repr_map = self._stage_repr_map_default()
    
    def find_maxima(self):
        """Find maxima of each ratemap, storing locations and magnitudes
        
        The following attribute are set:
        maxima -- array data is set such that maxima[m] = [x_m, y_m, z_m]
        num_active -- number of active place units
        sparsity -- proportion of inactive units in the network
        peak_rate -- maximum firing rate across all maps
        """
        if not self._initialized:
            self.initialize()
        self.out('Determining activity maxima...')
        
        x, y = N.mgrid[0:self.W, 0:self.H]
        xf = x.flatten() + 0.5
        yf = y.flatten() + 0.5
        
        # Scan the maps and store the maxima
        for m, M in enumerate(self.Map):
            zf = N.flipud(M).T.flatten()
            maxix = zf.argmax()
            if zf[maxix]:
                self.maxima[m] = xf[maxix], yf[maxix], zf[maxix]
        self.peak_rate = self.maxima[:,2].max()
    
    def find_peaks(self):
        """Find all field peaks in each ratemap
        
        Location and magnitudes are stored as 3-column arrays in *peaks* list:
            [x, y, z]
        """
        
        # Locate maxima so that cutoffs are computed
        self.find_maxima()
        self.out('Scanning maps for peaks...')
        
        # For each map, scan each pixel (x, y)
        for m, M in enumerate(self.Map):

            # Progress bar and skip empty maps
            self.out.printf('.', color='purple')
            if M.sum() == 0.0:
                continue
            
            # Define minimum rate for determining a new field peak as maximum
            # of per-cell peak or population noise floor
            cutoff = max(FIELD_CUTOFF*self.maxima[m,2], 
                NOISE_FLOOR*self.peak_rate)
                
            # Raster scan to find peaks to add to list
            peak_list = []
            for x in self._xrange:
                for y in self._yrange:
                    
                    # Get z value at the corresponding index
                    i, j = self.index(x, y)
                    z = self.Map[m, i, j]
                    if z < cutoff:
                        continue
                        
                    # Add x, y iff z is greater than all surrounding pixels
                    is_peak = True
                    for dx in self._surround:
                        try:
                            di, dj = self.index(x+dx[0], y+dx[1])
                        except IndexError:
                            continue
                        else:
                            if self.Map[m, di, dj] > z:
                                is_peak = False
                                break
                    if is_peak:
                        peak_list.append([x, y, z])
                        
            self.peaks[m] = N.array(peak_list)
            
        self._peaks_found = True
        self.out.printf('\n')
        self.out('Done!')
    
    def store_fields(self):
        """
        Store boolean arrays representing individual firing fields in each map
        
        The number of fields will be stored in *num_fields*.
        """        
        if not self._peaks_found:
            self.find_peaks()

        self.out('Scanning maps to store individual fields...')
        
        # Scan each rough-cut map using stored peaks
        for m, M in enumerate(self.Map):
            
            # Get the peaks and rough-cut for this map
            peaks = self.peaks[m]
            if peaks is None:
                self.out.printf('.', color='lightred')
                continue
                
            # Master cut of cell activity based on field cutoff
            map_cut = M > FIELD_CUTOFF*self.maxima[m, 2]
            
            # Handle nonspecific responses
            if map_cut.sum() / float(self.H*self.W) > .4:
                field_list = [map_cut]
            else:
                field_list = []
            
                # Scan peaks for unique fields
                for p, peak in enumerate(peaks):
                    field = N.zeros((self.H, self.W), '?')
                    self._mark_field(peak[0], peak[1], map_cut, field)
                
                    # Enforce field size minimum and kill dupes
                    if field.sum() > MIN_FIELD_SIZE:
                        duplicate = False
                        for f in field_list:
                            if (f*field).sum():
                                duplicate = True
                                break
                        if not duplicate:
                            field_list.append(field)
            
            # If valid fields were found, store them
            if len(field_list):
                self.fields[m] = N.array(field_list)
                if len(field_list) == 1:
                    self.out.printf('.', color='yellow')
                elif len(field_list) == 2:
                    self.out.printf('.', color='lightblue')
                else:
                    self.out.printf('.', color='purple')
            else:
                self.out.printf('.', color='lightred')

        
        self._fields_stored = True
        self.out.printf('\n')
        self.out('Done!')
            
    # Field marking recursion
    
    def _mark_field(self, x, y, master, field):
        """
        Recursively mark off complete field coverage
        
        Required parameters:
        x, y -- location within the field to mark
        master -- array (HxW) of all valid field pixels
        field -- array (HxW) containing the field picked by (x, y) in master
        """
        # Stop if out of bounds
        try:
            i, j = self.index(x, y)
        except IndexError:
            return
        
        # Stop if already marked
        if field[i, j]:
            return
        
        # Set field pixel, otherwise stop
        if master[i, j]:
            field[i, j] = True
        else:
            return
        
        # Probe surrounding pixels
        for dx in self._surround:
            try:
                self._mark_field(x+dx[0], y+dx[1], master, field)
            except RuntimeError: # max recursion depth
                break
    
    # Field coverage and stage representation methods
    
    def compute_coverage(self):
        """
        Compute coverage maps and single-field ratemaps for each place cell
        
        This method computes the following attributes:
        coverage_maps, single_maps, stage_coverage_map, stage_coverage, 
        stage_repr_map, stage_repr, num_active and sparsity
        """
        if not self._fields_stored:
            self.store_fields()
        
        self.out('Collapsing fields for coverage maps...')
        
        # Sum each set of individual field maps and store data
        npixels = self.H * self.W
        self.num_active = 0
        for m, fields in enumerate(self.fields):
            if fields is None:
                continue
            self.coverage_maps[m] = (fields.sum(axis=0) != 0)
            if len(fields) == 1:
                self.single_maps[m] = fields[0] * self.Map[m]
            else:
                for f in fields:
                    tmp = f * self.Map[m]
                    if tmp.max() == self.maxima[m, 2]:
                        self.single_maps[m] = tmp
                        break
            self.num_active += 1
        self.sparsity = 1 - (self.num_active / float(self.num_maps))
        
        # Compute full coverage of stage
        self.stage_coverage_map[:] = (self.coverage_maps.sum(axis=0) != 0)
        self.stage_coverage = float(self.stage_coverage_map.sum()) / npixels
        
        # Compute relative representation of stage
        self.stage_repr_map[:] = \
            self.coverage_maps.sum(axis=0) / float(self.num_maps)
        self.stage_repr = self.stage_repr_map.mean() * self.num_maps
        
        self._coverage_computed = True
        self.out('Done!')
    
    # Create record arrays of per-unit and per-field data 
    
    def get_unit_data(self):
        """
        Return a recarray object with data records for each place unit
        
        Record columns (i.e., data fields):
        unit -- unique integer id for each place unit
        max_x, max_y -- position where rate maximum occurs each map
        max_r -- value of the peak rate for each map
        num_fields -- number of fields
        avg_area -- average area of all fields in this map
        avg_diameter -- average diameter of all fields in this map
        coverage -- fractional coverage by area of all fields
        """
        if not self._coverage_computed:
            self.compute_coverage()
        
        # Create some useful data fields
        num_fields = N.array([(f is not None) and f.shape[0] or 0
            for f in self.fields], 'h')
        active = num_fields.astype(bool)
        num_fields = num_fields[active]
        unit = N.arange(self.num_maps)[active]
        max_x, max_y = self.maxima[active, 0], self.maxima[active, 1]
        max_rate = self.maxima[active, 2]
        avg_area = N.array([N.array([fld.sum() for fld in f]).mean() 
            for f in self.fields[active]])
        avg_diameter = N.array([N.array([2*N.sqrt(fld.sum()/N.pi) 
            for fld in f]).mean()
            for f in self.fields[active]])
        coverage = self.coverage_maps[active].sum(axis=-1).sum(axis=-1) / \
            float(self.H*self.W)
        
        # Create records array
        unit_data = N.rec.fromarrays(
            [unit, max_x, max_y, max_rate, num_fields, avg_area, avg_diameter, 
                coverage], 
            names='unit, max_x, max_y, max_r, num_fields, avg_area, '
                'avg_diameter, coverage',
            formats='l, d, d, d, h, d, d, d')
        
        return unit_data
    
    def get_field_data(self):
        """
        Return a recarray object with data records for each place field
        
        Record columns (i.e., data fields):
        id -- unique integer id for each place field
        unit -- place unit whose response this field is a part
        area -- place field area in sq-cm
        diameter -- approximate field diameter in cm
        radius -- approximate field radius in cm
        peak -- peak activity rate within each field
        average -- average activity rate across each field
        x, y -- position of the rate-weighted centroid 
        """
        if not self._coverage_computed:
            self.compute_coverage()
        
        # Find the total number of fields
        num_fields = 0
        for m in xrange(self.num_maps):
            if self.fields[m] is not None:
                num_fields += self.fields[m].shape[0]
        
        # Initialize data fields
        field_id = N.arange(num_fields)
        unit_id = N.empty(num_fields, 'h')
        area = N.empty(num_fields, 'd')
        diameter = N.empty(num_fields, 'd')
        radius = N.empty(num_fields, 'd')
        maximum = N.empty(num_fields, 'd')
        average = N.empty(num_fields, 'd')
        center_x = N.empty(num_fields, 'd')
        center_y = N.empty(num_fields, 'd')
        
        # Quantify place field characteristics
        f_id = 0
        for m in xrange(self.num_maps):            
            if self.fields[m] is None:
                continue
                
            for field in self.fields[m]:
                
                # Single field-masked ratemap and sum
                rates = field * self.Map[m]
                rates_sum = float(rates.sum())
                
                # Place unit identification
                unit_id[f_id] = m
                
                # Coverage geometry
                area[f_id] = field.sum()
                diameter[f_id] = 2*N.sqrt(area[f_id]/N.pi)
                radius[f_id] = diameter[f_id] / 2
                
                # Rate-dependent quantities
                maximum[f_id] = rates.max()
                average[f_id] = rates_sum / area[f_id]
                center_x[f_id] = (self._xrange[N.newaxis,:] * rates).sum() \
                    / rates_sum
                center_y[f_id] = (self._yrange[:,N.newaxis] * rates).sum() \
                    / rates_sum
                
                f_id += 1
            
        # Create records array
        field_data = N.rec.fromarrays(
            [field_id, unit_id, area, diameter, radius, maximum, average,
                    center_x, center_y],
            names='id, unit, area, diameter, radius, peak, average, x, y',
            formats='l, l, l, d, d, d, d, d, d')
        
        return field_data
    
    # Traits properties and defaults
    
    def _maxima_default(self):
        return N.zeros((self.num_maps, 3), 'd')
        
    def _peaks_default(self):
        return [None] * self.num_maps
    
    def _fields_default(self):
        return [None] * self.num_maps
    
    def __surround_default(self):
        """Relative indices for surrounding pixels
        """
        full = N.array([z.flatten() for z in N.mgrid[-1:2, -1:2]]).T
        return N.r_[full[:4], full[-4:]]
    
    def _coverage_maps_default(self):
        return N.zeros((self.num_maps, self.H, self.W), '?')
    
    def _single_maps_default(self):
        return N.zeros((self.num_maps, self.H, self.W), 'd')
    
    def _stage_coverage_map_default(self):
        return N.empty((self.H, self.W), '?')
    
    def _stage_repr_map_default(self):
        return N.empty((self.H, self.W), 'd')