#!/usr/bin/env python
""" plot.py - plotting routines for the basic figures
    Copyright (C) 2013 Shane Lee and Stephanie Jones

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
"""

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.gridspec as gs
import numpy as np
import os
import fileio
import specfn
import spikefn

class FigTemplate():
    """ Figure template class for basic figure
    """
    def __init__(self):
        mpl.rc('font', size=8)
        self.f = plt.figure(figsize=(7.5, 10))

        self.gs = {
            'all': gs.GridSpec(8, 100, left=0.15, right=0.95, bottom=0.1, top=0.95),
        }

        self.ax = {
            'raster_L2': self.f.add_subplot(self.gs['all'][:2, :80]),
            'raster_L5': self.f.add_subplot(self.gs['all'][2:4, :80]),
            'dipole_L2': self.f.add_subplot(self.gs['all'][4:5, :80]),
            'dipole_L5': self.f.add_subplot(self.gs['all'][5:6, :80]),
            'spec_dipole': self.f.add_subplot(self.gs['all'][6:, :]),
        }
        self.labels()

    def labels(self):
        self.ax['raster_L2'].set_ylabel('L2 pyramidal (black), fs (red)')
        self.ax['raster_L2'].set_xticklabels([])

        self.ax['raster_L5'].set_ylabel('L5 pyramidal (black), fs (red)')
        self.ax['raster_L5'].set_xticklabels([])

        self.ax['dipole_L2'].set_ylabel('L2 dipole (nAm)')
        self.ax['dipole_L2'].set_xticklabels([])

        self.ax['dipole_L5'].set_ylabel('L5 dipole (nAm)')
        self.ax['dipole_L5'].set_xticklabels([])

        self.ax['spec_dipole'].set_ylabel('Frequency (Hz)')
        self.ax['spec_dipole'].set_xlabel('Time (ms)')

    def close(self):
        plt.close(self.f)

    def save(self, fpng):
        self.f.savefig(fpng, dpi=250)

def plot_simulation(dsub):
    d = os.path.join(os.getcwd(), 'data', dsub)
    f = os.path.join(d, "data.pkl")
    fspikes = os.path.join(d, "spikes.txt")
    fpng = os.path.join(d, "spec.png")

    x = fileio.pkl_load(f)

    # dt is given here in ms
    fs = 1000. / x['p']['dt']

    spikes = spikefn.spikes_from_file(fspikes, x['gid_dict'])

    n = dict.fromkeys(spikes)
    for celltype in spikes:
        n[celltype] = len(spikes[celltype])

    # get total counts
    N_L2 = n['L2_basket'] + n['L2_pyramidal']
    N_L5 = n['L5_basket'] + n['L5_pyramidal']

    yticks = {
        'L2': np.linspace(0, 1, N_L2 + 2),
        'L5': np.linspace(0, 1, N_L5 + 2),
    }

    ind_L2_pyr = np.arange(0, N_L2, 1)[:n['L2_pyramidal']]
    ind_L2_inh = np.arange(0, N_L2, 1)[n['L2_pyramidal']:]

    ind_L5_pyr = np.arange(0, N_L5, 1)[:n['L5_pyramidal']]
    ind_L5_inh = np.arange(0, N_L5, 1)[n['L5_pyramidal']:]

    fig = FigTemplate()

    # L2 spikes
    for i, spk_cell in zip(ind_L2_pyr, spikes['L2_pyramidal']):
        y = yticks['L2'][i] * np.ones(len(spk_cell))
        fig.ax['raster_L2'].scatter(spk_cell, y, marker='|', s=2, color='k')

    for i, spk_cell in zip(ind_L2_inh, spikes['L2_basket']):
        y = yticks['L2'][i] * np.ones(len(spk_cell))
        fig.ax['raster_L2'].scatter(spk_cell, y, marker='|', s=2, color='r')

    # L5 spikes
    for i, spk_cell in zip(ind_L5_pyr, spikes['L5_pyramidal']):
        y = yticks['L5'][i] * np.ones(len(spk_cell))
        fig.ax['raster_L5'].scatter(spk_cell, y, marker='|', s=2, color='k')

    for i, spk_cell in zip(ind_L5_inh, spikes['L5_basket']):
        y = yticks['L5'][i] * np.ones(len(spk_cell))
        fig.ax['raster_L5'].scatter(spk_cell, y, marker='|', s=2, color='r')

    fig.ax['raster_L2'].set_ylim((yticks['L2'][0], yticks['L2'][-1]))
    fig.ax['raster_L5'].set_ylim((yticks['L5'][0], yticks['L5'][-1]))

    # dipole
    fig.ax['dipole_L2'].plot(x['t'], x['dipole_L2'])
    fig.ax['dipole_L5'].plot(x['t'], x['dipole_L5'])

    ylims = fig.ax['dipole_L5'].get_ylim()
    fig.ax['dipole_L2'].set_ylim(ylims)

    pc = specfn.pspec_ax(fig.ax['spec_dipole'], x['fspec'], x['spec'], (x['t'][0], x['t'][-1]))
    fig.f.colorbar(pc, ax=fig.ax['spec_dipole'])

    for axh in fig.ax:
        fig.ax[axh].set_xlim((x['t'][0], x['t'][-1]))

    fig.save(fpng)
    print("Saved file {}".format(fpng))
    fig.close()

if __name__ == '__main__':
    # fig = FigTemplate()
    # fig.save('testing.png')
    d = 'gamma_L5weak_L2weak'
    plot_simulation(d)