# encoding: utf-8
"""
compare.py -- Functions for comparing double-rotation session responses

Exoported namespace: mismatch_response_tally, mismatch_rotation, 
    cluster_mismatch_rotation, population_spatial_correlation, 
    correlation_matrix, correlation_diagonals, common_units

Created by Joe Monaco on 2010-02-11.

Copyright (c) 2009-2011 Johns Hopkins University. All rights reserved.

This software is provided AS IS under the terms of the Open Source MIT License. 
See http://www.opensource.org/licenses/mit-license.php.
"""

# Library imports
import numpy as np
from scipy.stats import pearsonr

# Package imports 
from .trajectory import CircleTrackData
from .session import VMOSession
from .tools.radians import get_angle_array, circle_diff
from .tools.filters import circular_blur


# Functions to compare responses between two sessions

def mismatch_response_tally(STD, MIS, mismatch, angle_tol=0.5, min_corr=0.4,
    **kwargs):
    """Categorize and tally spatial response changes in mismatch session
    
    Both angle tolerance and minimum correlation criteria must be met for a 
    response change to classify as coherent.
    
    Required arguments:
    STD, MIS -- VMOSession objects for STD and MIS session pair
    mismatch -- total mismatch angle for the cue rotation (in degrees)
    
    Keyword arguments:
    angle_tol -- proportional tolerance for matching a cue rotation
    min_corr -- minimum maximal correlation value across rotation
    
    Returns dictionary of tallies: local, distal, ambiguous, on, off.
    """
    if not hasattr(STD, '_get_active_units') or \
        type(MIS) is not type(STD):
        raise ValueError, 'invalid or non-matching session data inputs'
        
    # Count coherent vs non-coherent response changes
    tally = dict(local=0, distal=0, amb=0)
    rotations = mismatch_rotation(STD, MIS, degrees=False, **kwargs)
    mismatch *= (np.pi/180) / 2 # cue rotation in radians
    for i,rotcorr in enumerate(rotations.T):
        rot, corr = rotcorr
        if corr < min_corr:
            tally['amb'] += 1
        else:
            angle_dev = min(
                abs(circle_diff(rot, mismatch)), 
                abs(circle_diff(rot, -mismatch)))
            if angle_dev <= angle_tol * mismatch:
                if rot < np.pi:
                    tally['local'] += 1
                else:
                    tally['distal'] += 1
            else:
                tally['amb'] += 1
    
    # Count remapping (on/off) response changes
    STDclusts = set(STD._get_active_units())
    MISclusts = set(MIS._get_active_units())
    tally['off'] = len(STDclusts.difference(MISclusts))
    tally['on'] = len(MISclusts.difference(STDclusts))
    
    return tally

def mismatch_rotation(STD, MIS, degrees=True, **kwargs):
    """Computes rotation angle and peak correlation for active clusters between
    two sessions
    
    Returns two-row per-cluster array: angle, correlation.
    """
    Rstd, Rmis = comparison_matrices(STD, MIS, **kwargs)
    units, bins = Rstd.shape
    return np.array([cluster_mismatch_rotation(Rstd[c], Rmis[c], degrees=degrees)
        for c in xrange(units)]).T
        
def cluster_mismatch_rotation(Rstd, Rmis, degrees=True):
    """Find the rotation angle for a single cluster ratemap
    
    Ratemap inputs must be one-dimensional arrays representing the whole track.
    
    Keyword arguments:
    degrees -- whether to specify angle in degrees or radians
    
    Returns (angle, correlation) tuple.
    """
    bins = Rstd.shape[0]
    angle = get_angle_array(bins, degrees=degrees)
    corr = np.empty(bins, 'd')
    for offset in xrange(bins):
        MISrot = np.concatenate((Rmis[offset:], Rmis[:offset]))
        corr[offset] = pearsonr(Rstd, MISrot)[0]
    return angle[np.argmax(corr)], corr.max()

def population_spatial_correlation(STD, MIS, **kwargs):
    """Return whole population spatial correlation between two matrices
    """
    A, B = comparison_matrices(STD, MIS, **kwargs)
    return pearsonr(A.flatten(), B.flatten())[0]


# Functions to compute and operate on population correlation matrices

