# encoding: utf-8
"""
placemap_viewer.py -- An interactive GUI interface for individual spatial maps
Created by Joe Monaco on 04-30-2008.
Copyright (c) 2008 Columbia University. All rights reserved.
"""
# Library imports
import numpy as N, scipy as S
from matplotlib import cm
# Package imports
from .ratemap import PlaceMap
from .tools.images import array_to_rgba
from .tools.stats import integer_hist
from .tools.bash import CPrint
# Traits imports
from enthought.traits.api import HasTraits, Instance, Trait, TraitError, \
Property, Enum, Int, Float, Range, Delegate
from enthought.traits.ui.api import View, Group, Item, Heading
# Chaco imports
from enthought.chaco.api import ArrayPlotData, Plot, BasePlotContainer, VPlotContainer, copper
from enthought.enable.component_editor import ComponentEditor
class PlaceMapViewer(HasTraits):
"""
Chaco viewer for placemap data
Constructor arguments:
pmap -- PlaceMap (or subclass) object to view
Public methods:
view -- Bring up the Chaco View window for looking at data
"""
# Console output
out = Instance(CPrint)
# Reference to PlaceMap object
PMap = Trait(PlaceMap)
# Stage map traits
stage_map = Instance(Plot)
stage_map_type = Enum('representation', 'coverage', 'field_centers')
sparsity = Delegate('PMap')
num_active = Delegate('PMap')
stage_coverage = Delegate('PMap')
stage_repr = Delegate('PMap')
peak_rate = Delegate('PMap')
# Unit map traits
_unit = Int
unit_map = Instance(Plot)
unit_map_type = Enum('ratemap', 'single', 'fields')
num_fields = Int
coverage = Float
avg_area = Float
avg_diameter = Float
max_rate = Float
# Unit data traits
unit_data_plots = Instance(BasePlotContainer)
unit_bins = Range(low=5, high=50, value=20)
# Field data traits
field_data_plots = Instance(BasePlotContainer)
field_bins = Range(low=5, high=50, value=20)
# Chaco view definition
traits_view = \
View(
Group(
Group(
Item('stage_map_type'),
Item('stage_map', editor=ComponentEditor(), show_label=False),
Group(
Item('sparsity', style='readonly'),
Item('num_active', style='readonly'),
Item('stage_coverage', label='Coverage', style='readonly'),
Item('stage_repr', label='Representation', style='readonly'),
Item('peak_rate', style='readonly'),
label='Stage Coding',
show_border=True),
label='Stage Maps',
orientation='v'),
Group(
Item('unit_map_type'),
Item('unit', style='custom'),
Item('unit_map', editor=ComponentEditor(), show_label=False),
Group(
Item('max_rate', style='readonly'),
Item('num_fields', style='readonly'),
Item('coverage', style='readonly'),
Item('avg_area', label='Mean Field Area', style='readonly'),
Item('avg_diameter', label='Mean Field Diameter', style='readonly'),
label='Place Unit',
show_border=True),
label='Unit Maps',
orientation='v'),
Group(
Heading('Distributions of Single-Unit Properties'),
Item('unit_data_plots', editor=ComponentEditor(), show_label=False),
Item('unit_bins', label='Bins'),
label='Unit Data'),
Group(
Heading('Distributions of Single-Field Properties'),
Item('field_data_plots', editor=ComponentEditor(), show_label=False),
Item('field_bins', label='Bins'),
label='Field Data'),
layout='tabbed'),
title='Placemap Viewer',
resizable=True,
height=800,
width=700,
kind='live',
buttons=['Cancel', 'OK'])
def __init__(self, pmap, **traits):
HasTraits.__init__(self, **traits)
try:
self.PMap = pmap
except TraitError:
self.out('PlaceMap subclass instance required', error=True)
return
self.fdata = self.PMap.get_field_data()
self.udata = self.PMap.get_unit_data()
self.add_trait('unit', Range(low=0, high=self.PMap.num_maps-1))
self._update_unit_values()
self.out('Bringing up place-map visualization...')
self.view()
self.out('Done!')
def view(self):
self.configure_traits()
# Plot creation methods
def _stage_map_default(self):
# RGBA maps
rep_map = array_to_rgba(self.PMap.stage_repr_map, cmap=cm.hot)
cov_map = array_to_rgba(self.PMap.stage_coverage_map, cmap=cm.gray)
# Data sources and plot object
data = ArrayPlotData(fields_x=self.fdata['x'], fields_y=self.fdata['y'],
fields_z=self.fdata['peak'], rep=rep_map, cov=cov_map)
p = Plot(data)
# Plot the field centers
p.plot(('fields_x', 'fields_y', 'fields_z'), name='centers', type='cmap_scatter',
marker='dot', marker_size=5, color_mapper=copper, line_width=1, fill_alpha=0.6)
# Plot the representation and coverage maps
p.img_plot('rep', name='rep', xbounds=(0, self.PMap.W), ybounds=(0, self.PMap.H),
origin='top left')
p.img_plot('cov', name='cov', xbounds=(0, self.PMap.W), ybounds=(0, self.PMap.H),
origin='top left')
# Start with only the representation map visible
p.plots['cov'][0].visible = False
p.plots['centers'][0].visible = False
# Plot tweaks
p.aspect_ratio = 1.0
p.y_axis.title = 'Y (cm)'
p.x_axis.title = 'X (cm)'
p.x_axis.orientation = 'bottom'
p.title = 'Stage Maps'
return p
def _unit_map_default(self):
# Set the initial unit map
data = ArrayPlotData(unit_map=self._get_unit_map_data())
p = Plot(data)
# Plot the map
p.img_plot('unit_map', name='unit', xbounds=(0, self.PMap.W), ybounds=(0, self.PMap.H),
origin='top left')
# Plot tweaks
p.aspect_ratio = 1.0
p.y_axis.title = 'Y (cm)'
p.x_axis.title = 'X (cm)'
p.x_axis.orientation = 'bottom'
p.title = 'Single Unit Maps'
return p
def _unit_data_plots_default(self):
# Plot data and vertical container object
data = ArrayPlotData(**self._get_unit_plots_data())
container = VPlotContainer()
# Add individual distribution plots to container
for key in ('avg_diameter', 'avg_area', 'coverage', 'max_r', 'num_fields'):
p = Plot(data)
p.plot((key+'_bins', key), name=key, type='polygon', edge_width=2,
edge_color='mediumblue', face_color='lightsteelblue')
p.x_axis.title = key
p.y_axis.title = 'count'
p.padding = [50, 30, 20, 40]
if key == 'num_fields':
p.x_axis.tick_interval = 1
container.add(p)
return container
def _field_data_plots_default(self):
# Plot data and vertical container object
data = ArrayPlotData(**self._get_field_plots_data())
container = VPlotContainer()
# Add individual distributions plots to container
for key in ('area', 'diameter', 'average', 'peak'):
p = Plot(data)
p.plot((key+'_bins', key), name=key, type='polygon', edge_width=2,
edge_color='red', face_color='salmon')
p.x_axis.title = key
p.y_axis.title = 'count'
p.padding = [50, 30, 20, 40]
container.add(p)
return container
# Plot update methods
def _update_stage_map(self):
"""Handle switching between different stage maps"""
# Update and equalize bounds for all subplots
self.stage_map.plots['rep'][0].bounds = self.stage_map.bounds
self.stage_map.plots['cov'][0].bounds = self.stage_map.bounds
self.stage_map.plots['centers'][0].bounds = self.stage_map.bounds
# Set visibility flags
if self.stage_map_type is 'representation':
self.stage_map.title = 'Relative Representation'
vis_plots = (True, False, False)
elif self.stage_map_type is 'coverage':
self.stage_map.title = 'Total Stage Coverage'
vis_plots = (False, True, False)
elif self.stage_map_type is 'field_centers':
self.stage_map.title = 'Place Field Centroids'
vis_plots = (False, False, True)
# Toggle plot visibility and redraw
self.stage_map.plots['rep'][0].visible, \
self.stage_map.plots['cov'][0].visible, \
self.stage_map.plots['centers'][0].visible = vis_plots
self.stage_map.request_redraw()
def _update_unit_map(self):
"""Update current image source and title; then redraw the plot"""
self.unit_map.data.set_data('unit_map', self._get_unit_map_data())
self.unit_map.title = '%s of Unit %d'%(self.unit_map_type.capitalize(), self.unit)
self.unit_map.request_redraw()
def _update_unit_values(self):
"""Update the scalar readonly values"""
if self._unit == -1:
self.num_fields = 0
self.coverage = self.avg_area = self.avg_diameter = 0.0
self.max_rate = self.PMap.maxima[self.unit, 2]
else:
self.num_fields = int(self.udata[self._unit]['num_fields'])
self.coverage = float(self.udata[self._unit]['coverage'])
self.avg_area = float(self.udata[self._unit]['avg_area'])
self.avg_diameter = float(self.udata[self._unit]['avg_diameter'])
self.max_rate = float(self.udata[self._unit]['max_r'])
def _get_unit_map_data(self):
"""Helper function to get RGBA array for current unit and map type"""
if self.unit_map_type is 'ratemap':
map_data = array_to_rgba(self.PMap.Map[self.unit], cmap=cm.jet,
norm=False, cmax=self.peak_rate)
elif self.unit_map_type is 'single':
map_data = array_to_rgba(self.PMap.single_maps[self.unit], cmap=cm.hot)
elif self.unit_map_type is 'fields':
map_data = array_to_rgba(self.PMap.coverage_maps[self.unit], cmap=cm.gray)
return map_data
def _get_unit_plots_data(self):
"""Helper function for getting unit data distributions"""
# Integer distribution for number of fields
data = {}
data['num_fields_bins'], data['num_fields'] = integer_hist(self.udata['num_fields'])
# Continuous distributions of other unit statistics
for key in ('avg_area', 'avg_diameter', 'coverage', 'max_r'):
keyb = key + '_bins'
data[key], data[keyb] = S.histogram(self.udata[key], bins=self.unit_bins)
data[keyb] += (data[keyb][1] - data[keyb][0]) / 2
data[keyb] = data[keyb][:-1]
# Add 0-value end-points for polygon display
for key in data:
if key[-4:] == 'bins':
data[key] = N.r_[data[key][0], data[key], data[key][-1]]
else:
data[key] = N.r_[0, data[key], 0]
return data
def _get_field_plots_data(self):
"""Helper function for getting field data distributions"""
# Continuous distributions of place field properties
data = {}
for key in ('area', 'diameter', 'average', 'peak'):
keyb = key + '_bins'
data[key], data[keyb] = S.histogram(self.fdata[key], bins=self.field_bins)
data[keyb] += (data[keyb][1] - data[keyb][0]) / 2
data[keyb] = data[keyb][:-1]
# Add 0-value end-points for polygon display
for key in data:
if key[-4:] == 'bins':
data[key] = N.r_[data[key][0], data[key], data[key][-1]]
else:
data[key] = N.r_[0, data[key], 0]
return data
# Map traits notifications
def _unit_bins_changed(self):
"""Update plot data for unit distributions"""
data = self._get_unit_plots_data()
plot_data = self.unit_data_plots.components[0].data
for key in data:
plot_data.set_data(key, data[key])
def _field_bins_changed(self):
data = self._get_field_plots_data()
plot_data = self.field_data_plots.components[0].data
for key in data:
plot_data.set_data(key, data[key])
def _stage_map_type_changed(self):
self._update_stage_map()
def _unit_map_type_changed(self):
self._update_unit_map()
def _unit_changed(self):
"""Update the unit map and scalar values"""
find_unit = (self.udata['unit'] == self.unit).nonzero()[0]
if find_unit.shape[0]:
self._unit = find_unit[0]
else:
self._unit = -1
self._update_unit_map()
self._update_unit_values()
# Output object default
def _out_default(self):
return CPrint(prefix=self.__class__.__name__, color='purple')