"""
Short rewriting of the wavelet transform module of SciPy
correspondance: yann.zerlaut@iit.it
"""
import numpy as np
from scipy import signal
def ricker(t, f, t0):
"""
Ricker wavelet of frequency 'f' centered in t0 and over and signal length.
"""
fact = (np.pi**2) * (f**2) * ((t - t0)**2)
y = (1.0 - 2.0 * fact) * np.exp(-fact)
return y
def make_ricker_of_right_size(freq, dt, with_t=False, factor_freq=2.):
"""
returns a ricker of size int(factor_freq*/(freq*dt))
centered in the middle of the array (for use with convolve)
Note factor_freq = 2 covers well the extent of the ricker
"""
tstop = factor_freq/freq
t = np.arange(int(tstop/dt))*dt
if with_t:
return t, ricker(t, freq, t[-1]/2.)
else:
return ricker(t, freq, t[-1]/2.)
def my_cwt(data, frequencies, dt):
"""
Continuous wavelet transform, adapted from:
https://github.com/scipy/scipy/blob/v0.18.1/scipy/signal/wavelets.py#L311-L365
Performs a continuous wavelet transform on `data`,
using the `wavelet` function. A CWT performs a convolution
with `data` using the `wavelet` function, which is characterized
by a frequency parameter.
Parameters
----------
data : (N,) ndarray
data on which to perform the transform.
frequencies : (M,) sequence
Widths to use for transform.
Returns
-------
cwt: (M, N) ndarray
Will have shape of (len(frequencies), len(data)).
width[ii]), mode='same')
"""
output = np.zeros([len(frequencies), len(data)])
for ind, freq in enumerate(frequencies):
wavelet_data = make_ricker_of_right_size(freq, dt)
# the wavelets have different integrals
# conv_number compensates for the number of summed points (i.e. also integral of wavelet)
conv_number = signal.convolve(np.ones(len(data)), np.ones(len(wavelet_data)),
mode='same')
# the sliding mean that depends on the frequency
sliding_mean = signal.convolve(data, np.ones(len(wavelet_data)),
mode='same')/conv_number
# the final convolution
output[ind, :] = signal.convolve(data-sliding_mean, wavelet_data,
mode='same')/conv_number
return output
def illustration_plot(t, freqs, data, coefs, dt, tstop, freq1, freq2, freq3):
"""
a plot to illustrate the output of the wavelet analysis
"""
import matplotlib.pylab as plt
fig = plt.figure(figsize=(12,6))
plt.subplots_adjust(wspace=.8, hspace=.5, bottom=.2)
# signal plot
plt.subplot2grid((3, 8), (0,0), colspan=6)
plt.plot(1e3*t, data, 'k-', lw=2)
plt.ylabel('signal')
for f, tt in zip([freq2, freq1, freq3], [200,500,800]):
plt.annotate(str(int(f))+'Hz', (tt, data.max()))
plt.xlim([1e3*t[0], 1e3*t[-1]])
# time frequency power plot
ax1 = plt.subplot2grid((3, 8), (1,0), rowspan=2, colspan=6)
c = plt.contourf(1e3*t, freqs, coefs, cmap='PRGn', aspect='auto')
plt.xlabel('time (ms)')
plt.ylabel('frequency (Hz)')
plt.yticks([10, 40, 70, 100]);
# inset with legend
acb = plt.axes([.4, .4, .02, .2])
plt.colorbar(c, cax=acb, label='coeffs (a.u.)', ticks=[-1, 0, 1])
# mean power plot over intervals
plt.subplot2grid((3, 8), (1, 6), rowspan=2)
for t1, t2 in zip([0,300e-3,700e-3], [300e-3,700e-3, 1000e-3]):
cond = (t>t1) & (t<t2)
plt.barh(freqs, np.power(coefs[:,cond],2).mean(axis=1)*dt,\
label='t$\in$['+str(int(1e3*t1))+','+str(int(1e3*t2))+']')
plt.legend(prop={'size':'small'}, loc=(0.1,1.1))
plt.yticks([10, 40, 70, 100]);
plt.xticks([]);
plt.xlabel(' mean \n power \n (a.u.)')
# max of power over intervals
plt.subplot2grid((3, 8), (1, 7), rowspan=2)
for t1, t2 in zip([0,300e-3,600e-3], [300e-3,600e-3, 1000e-3]):
cond = (t>t1) & (t<t2)
plt.barh(freqs, np.power(coefs[:,cond],2).max(axis=1)*dt,\
label='t$\in$['+str(int(1e3*t1))+','+str(int(1e3*t2))+']')
plt.yticks([10, 40, 70, 100]);
plt.xticks([]);
plt.xlabel(' max. \n power \n (a.u.)');
return fig
if __name__ == '__main__':
import numpy as np
import matplotlib.pylab as plt
plt.style.use('ggplot')
# temporal sampling
dt, tstop = 1e-4, 1.
t = np.arange(int(tstop/dt))*dt
# ### artificially generated signal, transient oscillations
freq1, width1, freq2, width2, freq3, width3 = 10., 100e-3, 40., 40e-3, 70., 20e-3
data = 3.2+np.cos(2*np.pi*freq1*t)*np.exp(-(t-.5)**2/2./width1**2)+\
np.cos(2*np.pi*freq2*t)*np.exp(-(t-.2)**2/2./width2**2)+\
np.cos(2*np.pi*freq3*t)*np.exp(-(t-.8)**2/2./width3**2)
# ### adding colored noise to test robustness
nl = 0e-2 # noise level
data += nl*np.convolve(np.exp(-np.arange(1000)*dt/400e-3),\
np.random.randn(len(t)), mode='same') # a slow one
data += nl*np.convolve(np.exp(-np.arange(1000)*dt/5e-3),\
np.random.randn(len(t)), mode='same') # a faster one
# Continuous Wavelet Transform analysis
freqs = np.linspace(1, 90, 1e2)
coefs = my_cwt(data, freqs, dt)
illustration_plot(t, freqs, data, coefs, dt, tstop, freq1, freq2, freq3)
plt.show()