# encoding: utf-8
"""
mismatch.py -- Analysis for double-rotation data, single mismatch angle
Exported namespace: MismatchAnalysis
Created by Joe Monaco on 2010-02-15.
Copyright (c) 2009-2011 Johns Hopkins University. All rights reserved.
This software is provided AS IS under the terms of the Open Source MIT License.
See http://www.opensource.org/licenses/mit-license.php.
"""
# Library imports
import os
import numpy as np
from matplotlib import cm
from enthought.traits.api import Float
# Package imports
from . import CL_COLORS, CL_LABELS
from ..session import VMOSession
from ..core.analysis import BaseAnalysis
from ..tools.images import array_to_image
from ..compare import (common_units, comparison_matrices,
correlation_matrix, correlation_diagonals, mismatch_rotation,
mismatch_response_tally)
class MismatchAnalysis(BaseAnalysis):
"""
BaseMismatchAnalysis subclass for VMODoubleRotation experiment simulations
Keyword arguments:
mismatch -- angle or tuple of angles specifying which mismatch sessions
should be loaded
load_dir -- name of the directory containing subject subdirectories (which
must be of the form RatXX/), which contain the MIS_XXX.tar.gz archives
of VMOSession objects for various mismatch angles
cluster_criteria -- dictionary of cluster criteria for VMOSession;
valid keys include min_spike_count and min_info_rate
"""
label = "Mismatch Analysis"
blur = Float(4.3)
def load_dataset(self, mismatch=45, load_dir=None, cluster_criteria={},
**kwargs):
"""Load session data, compute population matrices and return all the
info needed to run the analysis.
"""
from glob import glob
if not os.path.isdir(load_dir):
raise ValueError, 'not a valid data directory!'
if type(mismatch) is int:
mismatch = (mismatch,)
rat_dirs = []
for rat_dir in glob(os.path.join(load_dir, 'Rat*')):
if os.path.isdir(rat_dir):
rat_dirs.append(os.path.abspath(rat_dir))
if not rat_dirs:
rat_dirs = [load_dir]
rat = []
angle = []
session_pairs = []
R_pairs = []
total_clusts = 0
clusters_included = {}
for rat_dir in rat_dirs:
load_files = []
files = \
[os.path.join(rat_dir, 'MIS_%03d.tar.gz'%m) for m in mismatch]
for afile in files:
if os.path.exists(afile):
load_files.append(afile)
if not len(load_files):
continue
if len(rat_dirs) > 1:
cur_rat = int(rat_dir[-2:])
else:
cur_rat = 0
STDfile = os.path.join(rat_dir, 'STD.tar.gz')
STD = VMOSession.fromfile(STDfile)
for key in cluster_criteria:
if hasattr(STD, key):
setattr(STD, key, cluster_criteria[key])
for MISfile in load_files:
MIS = VMOSession.fromfile(MISfile)
for key in cluster_criteria:
if hasattr(MIS, key):
setattr(MIS, key, cluster_criteria[key])
pair = STD, MIS
common = common_units(*pair)
if len(common):
self.out('Loading mismatch data:\n%s'%MISfile)
session_pairs.append(pair)
R_pairs.append(comparison_matrices(*pair))
total_clusts += len(common)
rat.append(cur_rat)
MISangle = int(MISfile.split('.')[0][-3:])
angle.append(MISangle)
clusters_included[(cur_rat, 1, 2, MISangle)] = common
session_pairs = np.array(session_pairs, 'O')
sessions = \
np.core.records.fromarrays([np.array(rat), np.ones(len(rat)),
1+np.ones(len(rat)), np.array(angle)],
names='rat,day,session,angle', formats='i,i,i,i')
return \
sessions, session_pairs, R_pairs, total_clusts, clusters_included
def collect_data(self, category_criteria={}, **kwargs):
"""Load filtered sets of session data and perform a series of analyses
Keyword arguments:
category_criteria -- dictionary of kwargs for mismatch_response_tally
to specify category thresholds (e.g., angle_tol, min_corr)
See subclass docstrings for keyword arguments to control session
filtering criteria.
Analysis is performed both per-rat and globally, with saved results
data and image files for both scopes.
"""
# Create images directory
save_dir = os.path.join(self.datadir, 'images')
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# Load session data (this is subclass-specific, a required override)
self.results['mismatch'] = kwargs['mismatch']
sessions, session_pairs, R_pairs, total_clusts, clusters_included = \
self.load_dataset(**kwargs)
self.blur = session_pairs[0,0].default_blur_width
self.results['N_clusters'] = total_clusts
self.results['clusters'] = clusters_included
self.out('Loaded %d mismatch sessions with %d clusters!'%
(len(sessions), total_clusts))
# Collapse across days within rats for per-animal responses
rats = set(sessions.rat)
self.results['rats'] = list(rats)
self.out('Collating response matrices for %d rat(s)...'%len(rats))
R_rats = {}
for rat in rats:
rat_pairs = [R_pairs[i] for i in np.nonzero(sessions.rat==rat)[0]]
for pair in rat_pairs:
if rat not in R_rats:
R_rats[rat] = pair[0].copy(), pair[1].copy()
continue
R_rats[rat] = np.concatenate((R_rats[rat][0], pair[0])), \
np.concatenate((R_rats[rat][1], pair[1]))
self.out('Finished collating responses for rat %d'%rat)
# Collapse across all sessions to get global response matrices
self.out('Collapsing all sessions into global response...')
R_STD = R_MIS = None
for pair in R_pairs:
if R_STD is None:
R_STD, R_MIS = pair
continue
R_STD = np.concatenate((R_STD, pair[0]))
R_MIS = np.concatenate((R_MIS, pair[1]))
# Auto- and cross-correlation matrices and diagonals
self.out('Computing per-rat correlation matrices...')
for rat in rats:
# Compute and store the data
self.results['C_STD_rat%d'%rat] = C_STD = \
correlation_matrix(R_rats[rat][0])
self.results['diags_STD_rat%d'%rat] = \
correlation_diagonals(C_STD, centered=True, blur=self.blur)
self.results['C_MIS_rat%d'%rat] = C_MIS = \
correlation_matrix(*R_rats[rat])
self.results['diags_MIS_rat%d'%rat] = D_MIS = \
correlation_diagonals(C_MIS, centered=True, blur=self.blur)
self.results['popshift_rat%d'%rat] = \
np.array([D_MIS[0, np.argmax(D_MIS[1])], D_MIS[1].max()])
# Save image files of correlation matrices
array_to_image(np.flipud(C_STD), os.path.join(save_dir,
'C_STD_rat%d.png'%rat), cmap=cm.jet, norm=False)
array_to_image(np.flipud(C_MIS), os.path.join(save_dir,
'C_MIS_rat%d.png'%rat), cmap=cm.jet, norm=False)
# Global correlation matrices
self.out('Computing global correlation matrices...')
self.results['C_STD'] = C_STD = correlation_matrix(R_STD)
self.results['diags_STD'] = \
correlation_diagonals(C_STD, centered=True, blur=self.blur)
self.results['C_MIS'] = C_MIS = correlation_matrix(R_STD, R_MIS)
self.results['diags_MIS'] = D_MIS = \
correlation_diagonals(C_MIS, centered=True, blur=self.blur)
self.results['popshift'] = np.array([D_MIS[0, np.argmax(D_MIS[1])],
D_MIS[1].max()])
# Save image files of global correlation matrices
array_to_image(np.flipud(C_STD), os.path.join(save_dir, 'C_STD.png'),
cmap=cm.jet, norm=False)
array_to_image(np.flipud(C_MIS), os.path.join(save_dir, 'C_MIS.png'),
cmap=cm.jet, norm=False)
# Compute actual rotation angles via maximal correlation per-rat
self.out('Computing rotation angles and correlations...')
for rat in rats:
self.results['rotcorr_rat%d'%rat] = mismatch_rotation(*R_rats[rat])
self.results['rotcorr'] = mismatch_rotation(R_STD, R_MIS)
# Tally up per-session categorical response changes
self.out('Categorizing response changes across sessions...')
rat_tallies = {}
tallies = {}
for rat in rats:
rat_sessions = (sessions.rat == rat).nonzero()[0]
for ix in rat_sessions:
angle = sessions[ix][3]
pair = session_pairs[ix]
if rat not in rat_tallies:
rat_tallies[rat] = \
mismatch_response_tally(pair[0], pair[1], angle,
**category_criteria)
continue
new_tally = mismatch_response_tally(pair[0], pair[1], angle,
**category_criteria)
for category in new_tally:
rat_tallies[rat][category] += new_tally[category]
self.out('Finished response counts for rat %d'%rat)
for category in rat_tallies[rat]:
if category not in tallies:
tallies[category] = rat_tallies[rat][category]
continue
tallies[category] += rat_tallies[rat][category]
self.results['response_tallies_rat%d'%rat] = rat_tallies[rat]
self.results['response_tallies'] = tallies
# Good-bye
self.out('All done!')
def create_plots(self):
"""Create per-rat and global figures of rotations, correlations,
correlation diagonals, and response change pie charts
"""
from pylab import figure, axes, axis, subplot, polar, pie, rcParams
from ..tools.images import tiling_dims
self.figure = {}
res = self.results
rats = res['rats']
rats.sort()
r, c = tiling_dims(len(rats))
rat_plots_size = 13, 9
global_plots_size = 9, 9
# Per-rat rotations/correlations polar cluster plots
polar_kwargs = dict(marker='o', ms=10, mfc='b', mew=0, alpha=0.6,
ls='', aa=True)
deg_labels = [u'%d\xb0'%d for d in [-90,-45,0,45,90,135,180,-135]]
if len(rats) > 1:
rcParams['figure.figsize'] = rat_plots_size
self.figure['rotcorr_rats'] = f = figure()
f.set_size_inches(rat_plots_size)
f.suptitle('Cluster Rotation and Peak Correlation', fontsize=16)
for i,rat in enumerate(rats):
ax = subplot(r, c, i+1, polar=True)
rots, corrs = res['rotcorr_rat%d'%rat]
polar((np.pi/180)*rots + np.pi/2, corrs, **polar_kwargs)
ax.set_rmax(1.0)
ax.set_xticklabels(deg_labels)
ax.set_title('Rat %d'%rat)
# Global polar cluster plot
rcParams['figure.figsize'] = global_plots_size
self.figure['rotcorr'] = f = figure()
f.set_size_inches(global_plots_size)
f.suptitle('Cluster Rotation and Peak Correlation (%d rats)'%len(rats),
fontsize=16)
rots, corrs = res['rotcorr']
polar_kwargs.update(ms=12)
ax = subplot(111, polar=True)
polar((np.pi/180)*rots + np.pi/2, corrs, **polar_kwargs)
ax.set_rmax(1.0)
ax.set_xticklabels(deg_labels)
# Per-rat correlation diagonals plots
diags_kwargs = dict(c='k', ls='-', aa=True, marker='')
if len(rats) > 1:
rcParams['figure.figsize'] = rat_plots_size
self.figure['diags_rats'] = f = figure()
f.set_size_inches(rat_plots_size)
f.suptitle('Correlation Diagonals', fontsize=16)
for i,rat in enumerate(rats):
ax = subplot(r, c, i+1)
d_STD = res['diags_STD_rat%d'%rat]
d_MIS = res['diags_MIS_rat%d'%rat]
ax.plot(*d_STD, lw=1, label='STD', **diags_kwargs)
ax.plot(*d_MIS, lw=2, label='MIS', **diags_kwargs)
ax.set_title('Rat %d'%rat)
ax.set_xlim(d_STD[0][0], d_STD[0][-1])
ax.set_ylim(0, 1)
if i == 0:
ax.legend(fancybox=True)
ax.text(0.03, 0.92, 'Angle = %.1f'%res['popshift_rat%d'%rat][0],
transform=ax.transAxes)
ax.text(0.03, 0.87, 'Peak = %.3f'%res['popshift_rat%d'%rat][1],
transform=ax.transAxes)
# Global diagonals plot
rcParams['figure.figsize'] = global_plots_size
self.figure['diags'] = f = figure()
f.set_size_inches(global_plots_size)
f.suptitle('Population Correlations (%d rats)'%len(rats), fontsize=16)
ax = axes()
d_STD = res['diags_STD']
d_MIS = res['diags_MIS']
ax.plot(*d_STD, lw=1.5, label='STD', **diags_kwargs)
ax.plot(*d_MIS, lw=3, label='MIS', **diags_kwargs)
ax.set_xlim(d_STD[0][0], d_STD[0][-1])
ax.set_ylim(0, 1)
ax.legend(fancybox=True)
ax.text(0.03, 0.95, 'Angle = %.1f'%res['popshift'][0],
transform=ax.transAxes)
ax.text(0.03, 0.92, 'Peak = %.3f'%res['popshift'][1],
transform=ax.transAxes)
# Per-rat response change tally pie charts
pie_kwargs = dict(labels=CL_LABELS, colors=CL_COLORS, autopct='%.1f',
pctdistance=0.8, labeldistance=100)
if len(rats) > 1:
rcParams['figure.figsize'] = rat_plots_size
self.figure['response_tally_rats'] = f = figure()
f.set_size_inches(rat_plots_size)
f.suptitle('Response Changes', fontsize=16)
for i,rat in enumerate(rats):
ax = subplot(r, c, i+1)
tally = [res['response_tallies_rat%d'%rat][k] for k in CL_LABELS]
pie(tally, **pie_kwargs)
axis('equal')
ax.set_xlim(-1.04, 1.04) # fixes weird edge clipping
ax.set_title('Rat %d (%d total)'%(rat, sum(tally)))
if i == 0:
ax.legend(fancybox=True, loc='right',
bbox_to_anchor=(0.0, 0.5))
# Global response change tally pie chart
rcParams['figure.figsize'] = global_plots_size
self.figure['response_tally'] = f = figure()
tally = [res['response_tallies'][k] for k in CL_LABELS]
f.set_size_inches(global_plots_size)
f.suptitle('Response Changes (%d total)'%sum(tally), fontsize=16)
pie_kwargs.update(pctdistance=0.9)
ax = axes()
pie(tally, **pie_kwargs)
axis('equal')
ax.set_xlim(-1.04, 1.04) # fixes weird edge clipping
ax.legend(fancybox=True, loc='upper left', bbox_to_anchor=(-0.11, 1.05))
# Reset figure sizes
rcParams['figure.figsize'] = 9, 9