import numpy as np
import matplotlib as mpl
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
from scipy import signal
import random

import pickle
import os

from mfm import MFM

def fig_3():
    def get_data():
        """
        Loads data from file, or generates data if no file found
        """
        #First, look for saved data
        path = 'data/fig_3.npz'
        if os.path.isfile(path):
            print('Loading data from cache')
            data_dict = pickle.load(open(path,'rb'))

            data = data_dict['data']
            t    = data_dict['t']
            dt   = data_dict['dt']

        #If no data, generate data
        else:
            print('Generating data for Figure 3')
            conditions = [{'DD':False},
                          {'DD':True},
                          {'DD':True,'cDBS':True,'cDBS_amp':4.13}]

            for c in conditions:
                c['tstop'] = 100

            mfms = []
            data = []
            for i in range(3):
                mfms.append(MFM(**conditions[i]))
                mfms[i].run()

                dt = mfms[i].params['dt']
                time_series = mfms[i].S[:,mfms[i].struct['p2']]
                time_series = np.split(time_series,5)[-1] #get last 5th
                time_series -= np.mean(time_series)
                t = np.arange(len(time_series))
                t = t*dt

                data.append(time_series)

            data_dict = {
                'data' : data,
                't'    : t,
                'dt'   : dt
            }

            if not os.path.isdir(os.path.dirname(path)):
                os.makedirs(os.path.dirname(path))
            pickle.dump(data_dict,open(path,'wb'))

        return data_dict, data, t, dt

    def plot(data, t, dt):
        # Figures
        # Timeseries Figure
        #-----------------------------------------------------------------------
        # Create figure and subplots
        fig = plt.figure(figsize=(10,4))

        gs = gridspec.GridSpec(1,2)

        gs0 = gridspec.GridSpecFromSubplotSpec(3,1, subplot_spec=gs[0])
        ax0 = [plt.subplot(gs0[0,0])]
        ax0.append(plt.subplot(gs0[1,0],sharey=ax0[0]))
        ax0.append(plt.subplot(gs0[2,0],sharey=ax0[0]))

        gs1 = gridspec.GridSpecFromSubplotSpec(2,1, subplot_spec=gs[1],hspace=0.5)
        ax1 = [plt.subplot(gs1[0,0])]
        ax1.append(plt.subplot(gs1[1,0],sharex=ax1[0],sharey=ax1[0]))

        c = 'k'
        ax0[2].set_xlabel('time (s)')
        ax0[1].set_ylabel('LFP (mV)')

        ts_legends=['naive','DD','cDBS']
        for i in range(len(data)):
            ax0[i].plot(t,data[i],label=ts_legends[i],c=c)
            leg = ax0[i].legend(bbox_to_anchor=(1.02,1.3),loc='upper right', handlelength=0, handletextpad=0, borderpad=0, frameon=False)
            for item in leg.legendHandles: item.set_visible(False)

            ax0[i].set_xlim((10,10.5))
            ax0[i].set_ylim((-2.5,2.5))
            ax0[i].set_yticks([-2,0,2])

        # Remove spines
        for axes in ax0[:-1]:
            axes.spines['bottom'].set_visible(False)
            axes.xaxis.set_ticks([])

        # PSD Figure
        #------------------------------------------------------------------------
        ax1[1].set_xlabel('frequency (Hz)')

        for i in range(len(data)):
            fmax = 100    
            f,Pxx_den = signal.welch(data[i],1/dt,nperseg=4096)
            Pxx_den = 10*np.log10(Pxx_den**2)
            if i == 0:
                ax1[0].plot(f[f<fmax],Pxx_den[f<fmax],color='C0')
            elif i == 1:
                ax1[0].plot(f[f<fmax],Pxx_den[f<fmax],color='C3')
                ax1[1].plot(f[f<fmax],Pxx_den[f<fmax],color='C3')
            else:
                ax1[1].plot(f[f<fmax],Pxx_den[f<fmax],color='C2')


        ax1[0].set_xlim((0,100))

        x = ax1[0].figbox.bounds[0]
        y_upper = ax1[0].figbox.bounds[1] + ax1[0].figbox.bounds[3]
        y_lower = ax1[1].figbox.bounds[1]
        y_center = (y_upper - y_lower) / 2

        ax1[0].set_ylabel('power (dB/Hz)     ',ha='right')

        gs.tight_layout(fig,w_pad=2)

        # Add legends
        ax1[0].legend(['naive','DD'], frameon=False, borderpad=0, bbox_to_anchor=(1.05,1.05))
        ax1[1].legend(['DD','cDBS'],  frameon=False, borderpad=0, bbox_to_anchor=(1.05,1.05))

        # Add subfigure labels (a, b)
        fontdict={'size': 'large',
                  'weight' : 'bold'}
        fig.text(0.0, 1, 'a', fontdict=fontdict, verticalalignment='top')
        fig.text(0.5, 1, 'b', fontdict=fontdict, verticalalignment='top')
                
            
    data_dict, data, t, dt = get_data()
    
    plot(data, t, dt)

def main():
    random.seed(0)
    fig_3()
    plt.show()

if __name__ == '__main__':
    main()