#encoding: utf-8
"""
grid.analysis.altmodels -- Analysis simulating model variants for comparison
Exports: ModelComparison
Written by Joe Monaco, 02/05/2011.
Copyright (c) 2011 Johns Hopkins University. All rights reserved.
"""
# Library imports
from scipy.stats import sem
import os, numpy as np
import matplotlib as mpl
import matplotlib.pylab as plt
# Package imports
from ..place_network import PlaceNetworkStd
from ..core.analysis import AbstractAnalysis
from ..tools.images import array_to_image
from ..ratemap import CheckeredRatemap
from ..dmec import GridCollection
from .compare import compare_AB
from .map_funcs import get_tuned_weights
class ModelComparison(AbstractAnalysis):
"""
Load a standard simulation from pre-existing data (or simulate a new map)
and then simulate several model variants to compare place fields size
and location differences.
See core.analysis.AbstractAnalysis documentation and collect_data method
signature and docstring for usage.
"""
label = "alt models"
def collect_data(self, load_dir=None, alpha=0.3, gamma=1.0, rec_tuned=False):
"""Run a standard simulation and then variants using the same network
Keyword arguments:
load_dir -- if loading pre-existing network, set directory here
alpha -- learning parameter for tuned weights (get_tuned_weights)
gamma -- gain of recurrent excitation (based on overlap)
rec_tuned -- whether recurrent variant is based on tuned output (True)
or the standard output (False)
Set save_maps to True to save the spatial maps for the sample.
"""
self.results['model_types'] = ('std', 'fwd', 'tuned', 'rec')
if load_dir is not None:
if not os.path.isdir(load_dir):
raise ValueError, 'invalid load directory'
self.results['load_dir'] = os.path.abspath(load_dir)
self.out('Loading network from\n%s...'%self.results['load_dir'])
os.chdir(load_dir)
l = np.load
EC = GridCollection(
_phi=l('phi.npy'), _psi=l('psi.npy'), spacing=l('spacing.npy'))
model = PlaceNetworkStd(EC=EC, W=l('W.npy'), refresh_weights=False)
os.chdir(self.datadir)
else:
self.out('Creating new grid inputs and place network...')
EC = GridCollection()
model = PlaceNetworkStd(EC=EC)
W = model.W
def get_norms(M):
return np.sqrt((M**2).sum(axis=0))
def store_data(prefix, pmap):
udata = pmap.get_unit_data()
fdata = pmap.get_field_data()
self.results['%s_sparsity'%prefix] = pmap.sparsity
self.results['%s_num_fields'%prefix] = udata['num_fields']
self.results['%s_area'%prefix] = fdata['area']
self.results['%s_diameter'%prefix] = fdata['diameter']
self.results['%s_x'%prefix] = fdata['x']
self.results['%s_y'%prefix] = fdata['y']
if not os.path.exists('%s_map.tar.gz'%prefix):
pmap.tofile('%s_map'%prefix)
return
# Get input strength map
self.out('Computing grid input strengths...')
EC_R = EC.get_z_stack()
EC_norms = get_norms(EC_R)
np.save('EC_norms.npy', EC_norms)
array_to_image(EC_norms, 'EC_norms.png', cmap=mpl.cm.gray_r)
array_to_image(EC_norms, 'EC_norms_jet.png', cmap=mpl.cm.jet)
# Run the standard simulation
if not os.path.exists('std_map.tar.gz'):
self.out('Running standard simulation...')
model.advance()
pmap = CheckeredRatemap(model)
else:
self.out('Loading standard simulation data...')
pmap = CheckeredRatemap.fromfile('std_map.tar.gz')
store_data('std', pmap)
std_num_active = pmap.num_active
self.out('Standard active units = %d'%std_num_active)
R = pmap.Map
array_to_image(get_norms(R), 'std_norms.png', cmap=mpl.cm.gray_r)
array_to_image(get_norms(R), 'std_norms_jet.png', cmap=mpl.cm.jet)
def sparsity_match_threshold(Map):
self.out('Searching for sparsity-matching threshold...')
N, H, W = Map.shape
I = np.empty((N,), 'd')
for i in xrange(N):
I[i] = Map[i].max()
# Test activity peaks as thresholds to find sparsity-matching threshold
I.sort()
R_ = np.empty(Map.shape, 'd') # probe workspace
thresh = 0
for i in xrange(N):
R_[:] = Map # reset
Rmax = R_.max()
num_active = 0
for j in xrange(N):
if (R_[j].max()>0.2*Rmax):
if (R_[j]>0.2*R_[j].max()).sum() > 50:
num_active += 1
self.out.printf('%d '%num_active)
if num_active < std_num_active:
self.out.printf('\n')
self.out('... sparsity match at %.4f ...'%thresh)
break
thresh = I[i] # get next peak
R_ -= thresh # and apply test threshold
R_[R_<0] = 0
del R_
if num_active >= std_num_active:
self.out.printf('\n')
if thresh:
Map -= thresh
Map[Map<0] = 0
return
# Run feedforward inhibition simulation
if not os.path.exists('fwd_map.tar.gz'):
self.out('Computing feedforward model variant...')
R[:] = 0 # using R matrix as a spatial map workspace
for i in xrange(model.N_CA):
R[i] = model.beta * (W[i].reshape(model.N_EC, 1, 1) * EC_R).sum(axis=0)
# Feedforward inhibition as sparsity-matching threshold
sparsity_match_threshold(R)
pmap.reset()
pmap.compute_coverage()
self.out('Feedforward active units = %d'%pmap.num_active)
else:
self.out('Loading feedforward model data...')
pmap = CheckeredRatemap.fromfile('fwd_map.tar.gz')
R = pmap.Map
array_to_image(get_norms(R), 'fwd_norms.png', cmap=mpl.cm.gray_r)
store_data('fwd', pmap)
# Run associatively tuned simulation
if not os.path.exists('tuned_map.tar.gz'):
self.out('Running input tuned simulation (alpha = %.2f)...'%alpha)
model.W = get_tuned_weights(
CheckeredRatemap.fromfile('std_map.tar.gz'), W, EC, alpha,
grow_synapses=True)
model.reset()
model.advance()
pmap = CheckeredRatemap(model)
pmap.compute_coverage()
self.out('Tuned active units = %d'%pmap.num_active)
else:
self.out('Loading input tuned model data...')
pmap = CheckeredRatemap.fromfile('tuned_map.tar.gz')
R = pmap.Map
array_to_image(get_norms(R), 'tuned_norms.png', cmap=mpl.cm.gray_r)
store_data('tuned', pmap)
# Run recurrent excitation simulation
if not os.path.exists('rec_map.tar.gz'):
# Construct the E-E weight matrix
self.out('Constructing E-E weight matrix...')
if rec_tuned:
self.out('--> Using input-tuned output as base')
else:
self.out('--> Using standard output as base')
pmap = CheckeredRatemap.fromfile('std_map.tar.gz')
R = pmap.Map
N, H, W = R.shape
J = np.zeros((N, N), 'd')
for i in xrange(N):
for j in xrange(i+1, N):
J[i,j] = J[j,i] = gamma * \
(pmap.single_maps[i] * pmap.single_maps[j]).sum()
if J[i,j] > 0:
J[i,j] = J[j,i] = J[i,j] / \
min(pmap.single_maps[i].sum(),
pmap.single_maps[j].sum())
# Add in first-order recurrent excitation across the map
self.out('Adding first-order recurrent excitation to map...')
for i in xrange(H):
for j in xrange(W):
R[:,i,j] += np.dot(R[:,i,j], J) # feedforward
R[:,i,j] += np.dot(R[:,i,j], J) # feedback
# Feedforward threshold to maintain activity level
sparsity_match_threshold(R)
pmap.reset()
pmap.compute_coverage()
self.out('Recurrent active units = %d'%pmap.num_active)
else:
self.out('Loading recurrent model data...')
pmap = CheckeredRatemap.fromfile('rec_map.tar.gz')
R = pmap.Map
array_to_image(get_norms(R), 'rec_norms.png', cmap=mpl.cm.gray_r)
store_data('rec', pmap)
# Good-bye!
self.out('All done!')
def create_plots(self, legend=False):
# Move into data directoary and start logging
os.chdir(self.datadir)
self.out.outfd = file('figure.log', 'w')
# Set up main figure for plotting
self.figure = {}
figsize = 8, 10
plt.rcParams['figure.figsize'] = figsize
self.figure['altmodels'] = f = plt.figure(figsize=figsize)
f.suptitle(self.label.title())
# Load data
data = self.results
models = data['model_types']
getval = lambda pre, k: data[pre + '_' + k]
# Log some data
def print_mean_sem(value, arr):
if type(arr) is float:
self.out('%s = %.4f'%(value, arr))
else:
self.out('%s = %.4f +/- %.4f'%(value, arr.mean(), sem(arr)))
for prefix in models:
for val in ('sparsity', 'num_fields', 'area', 'diameter'):
key = prefix + '_' + val
print_mean_sem(key, data[key])
# Draw place fields as circles
def draw_circle_field_plots(ax, prefix):
x = getval(prefix, 'x')
y = getval(prefix, 'y')
d = getval(prefix, 'diameter')
nfields = len(x)
ax.plot(x, y, 'k+', ms=6, aa=False)
for i in xrange(nfields):
ell = mpl.patches.Ellipse((x[i], y[i]), d[i], d[i],
fill=False, lw=1, ec='k')
ell.clip_box = ax.bbox
ax.add_artist(ell)
ax.axis("image")
ax.set_xlim(0, 100)
ax.set_ylim(0, 100)
ax.set_title(prefix)
return ax
# Render place field plots
rows = 3
cols = 2
for i,prefix in enumerate(models):
draw_circle_field_plots(plt.subplot(rows, cols, i+1), prefix)
# Statistics plot
ax = plt.subplot(rows, cols, 5)
markers = "ods^"
for i,prefix in enumerate(models):
a = getval(prefix, 'area')
nf = getval(prefix, 'num_fields')
ax.errorbar(a.mean(), nf.mean(), xerr=sem(a), yerr=sem(nf),
fmt=markers[i], ecolor='k', elinewidth=1, capsize=4,
ms=6, mfc='k', mec='k', mew=1)
# ax.set_ylim(1, 2)
# ax.set_xlim(xmax=245)
ax.set_xlabel('area')
ax.set_ylabel('num. fields')
# Remapping data
if os.path.exists('remapping.npy'):
self.out('Loading remapping/turnover values...')
remapping, turnover = np.load('remapping.npy')
else:
self.out('Computing remapping/turnover measures...')
pmaps = [CheckeredRatemap.fromfile('%s_map.tar.gz'%p) for p in models]
remapping = []
turnover = []
for pm in pmaps[1:]:
cmpAB = compare_AB(pmaps[0], pm)
remapping.append(cmpAB['remapping'])
turnover.append(cmpAB['turnover'])
np.save('remapping.npy', np.array([remapping, turnover]))
self.out('Remapping: %s'%str(remapping))
self.out('Turnover: %s'%str(turnover))
# Set up bar plot data
ax = plt.subplot(rows, cols, 6)
left = []
height = []
xticklabels = models[1:]
bar_w = 1/float(len(xticklabels))
c = 0
for i in xrange(len(xticklabels)):
left.extend([c-bar_w, c])
height.extend([remapping[i], turnover[i]])
c += 1
# Render the bar chart and legend
bar_cols = mpl.cm.gray(([0.25, 0.6])*c)
bar_h = ax.bar(left, height, width=bar_w,
ec='k', color=bar_cols, linewidth=0, ecolor='k', aa=False)
if legend:
ax.legend(bar_h[:2], ['Remapping', 'Turnover'], loc=1)
ax.hlines(1.0, xmin=-0.5, xmax=c-0.5, linestyle=':', color='k')
ax.set_xlim(-0.5, c-0.5)
ax.set_ylim(0.0, 1.1)
ax.set_xticks(np.arange(c))
ax.set_xticklabels(xticklabels)
plt.draw()
plt.rcParams['figure.figsize'] = plt.rcParamsDefault['figure.figsize']
self.out.outfd.close()