from scipy.io import loadmat
import numpy as np
import matplotlib.pylab as plt
import sys
sys.path.append('../../')
from graphs.my_graph import set_plot, show
# from data_analysis.processing.signanalysis import * #gaussian_smoothing
from scipy.signal import convolve2d
import matplotlib.cm as cm
from dataset import get_dataset
from scipy.optimize import minimize

from scipy.ndimage.filters import gaussian_filter1d
def gaussian_smoothing(signal, idt_sbsmpl=10):
    """Gaussian smoothing of the data"""
    return gaussian_filter1d(signal, idt_sbsmpl)

def get_time_max(t, data, debug=False, Nsmooth=1):
    spatial_average = np.mean(data, axis=0)
    smoothed = gaussian_smoothing(spatial_average, Nsmooth)[:-int(Nsmooth)]
    i0 = np.argmax(smoothed)
    t0 = t[:-int(Nsmooth)][i0]
    if debug:
        plt.plot(t, spatial_average)
        plt.plot(t[:-int(Nsmooth)], smoothed)
        plt.plot([t0], [smoothed[i0]], 'D')
        show()
    return t0

def get_stim_center(time, space, data,
                    Nsmooth=4, debug=False, tmax=0., window=100.):
    """ we smoothe the average over time and take the x position of max signal"""
    temporal_average = np.mean(\
                data[:,(time>tmax-window) & (time<tmax+window)], axis=1)
    smoothed = gaussian_smoothing(temporal_average, Nsmooth)[:-int(Nsmooth)]
    i0 = np.argmax(smoothed)
    x0 = space[:-int(Nsmooth)][i0]
    if debug:
        plt.plot(space, temporal_average)
        plt.plot(space[:-int(Nsmooth)], smoothed)
        plt.plot([x0], [smoothed[i0]], 'D')
        show()
    return x0

def get_data(dataset_index,
             t0=-150, t1=100, debug=False,\
             Nsmooth=2,
             smoothing=None):

    # loading data
    print(get_dataset()[dataset_index])
    delay = 0 #get_dataset()[dataset_index]['delay']
    f = loadmat(get_dataset()[dataset_index]['filename'])
    data = 1e3*f['matNL'][0]['stim1'][0]
    data[np.isnan(data)] = 0 # blanking infinite data
    time = f['matNL'][0]['time'][0].flatten()
    space = f['matNL'][0]['space'][0].flatten()
    if smoothing is None:
        smoothing = np.ones((Nsmooth, Nsmooth))/Nsmooth**2
    smooth_data = convolve2d(data, smoothing, mode='same')
    # smooth_data = data  # REMOVE DATA SMOOTHING
    # apply time conditions
    cond = (time>t0-delay) & (time<t1-delay)
    new_time, new_data = np.array(time[cond]), np.array(smooth_data[:,cond])
    # get onset time
    tmax = get_time_max(new_time, new_data, debug=debug)
    x_center = get_stim_center(new_time, space, new_data, debug=debug,
                               tmax=tmax)
    return new_time-tmax, space-x_center, new_data

def reformat_model_data_for_comparison(model_data_filename,
                                       time_exp, space_exp, data_exp,
                                       model_normalization_factor=None,
                                       with_global_normalization=False,
                                       with_local_normalization=False):
    """

    """
    # loading model and centering just like in the model
    args2, t, X, Fe_aff, Fe, Fi, muVn = np.load(model_data_filename) # we load the data file
    t*=1e3 # bringing to ms
    X -= args2.X_extent/2.+args2.X_extent/args2.X_discretization/2.
    Xcond = (X>=space_exp.min()) & (X<=space_exp.max())
    space, new_muVn = X[Xcond], muVn.T[Xcond,:]
    t -= get_time_max(t, new_muVn)  # centering over time in the same than for data
    
    # let's construct the spatial subsampling of the data that
    # matches the spatial discretization of the model
    exp_data_common_sampling = np.zeros((len(space), len(time_exp)))
    for i, nx in enumerate(space):
        i0 = np.argmin(np.abs(nx-space_exp)**2)
        exp_data_common_sampling[i, :] = data_exp[i0,:]

    # let's construct the temporal subsampling of the model that
    # matches the temporal discretization of the data
    dt_exp = time_exp[1]-time_exp[0]
    model_data_common_sampling = np.zeros((len(space), len(time_exp)))
    for i, nt in enumerate(time_exp):
        i0 = np.argwhere(np.abs(t-nt)<dt_exp)
        if len(i0)>0:
            model_data_common_sampling[:, i] = new_muVn[:, i0[0][0]]
    
    if with_global_normalization:
        if model_normalization_factor is None:
            model_normalization_factor = model_data_common_sampling.max()
        model_data_common_sampling /= model_normalization_factor
        exp_data_common_sampling /= exp_data_common_sampling.max()
    elif with_local_normalization:
        # normalizing by local maximum over time
        for i, nx in enumerate(space):
            model_data_common_sampling[i, :] /= model_data_common_sampling[i,:].max()
            exp_data_common_sampling[i, :] /= exp_data_common_sampling[i,:].max()
            
    return time_exp, space, model_data_common_sampling, exp_data_common_sampling
        

