# encoding: utf-8
"""
trends.py -- Analysis of trends in response changes across mismatch angle
Exported namespace: MismatchTrends
Created by Joe Monaco on 2010-02-17.
Copyright (c) 2009-2011 Johns Hopkins University. All rights reserved.
This software is provided AS IS under the terms of the Open Source MIT License.
See http://www.opensource.org/licenses/mit-license.php.
"""
# Library imports
import numpy as np
# Package imports
from . import CL_LABELS, CL_COLORS
from ..core.analysis import BaseAnalysis
from ..tools.stats import smooth_pdf
from ..tools import circstat
class MismatchTrends(BaseAnalysis):
"""
Collate data from multiple MismatchAnalysis results to show trends in
response changes across mismatch angles.
"""
label = "mismatch trends"
def collect_data(self, *args):
"""Collate data from previous results for analysis and visualization
"""
# Sort analysis results data and get list of mismatch angles
self.out('Sorting results data according to mismatch...')
sort_data = []
for session in args:
data = session.results
angle = data['mismatch']
if type(angle) is tuple:
label = ', '.join([str(a) for a in angle])
angle = min(angle)
elif type(angle) is int:
label = str(angle)
else:
raise TypeError, \
'bad mismatch angle type (%s)'%str(type(angle))
sort_data.append((angle, label, data))
sort_data.sort()
mismatch, angle_labels, data_list = np.array(sort_data, 'O').T
self.results['mismatch_labels'] = angle_labels
self.results['N_mismatch'] = N = len(data_list)
# Mean and SEM of rotation angles and peak correlations
self.out('Computing statistics of rotations and correlations...')
self.results['rotations_mean'] = rot_mean = np.empty(N, 'd')
self.results['rotations_sem'] = rot_sem = np.empty(N, 'd')
self.results['correlations_mean'] = corr_mean = np.empty(N, 'd')
self.results['correlations_sem'] = corr_sem = np.empty(N, 'd')
self.results['N_common'] = N_common = np.empty(N, 'd')
for i, data in enumerate(data_list):
rots, corrs = data['rotcorr']
rots = (np.pi/180)*rots
N_common[i] = rots.shape[0]
rot_mean[i] = circstat.mean(rots)
rot_sem[i] = circstat.std(rots) / np.sqrt(N_common[i])
corr_mean[i] = np.mean(corrs)
corr_sem[i] = np.std(corrs) / np.sqrt(N_common[i])
rot_mean[rot_mean>np.pi] -= 2*np.pi # make distal rots negative
rot_mean *= 180/np.pi # convert back to degrees
# Smoothed density distributions
self.out('Computing smoothed density estimates...')
rot_pdf = []
corr_pdf = []
for data in data_list:
rots, corrs = data['rotcorr'].copy()
rots[rots>180] -= 360 # make distal rots negative
rot_pdf.append(smooth_pdf(rots))
corr_pdf.append(smooth_pdf(corrs))
self.results['rotations_pdf'] = np.array(rot_pdf, 'O')
self.results['correlations_pdf'] = np.array(corr_pdf, 'O')
# Population code rotation via correlation diagonals
self.out('Collating correlation diagonals...')
diags = [data['diags_MIS'] for data in data_list]
self.results['diagonals'] = np.array(diags, 'O')
# Response category distribution
self.out('Collating categorical response distributions...')
self.results['categories'] = categories = {}
N_total = []
for data in data_list:
N_total.append(sum([data['response_tallies'][key] for key in
CL_LABELS]))
for key in CL_LABELS:
categories[key] = \
np.array(
[data_list[i]['response_tallies'][key]/float(N_total[i])
for i in xrange(N)])
self.results['N_total'] = np.array(N_total)
# Good-bye!
self.out('All done!')
def create_plots(self):
"""Create trends plots for rotations, peak correlations and categorical
remapping statistics.
"""
from pylab import figure, subplot, rcParams, draw
from ..tools.images import tiling_dims
self.figure = {}
res = self.results
labels = res['mismatch_labels']
N = res['N_mismatch']
N_total = res['N_total']
diagonals_figsize = 10, 5
rotation_figsize = 13, 9
category_figsize = 10, 7
category_pie_figsize = 10, 10
# Population mismatch correlation diagonals
rcParams['figure.figsize'] = diagonals_figsize
self.figure['diagonals_trends'] = f = figure()
f.set_size_inches(diagonals_figsize)
f.suptitle('Population Mismatch Correlations', fontsize=16)
line_kwargs = dict(lw=2, aa=True)
ax = subplot(111)
ax.hold(True)
for i in xrange(N):
ax.plot(*res['diagonals'][i], label=labels[i], **line_kwargs)
ax.axis('tight')
ax.set_ylim(0,1)
ax.set_xlabel('Rotation (degrees)')
ax.set_ylabel('Diagonal Correlation')
ax.legend(fancybox=True, loc=2)
# Population and cluster rotations and correlations
rcParams['figure.figsize'] = rotation_figsize
self.figure['rotation_trends'] = f = figure()
f.set_size_inches(rotation_figsize)
f.suptitle('Cluster Rotation and Peak Correlation', fontsize=16)
ax = subplot(221)
ax.hold(True)
for i in xrange(N):
ax.plot(*res['rotations_pdf'][i], label=labels[i], **line_kwargs)
ax.set_xlim(-180, 180)
ax.set_xlabel('Rotation (degrees)')
ax.set_ylabel('Pr[Rotation]')
ax.legend(fancybox=True, loc=2)
ax = subplot(223)
ax.hold(True)
for i in xrange(N):
ax.plot(*res['correlations_pdf'][i], label=labels[i], **line_kwargs)
ax.set_xlim(0, 1)
ax.set_xlabel('Peak Correlation')
ax.set_ylabel('Pr[Peak Correlation]')
err_kwargs = dict(fmt='-', ecolor='k', capsize=5, ms=3, elinewidth=1,
lw=2, aa=True)
ax1 = subplot(122)
ax2 = ax1.twinx()
x = np.arange(1, N+1)
rh = ax1.errorbar(x, res['rotations_mean'], yerr=res['rotations_sem'],
c=(0.2, 0.0, 0.8), **err_kwargs)
ch = ax2.errorbar(x, res['correlations_mean'], yerr=res['correlations_sem'],
c=(0.8, 0.0, 0.2), **err_kwargs)
ax2.set_ylabel('Peak Correlation', rotation=270)
ax2.set_ylim(0, 1)
ax1.set_ylabel('Rotation (degrees)')
ax1.set_xlabel('Mismatch')
ax1.set_xlim(0.5, N+0.5)
ax1.set_xticks(x)
ax1.set_xticklabels(labels)
ax1.legend((rh[0], ch[0]), ('Rotation', 'Peak Correlation'),
fancybox=True, loc=0)
# Category distribution stacked bar chart
rcParams['figure.figsize'] = category_figsize
self.figure['response_trends'] = f = figure()
f.set_size_inches(category_figsize)
f.suptitle('Response Changes: Trends', fontsize=16)
plot_category_chart(subplot(111), res)
# Category distribution pie charts
rcParams['figure.figsize'] = category_pie_figsize
self.figure['response_trends_pie'] = f = figure()
f.set_size_inches(category_pie_figsize)
f.suptitle('Response Changes', fontsize=16)
r, c = tiling_dims(N)
pie_kwargs = dict(labels=CL_LABELS, colors=CL_COLORS, autopct='%.1f',
pctdistance=0.8, labeldistance=100)
for i in xrange(N):
ax = subplot(r, c, i+1)
tally = [cats[k][i] for k in CL_LABELS]
ax.pie(tally, **pie_kwargs)
ax.axis('equal')
ax.set_xlim(-1.04, 1.04) # fixes weird edge clipping
ax.set_title('%s (%d total)'%(labels[i], N_total[i]))
if i == 0:
ax.legend(fancybox=True, loc='right',
bbox_to_anchor=(0.1, -0.05))
# Reset figure size
draw()
rcParams['figure.figsize'] = 9, 9
# Plot helper functions
def plot_category_chart(ax, res, show_legend=True):
"""Plot stacked bar chart of remapping response categories
Arguments:
ax -- axis to plot into
res -- trends results dictionary
Keyword arguments:
show_legend -- whether to draw legend
"""
from matplotlib.patches import Polygon
labels = res['mismatch_labels']
N = res['N_mismatch']
x = np.arange(1, N+1)
stacked = np.empty((len(CL_LABELS)+1, N), 'd')
stacked[0] = 0.0
stacked[-1] = 1.0
cats = res['categories']
for i in xrange(1,len(CL_LABELS)):
stacked[i] = stacked[i-1] + cats[CL_LABELS[i-1]]
poly_kwargs = dict(aa=True, lw=1.5, alpha=0.85)
xloop = np.concatenate((x, x[::-1]))
for i in xrange(len(CL_LABELS)):
yloop = np.concatenate((stacked[i], stacked[i+1,::-1]))
p = Polygon(np.c_[xloop, yloop], fc=CL_COLORS[i],
ec=CL_COLORS[i], label=CL_LABELS[i], **poly_kwargs)
ax.add_artist(p)
p.set_clip_box(ax.bbox)
ax.axis([1, N, 0, 1])
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.set_xlabel('Mismatch')
ax.set_ylabel('Response Fraction')
if show_legend:
ax.legend(loc='upper left', bbox_to_anchor=(1.01, 1.0))
return ax