#encoding: utf-8
"""
realign.py -- A two-dimensional sweep of realignment magnitudes and
within-module variances to explore the parameters under which positional
and rate remapping occur.
Written by Joe Monaco, 08/04/2008. Updated 01/21/09.
Copyright (c) 2008 Columbia University. All rights reserved.
"""
# Library imports
from IPython.kernel import client as IPclient
import numpy as N, scipy as S, os, gc
# Package imports
from ..place_network import PlaceNetworkStd
from ..ratemap import CheckeredRatemap
from ..dmec import GridCollection
from ..tools.interp import BilinearInterp2D
from ..tools.string import snake2title
from .sweep import SingleNetworkSweep
from .compare import compare_AB
# Enthought imports
from enthought.traits.api import Enum
from enthought.chaco.api import ArrayPlotData, HPlotContainer, Plot
def run_sample_point(save_file, d_x, d_y):
gc.collect()
# Create the modules index arrays
mods = nmodules
if x_type == 'modules':
mods = int(d_x)
elif y_type == 'modules':
mods = int(d_y)
modules = EC.get_modules(mods, freq_sort=freq_modules)
# Reset grids and activate transforms if necessary
EC.reset()
if 'ellipticity' in (x_type, y_type):
EC.ellipticity = True
if 'zoom' in (x_type, y_type):
EC.zoom = True
# Modulate grid responses according to realignment parameters
for m, m_ix in enumerate(modules):
# Handle x-axis realignment
if x_type == 'shift':
EC.shift(d_x * delta_phi[m], mask=m_ix)
elif x_type == 'rotate':
EC.rotate(d_x * delta_psi[m], mask=m_ix)
elif x_type == 'ellipticity':
EC.ell_mag[m_ix] = d_x * ell_mags[m]
EC.ell_angle[m_ix] = d_x * ell_angles[m]
elif x_type == 'zoom':
EC.zoom_scale[m_ix] = 1 + d_x * (zoom_scales[m] - 1)
# Handle y-axis realignment
if y_type == 'shift':
EC.shift(d_y * delta_phi[m], mask=m_ix)
elif y_type == 'rotate':
EC.rotate(d_y * delta_psi[m], mask=m_ix)
elif y_type == 'ellipticity':
EC.ell_mag[m_ix] = d_y * ell_mags[m]
EC.ell_angle[m_ix] = d_y * ell_angles[m]
elif y_type == 'zoom':
EC.zoom_scale[m_ix] = 1 + d_y * (zoom_scales[m] - 1)
# Simulate and save the realigned spatial map
model = PlaceNetworkStd(EC=EC, W=W, **pdict)
model.advance()
B = CheckeredRatemap(model)
B.compute_coverage()
B.tofile(save_file)
return
class RealignmentSweep(SingleNetworkSweep):
"""
Analyze a 2D random sample of single-trial network simulations across
realignment magnitudes or variances in A-B environment comparisons.
See core.analysis.AbstractAnalysis documentation and collect_data method
signature and docstring for usage.
"""
label = 'Realign Sweep'
display_data = Enum('remapping', 'rate_remapping', 'turnover', 'sparsity',
'stage_coverage', 'stage_repr', 'peak_rate', 'max_rate', 'num_fields',
'coverage', 'area', 'diameter', 'peak', 'average')
map_data = Enum('remapping', 'rate_remapping', 'turnover', 'sparsity',
'stage_coverage', 'stage_repr', 'peak_rate', 'none')
def collect_data(self, x_type='shift', y_type='rotate', x_density=10, y_density=10,
nmodules=1, freq_modules=False, x_max=None, y_max=None, **kwargs):
"""
Store placemap data from a randomly sampled 2D region of parameter space
for realignment magnitudes or variances (spatial phase vs. orientation).
The same network is used for the simulation at each point, and each sample
is compared to a reference (A) spatial map.
Keyword arguments:
x_type -- realignment type along x axis; must be one of 'shift', 'rotate',
'ellipticity', 'zoom', or 'modules' (default 'shift')
y_type -- realignment type along y axis (default 'rotate)
x_density -- number of x_type samples along the defined x_bounds (10)
y_density -- number of y_type samples along the defined y_bounds (10)
nmodules -- number of independent alignment modules; used as max number
of modules if x_type or y_type is set to 'modules'
freq_modules -- whether modules are spatial frequency partitions
x_max -- set upper bound for extent of x_type realignment along x axis;
(shift should be a 2-tuple value)
y_max -- set upper bound for extent of y_type realignment along y axis
"""
# Parse the realignment types
realignment_types = ('shift', 'rotate', 'ellipticity', 'zoom', 'modules')
if x_type not in realignment_types:
raise ValueError, 'invalid realignment type specification (x_type)'
if y_type not in realignment_types:
raise ValueError, 'invalid realignment type specification (y_type)'
# Split cortical population into modules
self.results['nmodules'] = nmodules = int(nmodules)
self.results['freq_modules'] = freq_modules
self.results['x_type'] = x_type
self.results['y_type'] = y_type
# Make data directory
map_dir = os.path.join(self.datadir, 'data')
if not os.path.exists(map_dir):
os.makedirs(map_dir)
# Set default model parameters
pdict = dict( refresh_weights=False,
refresh_phase=False,
refresh_orientation=False
)
pdict.update(kwargs)
# Simulate reference spatial map for environment A
self.out('Simulating reference spatial map...')
EC = GridCollection()
model = PlaceNetworkStd(EC=EC, **pdict)
model.advance()
A = CheckeredRatemap(model)
A.compute_coverage()
A.tofile(os.path.join(map_dir, 'map_A'))
# Setup namespace on ipengine instances
self.out('Setting up ipengines for task-farming...')
mec = self.get_multiengine_client()
tc = self.get_task_client()
mec.clear_queue()
mec.reset()
mec.execute('import gc')
mec.execute('from grid_remap.place_network import PlaceNetworkStd')
mec.execute('from grid_remap.dmec import GridCollection')
mec.execute('from grid_remap.ratemap import CheckeredRatemap')
# Send some network weights, grid configuration and sweep info
self.out('Pushing network weights and grid configuration...')
W = model.W
mec.push(dict( W=model.W,
pdict=pdict,
spacing=EC.spacing,
phi=EC._phi,
psi=EC._psi,
nmodules=nmodules,
freq_modules=freq_modules,
x_type=x_type,
y_type=y_type
))
mec.execute('EC = GridCollection(spacing=spacing, _phi=phi, _psi=psi)')
# Set up modular realignment parameters, pushing data out to engines
self.results['bounds'] = bounds = N.array([[0, 1]]*2, 'd')
density = [x_density, y_density]
r_max = (x_max, y_max)
r_type = (x_type, y_type)
for i in 0, 1:
if r_type[i] == 'shift':
if nmodules == 1 and r_max[i] is not None:
delta_phi = N.array([r_max[i]], 'd')
elif nmodules > 1 and r_max[i] is not None:
delta_phi = N.array(r_max[i], 'd')
else:
grid_scale = None
if freq_modules and r_type[1-i] == 'modules':
grid_scale = 60.0 # cf. lab notebook @ p.147
delta_phi = \
N.array([GridCollection.get_delta_phi(scale=grid_scale)
for m in xrange(nmodules)])
mec.push(dict(delta_phi=delta_phi))
self.results[r_type[i] + '_params'] = delta_phi
self.out('Pushed shift parameters:\n%s'%str(delta_phi))
elif r_type[i] == 'rotate':
if nmodules == 1 and r_max[i] is not None:
delta_psi = N.array([r_max[i]], 'd')
elif nmodules > 1 and r_max[i] is not None:
delta_psi = N.array(r_max[i], 'd')
else:
delta_psi = N.array([GridCollection.get_delta_psi()
for m in xrange(nmodules)])
mec.push(dict(delta_psi=delta_psi))
self.results[r_type[i] + '_params'] = delta_psi
self.out('Pushed rotate parameters:\n%s'%str(delta_psi))
elif r_type[i] == 'ellipticity':
if nmodules == 1 and r_max[i] is not None:
ell_mags = N.array([r_max[i]], 'd')
ell_angles = N.array([0.0])
else:
ell_mags = N.array([GridCollection.get_ellipticity()
for m in xrange(nmodules)])
ell_angles = N.array([GridCollection.get_elliptic_angle()
for m in xrange(nmodules)])
mec.push(dict(ell_mags=ell_mags, ell_angles=ell_angles))
self.results[r_type[i] + '_params'] = \
N.c_[ell_mags, ell_angles]
self.out('Pushed ellipticity parameters:\n' +
'Flattening: %s\nAngles: %s'%(str(ell_mags),
str(ell_angles)))
elif r_type[i] == 'zoom':
if nmodules == 1 and r_max[i] is not None:
zoom_scales = N.array([r_max[i]], 'd')
else:
zoom_scales = N.array([GridCollection.get_zoom_scale()
for m in xrange(nmodules)])
mec.push(dict(zoom_scales=zoom_scales))
self.results[r_type[i] + '_params'] = zoom_scales
self.out('Pushed zoom parameters:\n%s'%str(zoom_scales))
elif r_type[i] == 'modules':
density[i] = nmodules
bounds[i] = 1, nmodules
self.out('Setting up modularity sweep for %d modules'%nmodules)
# Build the sample grid according to specifications
pts_x = N.linspace(bounds[0,0], bounds[0,1], density[0])
pts_y = N.linspace(bounds[1,0], bounds[1,1], density[1])
x_grid, y_grid = N.meshgrid(pts_x, pts_y)
pts = N.c_[x_grid.flatten(), y_grid.flatten()]
self.results['samples'] = pts
# Initialize stage map sample data arrays
nsamples = density[0] * density[1]
self.results['remapping_samples'] = remapping = N.empty(nsamples, 'd')
self.results['rate_remapping_samples'] = rate_remapping = N.empty(nsamples, 'd')
self.results['turnover_samples'] = turnover = N.empty(nsamples, 'd')
self.results['sparsity_samples'] = sparsity = N.empty(nsamples, 'd')
self.results['stage_coverage_samples'] = stage_coverage = N.empty(nsamples, 'd')
self.results['stage_repr_samples'] = stage_repr = N.empty(nsamples, 'd')
self.results['peak_rate_samples'] = peak_rate = N.empty(nsamples, 'd')
self.results['max_rate_samples'] = max_rate = N.zeros(nsamples, 'd')
self.results['num_fields_samples'] = num_fields = N.zeros(nsamples, 'd')
self.results['coverage_samples'] = coverage = N.zeros(nsamples, 'd')
self.results['area_samples'] = area = N.zeros(nsamples, 'd')
self.results['diameter_samples'] = diameter = N.zeros(nsamples, 'd')
self.results['peak_samples'] = peak = N.zeros(nsamples, 'd')
self.results['average_samples'] = average = N.zeros(nsamples, 'd')
# Method for creating interpolated maps of collated data
def interpolate_data(z, pixels=256):
"""Interpolate value z across sample points with *density* points
"""
M = N.empty((pixels,)*2, 'd')
f = BilinearInterp2D(x=pts_x, y=pts_y, z=z)
x_range = N.linspace(bounds[0,0], bounds[0,1], num=pixels)
y_range = N.linspace(bounds[1,1], bounds[1,0], num=pixels)
for j, x in enumerate(x_range):
for i, y in enumerate(y_range):
M[i,j] = f(x, y)
return M
# Execute data collection process for each sample point
tasks = []
for i, p in enumerate(pts):
self.out('Submitting: d_%s = %.2f, d_%s = %.2f'%
(x_type, p[0], y_type, p[1]))
save_file = os.path.join(map_dir, 'map_%03d.tar.gz'%i)
tasks.append(
tc.run(
IPclient.MapTask(run_sample_point,
args=(save_file, float(p[0]), float(p[1])))))
tc.barrier(tasks)
# Collate data return from task farming
for i in xrange(nsamples):
self.out('Loading data from map %d for analysis...'%i)
B = CheckeredRatemap.fromfile(os.path.join(map_dir, 'map_%03d.tar.gz'%i))
# Get field and unit data record arrays
fdata = B.get_field_data()
udata = B.get_unit_data()
# Collate the stage map data
sparsity[i] = B.sparsity
stage_coverage[i] = B.stage_coverage
stage_repr[i] = B.stage_repr
peak_rate[i] = B.peak_rate
# Collate the per-unit data
if udata.shape[0] != 0:
max_rate[i] = udata['max_r'].mean()
num_fields[i] = udata['num_fields'].mean()
coverage[i] = udata['coverage'].mean()
# Collate the per-field data
if fdata.shape[0] != 0:
area[i] = fdata['area'].mean()
diameter[i] = fdata['diameter'].mean()
peak[i] = fdata['peak'].mean()
average[i] = fdata['average'].mean()
# Compute remapping strength from map A
cmp_AB = compare_AB(A, B)
remapping[i] = cmp_AB['remapping']
rate_remapping[i] = cmp_AB['rate_remapping']
turnover[i] = cmp_AB['turnover']
# Create interpolated maps for the collated data
def dot():
self.out.printf('.', color='purple')
self.out('Creating interpolated parameter maps for collected data'); dot()
self.results['remapping'] = interpolate_data(remapping); dot()
self.results['rate_remapping'] = interpolate_data(rate_remapping); dot()
self.results['turnover'] = interpolate_data(turnover); dot()
self.results['sparsity'] = interpolate_data(sparsity); dot()
self.results['stage_coverage'] = interpolate_data(stage_coverage); dot()
self.results['stage_repr'] = interpolate_data(stage_repr); dot()
self.results['peak_rate'] = interpolate_data(peak_rate); dot()
self.results['max_rate'] = interpolate_data(max_rate); dot()
self.results['num_fields'] = interpolate_data(num_fields); dot()
self.results['coverage'] = interpolate_data(coverage); dot()
self.results['area'] = interpolate_data(area); dot()
self.results['diameter'] = interpolate_data(diameter); dot()
self.results['peak'] = interpolate_data(peak); dot()
self.results['average'] = interpolate_data(average); dot()
self.out.printf('\n')
# Good-bye!
self.out('All done!')
def create_plots(self):
"""Create a simple 2D image plot of the parameter sweep"""
# Figure is horizontal container for main plot + colorbar
self.figure = \
container = HPlotContainer(fill_padding=True, padding=25,
bgcolor='linen')
# Data and bounds for main plot
raw_data = self.results[self.display_data]
data = ArrayPlotData(image=self.get_rgba_data(raw_data), raw=raw_data,
x=self.results['samples'][:,0], y=self.results['samples'][:,1])
x_range = tuple(self.results['x_bounds'])
y_range = tuple(self.results['y_bounds'])
bounds = dict(xbounds=x_range, ybounds=y_range)
# Create main plot
p = Plot(data)
p.img_plot('image', name='sweep', origin='top left', **bounds)
p.contour_plot('raw', name='contour', type='line', origin='top left', **bounds)
p.plot(('x', 'y'), name='samples', type='scatter', marker='circle',
color=(0.5, 0.6, 0.7, 0.4), marker_size=4)
# Tweak main plot
p.title = snake2title(self.display_data)
p.x_axis.orientation = 'bottom'
p.x_axis.title = 'Spatial Phase (cm)'
p.y_axis.title = 'Orientation (rads)'
p.plots['samples'][0].visible = self.show_sample_points
# Add main plot and colorbar to figure
container.add(p)
container.add(
self.get_colorbar_plot(bounds=(raw_data.min(), raw_data.max())))
# Set radio buttons
self.unit_data = self.field_data = 'none'
# Convenience function to reorganize results data
def get_module_columns(res, module_dim='y', which='remapping'):
"""Get matrix of columns of line data from results samples to plot
Arguments:
res -- results dict from a completed RealignmentSweep analysis object
module_dim -- set to 'x' or 'y' to specify modularity axis
which -- which data to retrieve ('remapping', 'turnover', etc.)
Returns modules array, sweep (realignment) array, and column data matrix.
"""
pts, data = res['samples'], res[which+'_samples']
# Get the module and sweep information
mod_dim = int(module_dim == 'y')
modules = N.unique(pts[:,mod_dim]).astype('i')
sweep = pts[pts[:,mod_dim] == modules[0], 1-mod_dim]
# Fill the column matrix
lines = N.empty((len(modules), len(sweep)), 'd')
for m,module in enumerate(modules):
pts_ix = (pts[:,mod_dim] == module).nonzero()[0]
lines[:,m] = data[pts_ix]
return modules, sweep, lines