def get_residual(args,
                 new_time, space, new_data,
                 Nsmooth=2,
                 fn='../ring_model/data/example_data.npy',
                 model_normalization_factor=None,
                 with_plot=False):

    new_time, space,\
        model_data_common_sampling,\
        exp_data_common_sampling =\
                reformat_model_data_for_comparison(fn,
                    new_time, space, new_data,
                    model_normalization_factor=model_normalization_factor,
                    with_global_normalization=True)
    
    if with_plot:

        fig, AX = plt.subplots(2, figsize=(4.5,5))
        plt.subplots_adjust(bottom=.23, top=.97, right=.85, left=.3)
        plt.axes(AX[0])
        c = AX[0].contourf(new_time, space, exp_data_common_sampling,
           np.linspace(exp_data_common_sampling.min(), exp_data_common_sampling.max(), args.Nlevels),
                           cmap=cm.viridis)
        plt.colorbar(c, label='norm. VSD',
                     ticks=.5*np.arange(3))
        set_plot(AX[0], xticks_labels=[], ylabel='space (mm)')
        plt.axes(AX[1])

        # to have the zero at the same color level
        factor = np.abs(exp_data_common_sampling.min()/exp_data_common_sampling.max())
        model_data_common_sampling[-1,-1] = -factor*model_data_common_sampling.max()

        c2 = AX[1].contourf(new_time, space, model_data_common_sampling,
          np.linspace(model_data_common_sampling.min(), model_data_common_sampling.max(), args.Nlevels),
                            cmap=cm.viridis)
        
        plt.colorbar(c2, label='norm. $\\delta V_N$',
                     ticks=.5*np.arange(3))
        set_plot(AX[1], xlabel='time (ms)', ylabel='space (mm)')

        if args.save:
            fig.savefig('/Users/yzerlaut/Desktop/temp.svg')
        else:
            show()

    return np.sum((exp_data_common_sampling-model_data_common_sampling)**2)

if __name__=='__main__':

    import argparse
    parser=argparse.ArgumentParser(description=
            """
            """,
            formatter_class=argparse.RawTextHelpFormatter)
    
    parser.add_argument("--Nsmooth", help="for data plots", type=int, default=1)
    parser.add_argument("-s", "--save", help="save fig", action="store_true")
    parser.add_argument("-a", "--analyze", help="analyze", action="store_true")
    parser.add_argument("-p", "--plot", help="plot analysis", action="store_true")
    parser.add_argument("-d", "--debug", help="with debugging", action="store_true")
    parser.add_argument("--space", help="space residual", action="store_true")
    parser.add_argument("--time", help="temporal residual", action="store_true")
    parser.add_argument("--model_filename", '-f', type=str,
                        default='../ring_model/data/example_data.npy')
    parser.add_argument("--data_index", '-df', type=int,
                        default=1)
    parser.add_argument("--t0", type=float, default=-np.inf)
    parser.add_argument("--t1", type=float, default=np.inf)
    parser.add_argument("--Nlevels", type=int, default=20)
    args = parser.parse_args()

    new_time, space, new_data = get_data(args.data_index,
                                         Nsmooth=args.Nsmooth,
                                         t0=args.t0, t1=args.t1,
                                         debug=args.debug)
    if args.space:
        print(get_space_residual(args,
                             new_time, space, new_data,
                             Nsmooth=args.Nsmooth,
                             fn=args.model_filename,
                             with_plot=True))
    elif args.time:
        new_time, space, new_data = get_data(args.data_index,
                                             smoothing=np.ones((1, 4))/4**2,
                                             t0=args.t0, t1=args.t1,
                                             debug=args.debug)
        print(get_time_residual(args,
                                new_time, space, new_data,
                                Nsmooth=2,
                                fn=args.model_filename,
                                with_plot=True))
    else:
        print(get_residual(args,
                             new_time, space, new_data,
                             Nsmooth=args.Nsmooth,
                             fn=args.model_filename,
                             with_plot=True))