import matplotlib.pylab as plt
import numpy as np

def time_freq_plot(t, freqs, data, coefs, xunits='s', yunits=''):
    if xunits=='ms':
        t = 1e3*t
    fig = plt.figure(figsize=(12,6))
    plt.subplots_adjust(wspace=.8, hspace=.5, bottom=.2)
    # signal plot
    plt.subplot2grid((3, 7), (0,0), colspan=6)
    plt.plot(t, data)
    plt.ylabel(yunits)
    plt.xlim([t[0], t[-1]])
    # time frequency power plot
    plt.subplot2grid((3, 7), (1,0), rowspan=2, colspan=6)
    c = plt.contourf(t, freqs, coefs, cmap='PRGn', aspect='auto')
    plt.xlabel('time ('+xunits+')')
    plt.ylabel('frequency (Hz)')
    # mean power plot over intervals
    plt.subplot2grid((3, 7), (1, 6), rowspan=2)
    plt.xlabel('power')
    # max of power over intervals
    plt.subplot2grid((3, 8), (1, 7), rowspan=2)
    plt.barh(freqs, np.power(coefs,2).mean(axis=1),\
             label='mean', height=freqs[-1]-freqs[-2])
    # plt.plot(np.power(coefs,2).max(axis=1), freqs,\
    #          label='max.')
    plt.xlabel(' power')
    plt.legend(prop={'size':'small'}, loc=(0.1,1.1))
    return fig