from shared import *
from hrtf_analysis import *
from models import *
import gc
class AllPairsModel(object):
'''
Initialise this object with an hrtfset, a cochlear range (cfmin, cfmax, cfN),
a range of gains (gain_max in dB, gain_N) and a range of delays (delay_max,
delay_N),
and optionally:
a model for the coincidence detector neurons (cd_model),
a model for the filter neurons (filtergroup_model).
The __call__ method returns a count (see docstring of that method).
'''
def __init__(self, hrtfset, cfmin, cfmax, cfN,
gain_max, gain_N, delay_max, delay_N,
cd_model=standard_cd_model,
filtergroup_model=standard_filtergroup_model,
):
self.hrtfset = hrtfset
self.cfmin, self.cfmax, self.cfN = cfmin, cfmax, cfN
self.cd_model = cd_model
self.filtergroup_model = filtergroup_model
self.gain_max = gain_max
self.gain_N = gain_N
self.delay_max = delay_max
self.delay_N = delay_N
self.num_indices = num_indices = hrtfset.num_indices
cf = erbspace(cfmin, cfmax, cfN)
# dummy sound, when we run apply() we replace it
sound = Sound((silence(1*ms), silence(1*ms)))
soundinput = DoNothingFilterbank(sound)
# prepare gains filter
m = (gain_N+1)/2
gains_dB = linspace(0, gain_max, m)
gains = 10**(gains_dB/20)
gains = hstack((1/gains[::-1], gains[1:]))
allgains = reshape(gains, (1, 1, gains.size))
def apply_gains(y):
nsamples = y.shape[0]
cfN = y.shape[1]/2
y = reshape(y, (nsamples, 2*cfN, 1))
y1 = y[:, :cfN, :]*allgains
y2 = y[:, cfN:, :]*allgains[:, :, ::-1]
y = hstack((y1, y2))
y = reshape(y, (nsamples, y.size/nsamples))
return y
gfb = Gammatone(Repeat(soundinput, cfN), hstack((cf, cf)))
gains_fb = FunctionFilterbank(gfb, apply_gains)
gains_fb.nchannels = gfb.nchannels*gain_N
compress = filtergroup_model['compress']
cochlea = FunctionFilterbank(gains_fb, lambda x:compress(clip(x, 0, Inf)))
# Create the filterbank group
eqs = Equations(filtergroup_model['eqs'], **filtergroup_model['parameters'])
G = FilterbankGroup(cochlea, 'target_var', eqs,
threshold=filtergroup_model['threshold'],
reset=filtergroup_model['reset'],
refractory=filtergroup_model['refractory'])
# create the synchrony group
cd_eqs = Equations(cd_model['eqs'], **cd_model['parameters'])
cd = NeuronGroup(cfN*gain_N*(delay_N*2-1), cd_eqs,
threshold=cd_model['threshold'],
reset=cd_model['reset'],
refractory=cd_model['refractory'],
clock=G.clock)
# set up the synaptic connectivity
left_delays = hstack((zeros(delay_N-1), linspace(0, float(delay_max), delay_N)))
right_delays = left_delays[::-1]
cd_weight = cd_model['weight']
C = Connection(G, cd, 'target_var', delay=True, max_delay=delay_max)
for i, j, dl, dr in zip(repeat(arange(cfN*gain_N), 2*delay_N-1),
arange(cfN*gain_N*(delay_N*2-1)),
tile(left_delays, cfN*gain_N),
tile(right_delays, cfN*gain_N)):
C[i, j] = cd_weight
C[i+cfN*gain_N, j] = cd_weight
C.delay[i, j] = dl
C.delay[i+cfN*gain_N, j] = dr
self.soundinput = soundinput
self.filtergroup = G
self.synchronygroup = cd
self.synapses = C
self.counter = SpikeCounter(cd)
self.network = Network(G, cd, C, self.counter)
def __call__(self, sound, index=None, **indexkwds):
'''
Apply all pairs filtering group to given sound, which should be a
stereo sound unless you specify the HRTF index, or coordinates of
the HRTF index as keyword arguments, in which case it should be a mono
sound which will have the given HRTF applied to it. You can also
specify index=hrtf. Returns the count of the neurons in the synchrony
group with shape (cfN, gain_N, delay_N*2-1).
'''
hrtf = None
if index is not None:
hrtf = self.hrtfset[index]
elif isinstance(index, HRTF):
hrtf = index
elif len(indexkwds):
hrtf = self.hrtfset(**indexkwds)
if hrtf is not None:
sound = hrtf(sound)
self.soundinput.source = sound
self.network.reinit()
self.filtergroup_model['init'](self.filtergroup,
self.filtergroup_model['parameters'])
self.cd_model['init'](self.synchronygroup, self.cd_model['parameters'])
self.network.run(sound.duration, report='stderr')
count = reshape(self.counter.count,
(self.cfN, self.gain_N, self.delay_N*2-1))
return count
if __name__=='__main__':
from plot_count import ircam_plot_count
hrtfdb = get_ircam()
subject = 1002
hrtfset = hrtfdb.load_subject(subject)
index = randint(hrtfset.num_indices)
cfmin, cfmax, cfN = 150*Hz, 5*kHz, 80
gain_max, gain_N = 8.0, 61
delay_N = 35
delay_max = delay_N/samplerate
# Change this to 10*second for equivalent picture to the paper
sound = whitenoise(200*ms).atlevel(80*dB)
apmodel = AllPairsModel(hrtfset, cfmin, cfmax, cfN,
gain_max, gain_N, delay_max, delay_N)
count = apmodel(sound, index)
# Complicated code to plot the output nicely
freqlabels = array([150*Hz, 1*kHz, 2*kHz, 3*kHz, 4*kHz, 5*kHz])
fig_mew = 1 # marker edge width (in points)
num_indices = hrtfset.num_indices
from scipy.ndimage.filters import *
itd, ild = hrtfset_itd_ild(hrtfset, cfmin, cfmax, cfN)
delays = array([itd[index][i] for i in xrange(cfN)])
gains = array([ild[index][i] for i in xrange(cfN)])
gains = 20*log10(gains)
delays = -array(delays*samplerate, dtype=int)+delay_N-1
arrgains = linspace(-gain_max, gain_max, gain_N)
gains = digitize(gains, 0.5*(arrgains[1:]+arrgains[:-1]))
gains = gain_N-1-gains
def dofig(count, blur=0, blurmode='reflect', freqlabels=None):
count = array(count, dtype=float)
ocount = count
count = copy(ocount)
count.shape = (cfN, gain_N, delay_N*2-1)
count = amax(count, axis=1)
count.shape = (cfN, delay_N*2-1)
subplot(121)
count = gaussian_filter(count, blur, mode=blurmode)
imshow(count, origin='lower left', interpolation='nearest', aspect='auto',
extent=(-float(delay_N/samplerate/msecond), float(delay_N/samplerate/msecond), 0, cfN))
plot((delays-delay_N)/samplerate/msecond, arange(cfN), '+', color=(0,0,0), mew=fig_mew)
plot((argmax(count, axis=1)-delay_N)/samplerate/msecond, arange(cfN), 'x', color=(1,1,1), mew=fig_mew)
axis((float(-delay_N/samplerate/msecond), float(delay_N/samplerate/msecond), 0, cfN))
xlabel('Delay (ms)')
if freqlabels is None:
yticks([])
ylabel('Channel')
else:
cf = erbspace(cfmin, cfmax, cfN)
j = digitize(freqlabels, .5*(cf[1:]+cf[:-1]))
yticks(j, map(str, array(freqlabels, dtype=int)))
ylabel('Channel (Hz)')
subplot(122)
count = copy(ocount)
count.shape = (cfN, gain_N, delay_N*2-1)
count = amax(count, axis=2)
count.shape = (cfN, gain_N)
count = gaussian_filter(count, blur, mode=blurmode)
imshow(count, origin='lower left', interpolation='nearest', aspect='auto')
plot(gains, arange(cfN), '+', color=(0,0,0), mew=fig_mew)
plot(argmax(count, axis=1), arange(cfN), 'x', color=(1,1,1), mew=fig_mew)
axis('tight')
xlabel('Relative gain (dB)')
xticks([0, (gain_N-1)/2, gain_N-1], [str(min(arrgains)), '0', str(max(arrgains))])
if freqlabels is None:
yticks([])
ylabel('Channel')
else:
cf = erbspace(cfmin, cfmax, cfN)
j = digitize(freqlabels, .5*(cf[1:]+cf[:-1]))
yticks(j, map(str, array(freqlabels, dtype=int)))
ylabel('Channel (Hz)')
dofig(count, freqlabels=freqlabels)
figure()
dofig(count, blur=1)#, freqlabels=[500, 1000, 2000, 3000, 4000, 5000])
figure()
dofig(count, blur=2)
show()