# encoding: utf-8
"""
session.py -- Post-hoc container for computing and storing VMOModel results
Exported namespace: VMOSession
Copyright (c) 2011 Johns Hopkins University. All rights reserved.
This software is provided AS IS under the terms of the Open Source MIT License.
See http://www.opensource.org/licenses/mit-license.php.
"""
# Library imports
import os
import numpy as np
from glob import glob
from numpy import pi
from scipy.signal import hilbert
# Package imports
from .vmo import VMOModel
from .double_rotation import VMODoubleRotation
from .tools.radians import radian, get_angle_histogram
from .tools.bash import CPrint
from .tools.path import unique_path
from .tools.filters import halfwave, circular_blur
from .tools.array_container import TraitedArrayContainer
# Traits imports
from enthought.traits.api import Trait, Instance, Array, Float, Int, false
class VMOSession(TraitedArrayContainer):
"""
A container for a completed VMOModel simulation object that automatically
computes input signal envelopes, median thresholds, a population matrix
and place cell spatial information for each trial.
"""
out = Instance(CPrint)
params = Trait(dict)
center = Array
dt = Float
trial = Int(1)
num_trials = Int
num_units = Int
angle = Array
alpha = Array
t = Array
x = Array
y = Array
laps = Array
N_laps = Int
E = Array(desc='units x track response matrix')
E_laps = Array(desc='units x track x laps response matrix')
thresh = Float
R = Array
R_laps = Array
I_rate = Array(desc='spatial information')
sortix = Array(desc='active unit index sorted by max responses')
active_units = Array(desc='active unit index')
is_mismatch = false(desc='whether this is a mismatch session')
mismatch = Int
# Firing rate smoothing parameters
default_blur_width = Float(4.3)
bins = Int(360)
# Saving time-series data
I_cache = Array(desc='saved input time-series')
E_cache = Array(desc='saved envelope time-series')
save_I = false(desc='save input time-series')
save_E = false(desc='save envelope time-series')
# Activity threshold for counting a unit as active
min_spike_count = Trait(0.05, Float, desc='min. fraction pop. max')
def __init__(self, model, **kwargs):
super(VMOSession, self).__init__(**kwargs)
try:
if not model.done:
raise ValueError, "model simulation must be completed"
except AttributeError:
raise ValueError, "argument must be a VMOModel object"
# Get basic data about this model simulation/session
self.trait_set( params=model.parameter_dict(),
num_trials=model.num_trials,
num_units=model.N_outputs,
dt=model.dt,
center=model.center)
if self.num_trials == 1:
pm = model.post_mortem()
else:
pm = model.post_mortem().get_trial_data(self.trial)
if hasattr(model, 'mismatch'):
self.is_mismatch = True
self.mismatch = int((180/pi)*model.mismatch[self.trial-1])
self.trait_set(alpha=pm.alpha, x=pm.x, y=pm.y, t=pm.t)
# Compute envelopes, thresholds, and population responses
self.out('Computing responses for trial %d of %d...'%(self.trial,
self.num_trials))
self._compute_envelopes(pm.I, model.track)
self._set_threshold()
self.compute_circle_responses()
self.out('Done!')
def _compute_envelope_timeseries(self, I):
"""Compute raw time-series signal envelopes of oscillatory drive from
synaptic drive matrix (timesteps x units).
"""
# Compute signal envelope via Hilbert transform
E = np.empty((I.shape[1], I.shape[0]), 'd')
for i in xrange(I.shape[1]):
H = hilbert(I[:,i])
Hhat = np.real(H).mean()
Hr = np.real(H) - Hhat # detrend real part
E[i] = np.sqrt(Hr**2 + np.imag(H)**2) + Hhat
# Cache the time-series before track binning if specified
if self.save_I:
self.out('Warning: saving input cache')
self.I_cache = I.T
if self.save_E:
self.out('Warning: saving envelopes cache')
self.E_cache = E
return E
def _compute_envelopes(self, I_theta, track):
"""Compute the amplitude envelopes for each of the output units
Required arguments:
I_theta -- synaptic drive time-series matrix for all outputs
track -- CircleTrackData object containing trajectory data
Session and per-lap envelope matrices are computed.
"""
# Compute envelope time-series
E = self._compute_envelope_timeseries(I_theta)
# Reduce envelope data to binned track angle histogram
t, alpha = self.t, self.alpha
angle = np.linspace(0, 2*pi, self.bins+1)
self.angle = angle[:-1]
# Get completed laps
lap_times = track.elapsed_time_from_timestamp(track.laps)
lap_ix = (lap_times<=self.t[-1]).nonzero()[0].max()
self.laps = lap_times[:lap_ix]
self.N_laps = lap_ix - 1 # only including *complete* laps, last lap is always incomplete
# Compute track responses: session- and lap-averages
self.E = np.zeros((self.num_units, self.bins), 'd')
self.E_laps = \
np.zeros((self.num_units, self.bins, self.N_laps), 'd')
for b in xrange(self.bins):
ix = np.logical_and(
alpha >= angle[b], alpha < angle[b+1]).nonzero()[0]
if len(ix):
self.E[:,b] = E[:,ix].mean(axis=1)
for lap in xrange(self.N_laps):
ix = reduce(np.logical_and,
[alpha >= angle[b], alpha < angle[b+1],
t >= self.laps[lap], t < self.laps[lap+1]]).nonzero()[0]
if len(ix):
self.E_laps[:,b,lap] = E[:,ix].mean(axis=1)
def _set_threshold(self):
"""Compute median peak inputs as an activity threshold
"""
self.thresh = np.median(self.E.max(axis=1))
def compute_circle_responses(self):
"""Top-level function to recompute the population matrix, information
rates and active place units.
"""
self._compute_population_matrix()
self._compute_spatial_information()
self._set_active_units()
def _compute_population_matrix(self):
"""Compute radial place field ratemaps for each output unit
"""
self.R = halfwave(self.E - self.thresh)
self.R_laps = halfwave(self.E_laps - self.thresh)
def _compute_spatial_information(self):
"""Compute overall spatial information for each output unit
Calculates bits/spike as (Skaggs et al 1993):
I(R|X) = (1/F) * Sum_i[p(x_i)*f(x_i)*log_2(f(x_i)/F)]
"""
self.I_rate = np.empty(self.num_units, 'd')
occ = get_angle_histogram(
self.x-self.center[0], self.y-self.center[1], self.bins)
occ *= self.dt # convert occupancy to seconds
p = occ/occ.sum()
for i in xrange(self.num_units):
f = self.R[i]
F = halfwave(self.E[i]-self.thresh).mean()
I = p*f*np.log2(f/F)/F
I[np.isnan(I)] = 0.0 # handle zero-rate bins
self.I_rate[i] = I.sum()
def _set_active_units(self):
"""Apply minimal firing rate threshold to determine which active units
are active.
"""
self.active_units = (
self.R.max(axis=1) >= self.min_spike_count*self.R.max()
).nonzero()[0]
self.sortix = self.active_units[
np.argsort(np.argmax(self.R[self.active_units], axis=1))]
def get_spatial_information(self, unit=None):
"""Get overall spatial information for the population or a single unit
"""
return np.squeeze(self.I_rate[unit])
def get_population_matrix(self, bins=None, norm=False, clusters=None,
smoothing=True, blur_width=None, inplace=False):
"""Retrieve the population response matrix for this session simulation
Keyword arguments:
bins -- recompute responses for a different number of bins (deprecated)
norm -- whether to integral normalize each unit's response
clusters -- optional index array for row-sorting the response matrix;
if not specified, a peak-location sort of the place-active subset
of the population is used by default
smoothing -- whether to do circular gaussian blur on ratemaps
blur_width -- width of gaussian window to use for smoothing; a value of
None defaults to default_blur_width
Returns (units, bins) matrix of population spatial responses.
"""
self.compute_circle_responses()
if clusters is None:
clusts = self._get_active_units()
elif type(clusters) in (np.ndarray, list):
clusts = np.asarray(clusters)
if inplace:
R = self.R[clusts]
else:
R = self.R[clusts].copy()
if smoothing:
if blur_width is None:
blur_width = self.default_blur_width
for Runit in R:
Runit[:] = circular_blur(Runit, blur_width)
if norm:
Rsum = np.trapz(R, axis=1).reshape(R.shape[0], 1)
Rsum[Rsum==0.0] = 1
R /= Rsum
return R
def get_population_lap_matrix(self, clusters=None, smoothing=True,
blur_width=None, inplace=False, **kwargs):
"""Construct concatentation of per-lap population response matrices
Keyword arguments:
clusters -- optional index array for row-sorting the response matrix;
if not specified, a peak-location sort of the place-active subset
of the population is used by default
smoothing -- whether to do circular gaussian blur on ratemaps
blur_width -- width of gaussian window to use for smoothing; a value of
None defaults to default_blur_width
Returns (N_clusts, bins, N_laps) response matrix.
"""
self.compute_circle_responses()
if clusters is None:
clusts = self._get_active_units()
elif type(clusters) in (np.ndarray, list):
clusts = np.asarray(clusters)
if inplace:
R = self.R_laps[clusts]
else:
R = self.R_laps[clusts].copy()
if smoothing:
if blur_width is None:
blur_width = self.default_blur_width
for Runit in R:
for Rlap in Runit.T:
Rlap[:] = circular_blur(Rlap, blur_width)
return R
def recover_cues(self):
"""Simulate a dummy model with identical cue configuration as was used
to create this session data. A post-mortem object is returned that can
be plotted using (e.g.) rat.oi_funcs.plot_external_cues.
"""
pdict = dict(
N_theta = 1,
N_outputs = 1,
monitoring = False,
N_cues_local = self.params['N_cues_local'],
N_cues_distal = self.params['N_cues_distal'],
local_cue_std = self.params['local_cue_std'],
distal_cue_std = self.params['distal_cue_std'],
refresh_fixed_points = False
)
if self.is_mismatch:
pdict.update(mismatch=[(np.pi/180)*self.mismatch])
klass = VMODoubleRotation
else:
klass = VMOModel
model = klass(**pdict)
model.advance()
return model.post_mortem()
@classmethod
def get_session_list(cls, model, **kwargs):
"""Convenience method to get a list of VMOSession objects for the
trials in a model object.
"""
res = []
if model.num_trials == 1:
res = VMOSession(model, **kwargs)
else:
res = []
for trial in xrange(1, model.num_trials+1):
res.append(VMOSession(model, trial=trial, **kwargs))
return res
@classmethod
def save_session_list(cls, session_list, save_dir):
"""Save all sessions in an experiment to the specified directory
"""
if not os.path.exists(save_dir):
os.makedirs(save_dir)
for session in session_list:
if session.is_mismatch:
if session.mismatch == 0:
fn = 'STD.tar.gz'
else:
fn = 'MIS_%03d.tar.gz'%session.mismatch
session.tofile(os.path.join(save_dir, fn))
else:
fn = unique_path(os.path.join(save_dir, 'session_'),
ext='tar.gz')
session.tofile(fn)
@classmethod
def load_session_list(cls, load_dir):
"""Load all sessions from files found in the specified load directory
"""
files = glob(os.path.join(load_dir, '*.tar.gz'))
files.sort()
return [cls.fromfile(fn) for fn in files]
def _get_active_units(self):
"""Get the list of active place units
"""
return self.sortix
def _out_default(self):
return CPrint(prefix=self.__class__.__name__)