def correlation_matrix(SD, cross=None, **kwargs):
    """Compute a spatial correlation matrix of population-rate vectors
    
    Returns (bins, bins) correlation matrix.
    """
    # Validate arguments and compute population response matrices
    R, R_ = comparison_matrices(SD, cross, **kwargs)
    
    # Compute the correlation matrix
    N_units, bins = R.shape
    C = np.empty((bins, bins), 'd')
    for i in xrange(bins):
        R_i = R[:,i] / np.sqrt(np.dot(R[:,i], R[:,i]))
        for j in xrange(bins):
            R_j = R_[:,j] / np.sqrt(np.dot(R_[:,j], R_[:,j]))
            C[i,j] = np.dot(R_i, R_j)
            
    # Fix any NaN's resulting from silent population responses (rare!)
    C[np.isnan(C)] = 0.0
    return C

def correlation_diagonals(C, use_median=True, centered=False, blur=None):
    """Return the angle bins and diagonals of a correlation matrix
    
    Keyword arguments:
    use_median -- whether to use the median diagonal correlation to collapse 
        the diagonals; if use_median=False, the average is used
    centered -- whether to center the diagonals on [-180, 180]
    blur -- if not None, specifies width in degrees of gaussian blur to be
        applied to diagonal array
    """
    bins = C.shape[0]
    if C.shape != (bins, bins):
        raise ValueError, 'correlation matrix must be square'
    f = use_median and np.median or np.mean
    D = np.empty(bins+1, 'd')
    d = np.empty(bins, 'd')
    offset = 0
    if centered:
        offset = int(bins/2)
        
    # Loop through and collapse correlation diagonals
    for b0 in xrange(bins):
        for b1 in xrange(bins):
            d[b1] = C[b1, np.fmod(offset+b0+b1, bins)]
        D[b0] = f(d)
    if blur is not None:
        D[:bins] = circular_blur(D[:bins], blur)
        
    # Wrap the last point around to the beginning
    D[-1] = D[0]
    last = centered and 180 or 360
    a = np.r_[get_angle_array(bins, degrees=True, zero_center=centered), last]
    return np.array([a, D])


# Functions on lists of VMOSession objects (e.g., full five-session double-
# rotation experiments)

def comparison_matrices(SD, cross, **kwargs):
    """Validate multiple types of arguments for use as comparanda

    SD and cross must be the same type of object (unless cross is None for an
    autocomparison): VMOSession instances, or previously computed population 
    matrices.

    For VMOSession objects, a clusters list is automatically created to be 
    passed in for the get_population_matrix call. This may be overriden by 
    passing in your own clusters list as a keyword argument. Additional keyword 
    arguments are passed to get_population_matrix.

    Returns two valid (units, bins) population matrix references.
    """
    if type(SD) is np.ndarray:
        if SD.ndim != 2:
            raise ValueError, 'expecting 2-dim population matrix'
        R = R_ = SD
        if type(cross) is np.ndarray:
            if cross.shape == SD.shape:
                R_ = cross
            else:
                raise ValueError, 'non-matching population matrices'
    elif hasattr(SD, '_get_active_units'):
        kwargs['norm'] = False
        if type(cross) is type(SD):
            if 'clusters' not in kwargs:
                kwargs['clusters'] = common_units(SD, cross)
        elif cross is not None:
            raise ValueError, 'non-matching session object types'
        R = R_ = SD.get_population_matrix(**kwargs)
        if cross is not None:
            R_ = cross.get_population_matrix(**kwargs)
            if R_.shape != R.shape:
                raise ValueError, 'population matrix size mismatch'
    return R, R_

def common_units(*SD_list):
    """Get a list of the active units that are common to a set of sessions
    """
    # Allow a single python list to be passed in
    if len(SD_list) == 1 and type(SD_list[0]) is list:
        SD_list = SD_list[0]
    
    # Get a list of sets of clusters for each data object
    clust_list = []
    for SD in SD_list:
        clust_list.append(set(SD._get_active_units()))

    # Find the common clusters
    common = clust_list[0]
    for i in xrange(1, len(SD_list)):
        common = common.intersection(clust_list[i])
        
    return list(common)