#encoding: utf-8
"""
grid.analysis.sweep -- AbstractAnalysis subclass for exploring spatial map properties
by sweeping a 2D region of parameter space with random sampling.
Written by Joe Monaco, 04/30/2008.
Copyright (c) 2008 Columbia University. All rights reserved.
"""
# Library imports
from IPython.kernel import client as IPclient
import numpy as N, scipy as S, os
# Package imports
from .. import PlaceNetworkStd, CheckeredRatemap, GridCollection
from ..core.analysis import AbstractAnalysis
from ..tools.interp import BilinearInterp2D
from ..tools.cmap_ui import ColormapControl
from ..tools.string import snake2title
# Enthought imports
from enthought.traits.api import Enum, Bool, Button
from enthought.traits.ui.api import View, Group, Item, Include
from enthought.chaco.api import ArrayPlotData, HPlotContainer, Plot
from enthought.enable.component_editor import ComponentEditor
def run_sample_point(**kwargs):
gc.collect()
# Handle file save
do_save = False
if 'save_file' in kwargs:
do_save = True
save_file = kwargs['save_file']
del kwargs['save_file']
# Check for pre-existing data to load
if do_save and os.path.exists(save_file):
self.out('Loading found data:\n%s'%save_file)
pmap = CheckeredRatemap.fromfile(save_file)
else:
# Run the simulation and save the results
model = PlaceNetworkStd(W=W, EC=EC, **kwargs)
model.advance()
pmap = CheckeredRatemap(model)
pmap.compute_coverage()
if do_save:
pmap.tofile(save_file)
# Get field and unit data record arrays
fdata = pmap.get_field_data()
udata = pmap.get_unit_data()
# Collate the place map sample data
sample = {}
sample['sparsity'] = pmap.sparsity
sample['stage_coverage'] = pmap.stage_coverage
sample['stage_repr'] = pmap.stage_repr
sample['peak_rate'] = pmap.peak_rate
# Collate the per-unit data
if udata.shape[0] != 0:
sample['max_rate'] = udata['max_r'].mean()
sample['num_fields'] = udata['num_fields'].mean()
sample['coverage'] = udata['coverage'].mean()
else:
sample['max_rate'] = sample['num_fields'] = \
sample['coverage'] = 0.0
# Collate the per-field data
if fdata.shape[0] != 0:
sample['area'] = fdata['area'].mean()
sample['diameter'] = fdata['diameter'].mean()
sample['peak'] = fdata['peak'].mean()
sample['average'] = fdata['average'].mean()
else:
sample['area'] = sample['diameter'] = sample['peak'] = \
sample['average'] = 0.0
return sample
class SingleNetworkSweep(AbstractAnalysis, ColormapControl):
"""
Analyze a 2D random sample of single-trial network simulations across
parameter space.
See core.analysis.AbstractAnalysis documentation and collect_data method
signature and docstring for usage.
"""
label = 'Single Sweep'
save_current_plot = Button
show_sample_points = Bool(True)
#
# These traits must be kept up-to-date with the data made available in
# the field and unit record arrays of PlaceMap:
#
# display_data -- the actual data to display in the figure plot
# map_data -- the subset of per-map data
# unit_data -- the subset of unit-averaged data
# field_data -- the subset of field-averaged data
#
display_data = Enum('sparsity', 'stage_coverage', 'stage_repr',
'peak_rate', 'max_rate', 'num_fields', 'coverage', 'area',
'diameter', 'peak', 'average')
map_data = Enum('sparsity', 'stage_coverage', 'stage_repr',
'peak_rate', 'none')
unit_data = Enum('max_rate', 'num_fields', 'coverage', 'none')
field_data = Enum('area', 'diameter', 'peak', 'average', 'none')
traits_view = \
View(
Group(
Item('figure', label='Data Map', height=450,
editor=ComponentEditor()),
Group(
Group(
Item('map_data', style='custom'),
Item('unit_data', style='custom'),
Item('field_data', style='custom'),
label='Data to Display',
show_border=True),
Group(
Include('colormap_group'),
Group(
Item('show_sample_points'),
label='Samples',
show_border=True),
Item('save_current_plot', show_label=False),
show_border=False),
show_border=False,
orientation='horizontal'),
layout='split',
orientation='vertical',
show_border=False),
title='Single Network Sweep',
kind='live',
resizable=True,
width=0.6,
height=0.8,
buttons=['Cancel', 'OK'])
def collect_data(self, x_density=10, x_bounds=(0.5,8), x_param='J0',
y_density=10, y_bounds=(0,2.5), y_param='phi_lambda', save_maps=True,
**kwargs):
"""Store placemap data from a grid-sampled 2D region of parameter space
The same network and inputs are used for the simulation at each point.
Keyword arguments:
nsamples -- the number of random samples to collect
x_param -- string name of PlaceNetwork parameter to sweep along the x-axis
y_param -- ibid for y-axis
x_bounds -- bounds on sampling the parameter specified by x_param
y_bounds -- ibid for y_param
"""
# Store bounds and sweep parameters
self.results['x_bounds'] = N.array(x_bounds)
self.results['y_bounds'] = N.array(y_bounds)
self.results['x_param'] = x_param
self.results['y_param'] = y_param
# Get ipcontroller clients
mec = self.get_multiengine_client()
tc = self.get_task_client()
# Setup namespace on ipengine instances
self.out('Setting up ipengines for task-farming...')
mec.clear_queue()
mec.reset()
mec.execute('import gc, os')
mec.execute('from grid_remap.place_network import PlaceNetworkStd')
mec.execute('from grid_remap.dmec import GridCollection')
mec.execute('from grid_remap.ratemap import CheckeredRatemap')
# Set default model parameters
pdict = dict( growl=False,
refresh_weights=False,
refresh_orientation=False,
refresh_phase=False
)
pdict.update(kwargs)
# Update with keyword arguments
all_params = PlaceNetworkStd().traits(user=True).keys()
if x_param not in all_params:
raise KeyError, 'x_param (%s) is not a PlaceNetwork parameter'%x_param
if y_param not in all_params:
raise KeyError, 'y_param (%s) is not a PlaceNetwork parameter'%y_param
# Send some network weights and a grid cell object
self.out('Pushing network weights and grid configuration...')
EC = GridCollection()
mec.push(dict(W=PlaceNetworkStd(EC=EC, **pdict).W,
spacing=EC.spacing, phi=EC._phi, psi=EC._psi))
self.out('...and reconstructing grid collection...')
mec.execute(
'EC = GridCollection(spacing=spacing, _phi=phi, _psi=psi)')
# Build the sample grid according to specifications
pts_x = N.linspace(x_bounds[0], x_bounds[1], x_density)
pts_y = N.linspace(y_bounds[0], y_bounds[1], y_density)
x_grid, y_grid = N.meshgrid(pts_x, pts_y)
pts = N.c_[x_grid.flatten(), y_grid.flatten()]
self.results['samples'] = pts
def interpolate_data(z, density=256):
"""Interpolate value z across sample points with *density* points
"""
M = N.empty((density, density), 'd')
x_range = N.linspace(x_bounds[0], x_bounds[1], num=density)
y_range = N.linspace(y_bounds[1], y_bounds[0], num=density)
f = BilinearInterp2D(x=pts_x, y=pts_y, z=z)
for j, x in enumerate(x_range):
for i, y in enumerate(y_range):
M[i,j] = f(x, y)
return M
# Execute data collection process for each sample point
self.out('Initiating task farming...')
save_dir = os.path.join(self.datadir, 'data')
if not os.path.isdir(save_dir):
os.makedirs(save_dir)
tasks = []
for i, p in enumerate(pts):
self.out('Point %d: %s = %.4f, %s = %.4f'%(i, x_param, p[0],
y_param, p[1]))
pdict[x_param], pdict[y_param] = p
if save_maps:
pdict['save_file'] = \
os.path.join(save_dir, 'point_%03d.tar.gz'%i)
tasks.append(
tc.run(
IPclient.MapTask(run_sample_point, kwargs=pdict)))
tc.barrier(tasks)
# Collate sample data returned from map task
samples = [tc.get_task_result(t_id) for t_id in tasks]
# Populate result arrays for interpolation
sparsity = N.array([pt['sparsity'] for pt in samples])
stage_coverage = N.array([pt['stage_coverage'] for pt in samples])
stage_repr = N.array([pt['stage_repr'] for pt in samples])
peak_rate = N.array([pt['peak_rate'] for pt in samples])
max_rate = N.array([pt['max_rate'] for pt in samples])
num_fields = N.array([pt['num_fields'] for pt in samples])
coverage = N.array([pt['coverage'] for pt in samples])
area = N.array([pt['area'] for pt in samples])
diameter = N.array([pt['diameter'] for pt in samples])
peak = N.array([pt['peak'] for pt in samples])
average = N.array([pt['average'] for pt in samples])
# Create interpolated maps for the collated data
def dot():
self.out.printf('.', color='purple')
self.out('Creating interpolated parameter maps for collected data'); dot()
self.results['sparsity'] = interpolate_data(sparsity); dot()
self.results['stage_coverage'] = interpolate_data(stage_coverage); dot()
self.results['stage_repr'] = interpolate_data(stage_repr); dot()
self.results['peak_rate'] = interpolate_data(peak_rate); dot()
self.results['max_rate'] = interpolate_data(max_rate); dot()
self.results['num_fields'] = interpolate_data(num_fields); dot()
self.results['coverage'] = interpolate_data(coverage); dot()
self.results['area'] = interpolate_data(area); dot()
self.results['diameter'] = interpolate_data(diameter); dot()
self.results['peak'] = interpolate_data(peak); dot()
self.results['average'] = interpolate_data(average); dot()
self.out.printf('\n')
# Good-bye!
self.out('All done!')
def create_plots(self):
"""Create a simple 2D image plot of the parameter sweep
"""
# Figure is horizontal container for main plot + colorbar
self.figure = \
container = HPlotContainer(fill_padding=True, padding=25,
bgcolor='linen')
# Convert old data sets to the new generalized style
if 'J0_bounds' in self.results:
self.results['x_bounds'] = self.results['J0_bounds']
self.results['x_param'] = 'J0'
if 'lambda_bounds' in self.results:
self.results['y_bounds'] = self.results['lambda_bounds']
self.results['y_param'] = 'phi_lambda'
# Data and bounds for main plot
raw_data = self.results[self.display_data]
data = ArrayPlotData(image=self.get_rgba_data(raw_data), raw=raw_data,
x=self.results['samples'][:,0], y=self.results['samples'][:,1])
x_range = tuple(self.results['x_bounds'])
y_range = tuple(self.results['y_bounds'])
bounds = dict(xbounds=x_range, ybounds=y_range)
# Create main plot
p = Plot(data)
p.img_plot('image', name='sweep', origin='top left', **bounds)
p.contour_plot('raw', name='contour', type='line', origin='top left', **bounds)
p.plot(('x', 'y'), name='samples', type='scatter', marker='circle',
color=(0.5, 0.6, 0.7, 0.4), marker_size=2)
# Tweak main plot
p.title = snake2title(self.display_data)
p.x_axis.orientation = 'bottom'
p.x_axis.title = snake2title(self.results['x_param'])
p.y_axis.title = snake2title(self.results['y_param'])
p.plots['samples'][0].visible = self.show_sample_points
# Add main plot and colorbar to figure
container.add(p)
container.add(
self.get_colorbar_plot(bounds=(raw_data.min(), raw_data.max())))
# Set radio buttons
self.unit_data = self.field_data = 'none'
# Traits notifications for the interactive GUI
def _cmap_notify_changed(self):
"""Respond to changes to the colormap specification by updating
"""
self._update_figure_plot()
def _save_current_plot_fired(self):
self.save_plots(fmt='png')
def _show_sample_points_changed(self, new):
self.figure.components[0].plots['samples'][0].visible = new
self.figure.request_redraw()
def _update_figure_plot(self):
if self.figure is None:
return
# Update data for the main plot
raw_data = self.results[self.display_data]
main_plot = self.figure.components[0]
main_plot.data.set_data('image', self.get_rgba_data(raw_data))
main_plot.data.set_data('raw', raw_data)
main_plot.title = snake2title(self.display_data)
# Remove old colorbar and add new one
del self.figure.components[1]
self.figure.add(
self.get_colorbar_plot(bounds=(raw_data.min(), raw_data.max())))
self.figure.request_redraw()
def _display_data_changed(self, old, new):
if new in self.results:
self._update_figure_plot()
else:
self.display_data = old
self.out('This analysis does not contain \'%s\' data'%new,
error=True)
def _map_data_changed(self, new):
if new != 'none':
self.unit_data = self.field_data = 'none'
self.display_data = new
def _unit_data_changed(self, new):
if new != 'none':
self.map_data = self.field_data = 'none'
self.display_data = new
def _field_data_changed(self, new):
if new != 'none':
self.unit_data = self.map_data = 'none'
self.display_data = new