# -*- coding: utf-8 -*-
"""
Created on Tue Mar  3 15:21:05 2020

@author: ocalvin
"""

import pandas as pd
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
import matplotlib.patches as mpatches
# Ignore the warning, this is necessary for 3D plotting
from mpl_toolkits.mplot3d import Axes3D
from sklearn.linear_model import LinearRegression

from plotting import raster
from .dpx import analyze_actbmp_timings, analyze_dpx


# Set the default font for figures
matplotlib.rcParams['font.sans-serif'] = "Arial"
matplotlib.rcParams['font.family'] = "sans-serif"
matplotlib.rcParams['pdf.fonttype'] = 42


'''
----------------------------------------------------------------------      
--------------------- Experiment 1: E/I Balance ----------------------
----------------------------------------------------------------------
'''

def cp_raster_plot(spike_dat_coll, *args, **kwargs):
    """
    Creates a raster plot figure for a cue-probe trial.

    Parameters
    ----------
    spike_dat_coll : SpikeDataCollector
        A recording object that records spikes during experiments
        and can load spiking information from csv's.
        
    Try help(raster_plot) for more information on other parameters
    that can be set.
    
    NOTE: Used to create Figure 2's E-G panels.

    Returns
    -------
    fig : matplotlib.figure.Figure
       Returns the raster plot figure, so that it may be saved or
       manipulated.

    """
    # Uses set defaults in case user wants to overwrite them.
    kwargs.setdefault('ylabel', 'Associated Direction (in radians)')
    kwargs.setdefault('yticklabels', 
                      ('0', '$\\pi$/2', '$\\pi$', '3$\\pi$/2', '2$\\pi$'))
    kwargs.setdefault('xlims', 
                      (0, 4000))
    
    # Set the default parameters for the CP trials.
    _cues = {'C' : 0.5, 'P' : 1.5}              # in radians
    _cue_colors = {'C' : 'b', 'P' : 'r'}
    _cue_w = 0.2
    _starts = (500, 3000)
    _ends = (1000, 3500)
    _map_alpha = 0.15
    _highlight_alpha = 0.30

    # Create the dictionaries.
    cue_disp_map = dict()
    cue_highlight = dict()
    _i = 0
    for cue, val in _cues.items():
        cue_disp_map.update(
            {cue: 
             {'ybottom' : int(spike_dat_coll.raster_IDs[-1] * ((val - _cue_w)/2)),
              'ytop' : int(spike_dat_coll.raster_IDs[-1] * ((val + _cue_w)/2)),
              'color' : _cue_colors.get(cue),
              'alpha' : _map_alpha}}
            )
        
        cue_highlight.update({cue: {'start' : _starts[_i],
                                     'end' : _ends[_i],
                                     'alpha' : _highlight_alpha}}
                              )  
        _i += 1
  
    return raster(spike_dat_coll, cue_disp_map, cue_highlight, 
                  *args, **kwargs)    

'''
----------------------------------------------------------------------      
------------ Experiment 2: Interneuron AMPA/NMDA Balance -------------
----------------------------------------------------------------------
'''

def smrz_EI_survey(fldrs, pyr_gs, int_gs, aff_gs):
    '''
    Creates a multipanel summary of network functioning over
    a wide range of parameter values.
    
    NOTE: Used to create Figure 3.

    Parameters
    ----------
    fldrs : list(string)
        List of the data folders that need to be described.
    pyr_gs : list(float)
        List of the pyramidal cell NMDA conductances for each file.
    int_gs : list(float)
        List of the interneuron NMDA conductances for each file.
    aff_gs : list(float)
        List of the afferent AMPA conductances for each file.

    '''
    
    cueS, probeS, cueL, probeL, jump = [], [], [], [], []
    initMed, jumpMed, cueDurMed, probeDurMed = [], [], [], []
    spiralled = []
    
    # gets all of the data timing information for the various files
    for f in fldrs:
        temp = analyze_actbmp_timings(f+'TC_comp.npz', f+'TC.npz')
        cueS.append(temp['cueStart'])
        probeS.append(temp['probeStart'])
        cueL.append(temp['cueLast'])
        probeL.append(temp['probeLast'])
        jump.append(temp['jumped'])
        initMed.append(temp['initMedian'])
        jumpMed.append(temp['jumpMedian'])
        cueDurMed.append(temp['cueDurMedian'])
        probeDurMed.append(temp['probeDurMedian'])
        spiralled.append(temp['spiralled'])
    
    # Set aside the unique afferent g values
    u_affg = np.unique(aff_gs)
    
    # Turns the collected data into numpy arrays and some of the variables 
    #    into percentiles
    cueS, probeS = np.array(cueS) * 100, np.array(probeS) * 100
    cueL, probeL = np.array(cueL) * 100, np.array(probeL) * 100
    jump = np.array(jump) * 100
    initMed, jumpMed = np.array(initMed), np.array(jumpMed)
    cueDurMed, probeDurMed = np.array(cueDurMed), np.array(probeDurMed)
    spiralled = np.array(spiralled) 
    
    # Loop through all of the afferent conditions
    for a in u_affg:
        # create the subsets for this pass
        msk = np.where(aff_gs == a)
        
        # Create functional use map
        fig, ax = plt.subplots()
        norm = plt.Normalize(0, 1)
        _red = np.where(spiralled[msk] != 1, jump[msk]/100, 1)
        _green = np.where(spiralled[msk] != 1, 0, 1) 
        _blue = np.where(spiralled[msk] != 1, ((cueL[msk] - jump[msk])/100), 1)
        _colorMap = np.vstack((_red, _green, _blue)).T
        _colorMap = _colorMap.reshape(np.unique(int_gs[msk]).size,
                                      np.unique(pyr_gs[msk]).size, -3)
        _colorMap = np.flipud(_colorMap)
        
        plt.imshow(_colorMap, alpha=0.9)
        plt.ylabel('Pyramidal NMDA g')
        plt.xlabel('Interneuron NMDA g')
        tick_labels = np.round(np.linspace(np.min(pyr_gs[msk]), 
                                           np.max(pyr_gs[msk]), num = 3),3)
        tick_labels = ['{:0.2f}'.format(x) for x in tick_labels[::-1]]
        plt.yticks(np.linspace(0,np.unique(pyr_gs[msk]).size-1, num = 3), 
                   tick_labels)
        tick_labels = np.round(np.linspace(np.min(int_gs[msk]), 
                                           np.max(int_gs[msk]), num = 3),3)
        tick_labels = ['{:0.2f}'.format(x) for x in tick_labels]
        plt.xticks(np.linspace(0,np.unique(int_gs[msk]).size-1, num = 3), 
                   tick_labels)
        plt.title('Afferent AMPA g ' + str(round(a,4)))
        
        # Creates the legend
        _rpatch = mpatches.Patch(color='red', label='Jumps')
        _bpatch = mpatches.Patch(color='blue', label='Stays')
        _npatch = mpatches.Patch(color='black', label='Fails Prior To')
        plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.10), 
                   handles=[_bpatch, _rpatch, _npatch], ncol=3)
        plt.show()




'''
----------------------------------------------------------------------      
-------------- Experiment 3: DPX Model NMDA Antagonism ---------------
----------------------------------------------------------------------
'''

def dpx_raster_plot(spike_dat_coll, trial_type, *args, **kwargs):
    """
    Creates a raster plot figure for a dpx trial.

    NOTE: Used to create Figures 5A-C and Supplemental 1A-C.
    
    Parameters
    ----------
    spike_dat_coll : SpikeDataCollector
        A recording object that records spikes during experiments
        and can load spiking information from csv's.
    
    trial_type : string
        The type of trial that is being plotted. For example, a
        BX trial would be 'BX'.
        
    Try help(raster_plot) for more information on other parameters
    that can be set.

    Returns
    -------
    fig : matplotlib.figure.Figure
       Returns the raster plot figure, so that it may be saved or
       manipulated.

    """
    # Uses set defaults in case user wants to overwrite them.
    kwargs.setdefault('ylabel', 'Associated Direction (in radians)')
    kwargs.setdefault('yticklabels', 
                      ('0', '$\\pi$/2', '$\\pi$', '3$\\pi$/2', '2$\\pi$'))
    kwargs.setdefault('xlims', 
                      (0, 6600))
    
    # Set the default parameters for the CP trials.
    _cues = {'A' : 0.3, 'B' : 0.7, 'X' : 1.3, 'Y' : 1.7}     # in radians
    _cue_colors = {'A' : 'b', 'B' : 'r', 'X' : 'b', 'Y' : 'r'}
    _cue_w = 0.2
    _starts = (500, 5500)
    _ends = (1500, 6000)
    _map_alpha = 0.15
    _highlight_alpha = 0.30

    # Create the dictionaries.
    cue_disp_map = dict()
    cue_highlight = dict()
    
    for cue, val in _cues.items():
        cue_disp_map.update(
            {cue: 
             {'ybottom' : int(spike_dat_coll.raster_IDs[-1] * ((val - _cue_w)/2)),
              'ytop' : int(spike_dat_coll.raster_IDs[-1] * ((val + _cue_w)/2)),
              'color' : _cue_colors.get(cue),
              'alpha' : _map_alpha}}
            )
    
    _i = 0
    
    for letters in trial_type:
        cue_highlight.update({letters: {'start' : _starts[_i],
                                        'end' : _ends[_i],
                                        'alpha' : _highlight_alpha}}
                             )  
        _i += 1
  
    return raster(spike_dat_coll, cue_disp_map, cue_highlight, 
                    *args, **kwargs)

def DPX_smry(files, values, title='DPX Performance', 
            xlabel='Percent Change', 
            ylabel='Percent Errors', ymax= None):
    '''
    Creates summary images for the DPX performance over a set of conditions.
    
    NOTE: Used to create Figures 5D-F, 8, 10A, and Supplemental 1D-F

    Parameters
    ----------
    files : list(string)
        List of the files that need to be loaded and analyzed. These should
        be DPX files.
    values : list(float)
        Parameter values that are paired with the file and will be used to 
        plot along the x-axis.
    title : string, optional
        The string label for the plot's title. The default is 
        'DPX Performance'.
    xlabel : string, optional
        The string label for the x-axis. The default is 'Percent Reduction'.
    ylabel : string, optional
        the string label for the y-axis. The default is 'Percent Errors'.
    ymax : float, optional
        If specified, then the plot's area will be from -2 to ymax. Otherwise
        it will plot based on the observed values. The default is None.

    Returns
    -------
    The figure.

    '''
    AXvalues = []
    AYvalues = []
    BXvalues = []
    BYvalues = []
    
    for f in files:
        # Analyze the data and add values to the relevant lines
        vals = analyze_dpx(f, plot=False)
        AXvalues.append(vals['AX'] * 100)
        AYvalues.append(vals['AY'] * 100)
        BXvalues.append(vals['BX'] * 100)
        BYvalues.append(vals['BY'] * 100)
    
    # Plot the values    
    fig, ax = plt.subplots()
    ax.plot(values, AXvalues, label='AX', color='black', marker='o')
    ax.plot(values, AYvalues, label='AY', color='blue', linestyle='-.', marker='v')
    ax.plot(values, BXvalues, label='BX', color='red', linestyle='--', marker='s')
    ax.plot(values, BYvalues, label='BY', color='purple', linestyle=':', marker='*')
    
   # Add labels
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    if not ymax is None: plt.ylim(bottom=-2, top=ymax)
    plt.title(title)
    plt.legend(loc=0)
            
    return fig

def DPX_rt_smry(files, values,
                func='median', title='DPX Performance', 
                xlabel='Percent Change', ylabel='Reaction Time (ms)', 
                cue_start=5500):
    '''
    Creates summary images for the agent's reaction times 
    over the various trial types.
    
    NOTE: Used to create Figures 6B-D and 9.

    Parameters
    ----------
    files : list(string)
        List of the files that need to be loaded and analyzed. These should
        be DPX files.
    values : list(float)
        Parameter values that are paired with the file and will be used to 
        plot along the x-axis.
    func : string
        Which function should be used to plot the reaction times. The two
        options are 'mean' and 'median'. The default is 'median'.
    title : string, optional
        The string label for the plot's title. The default is 
        'DPX Performance'.
    xlabel : string, optional
        The string label for the x-axis. The default is 'Percent Reduction'.
    ylabel : string, optional
        the string label for the y-axis. The default is 'Percent Errors'.
    cue_start: float, optional
        The offset for when the agent's response begins.

    Returns
    -------
    The figure.

    '''
    AXvalues = []
    AYvalues = []
    BXvalues = []
    BYvalues = []
    
    if func == 'median':
        fun = np.median
    elif func == 'mean':
        fun = np.mean
    
    for f in files:
        # Analyze the data and add values to the relevant lines
        df = pd.read_csv(f)
        ss = df[np.all(np.stack([(df.cue == 'A'),(df.probe == 'X')]),axis=0)]
        AXvalues.append(fun(ss.rt) - cue_start)
        ss = df[np.all(np.stack([(df.cue == 'A'),(df.probe == 'Y')]),axis=0)]
        AYvalues.append(fun(ss.rt) - cue_start)
        ss = df[np.all(np.stack([(df.cue == 'B'),(df.probe == 'X')]),axis=0)]
        BXvalues.append(fun(ss.rt) - cue_start)
        ss = df[np.all(np.stack([(df.cue == 'B'),(df.probe == 'Y')]),axis=0)]
        BYvalues.append(fun(ss.rt) - cue_start)
    
    # Plot the values    
    fig, ax = plt.subplots()
    ax.plot(values, AXvalues, label='AX', color='black', marker='o')
    ax.plot(values, AYvalues, label='AY', color='blue', linestyle='-.', marker='v')
    ax.plot(values, BXvalues, label='BX', color='red', linestyle='--', marker='s')
    ax.plot(values, BYvalues, label='BY', color='purple', linestyle=':', marker='*')
    
   # Add labels
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
    plt.legend(loc=0)
            
    return fig  

'''
----------------------------------------------------------------------      
----------------------------- SUPPLEMENT -----------------------------
----------------------------------------------------------------------
'''

def cross_study_err_rates(file):
    '''
    Creates a matrix of a set of cross correlations

    NOTE: Used to make Supplemental 2.

    Parameters
    ----------
    file : string
        The data file that is used for the analysis.

    Returns
    -------
    None.

    '''
    
    data = pd.read_csv(file)
    err_types = ['AX', 'BX', 'AY', 'BY']
    xlims = [-0.2, 0.4]
    ylims = [-0.11, 0.2]
    reg_range = np.linspace(xlims[0], xlims[1], 100)
    
    # Create error differences
    data['incomplete'] = np.zeros(len(data), dtype=bool)
    for err in err_types:
        data['diff'+err] = data['Scz'+err] - data['Con'+err]
        data['incomplete'] = np.logical_or(data['incomplete'],
                                           data['diff'+err].isnull()
                                           )
    
    fig, s_plots = plt.subplots(4,4)
    for ind, err in enumerate(err_types):
        y_notnan = data['diff'+err].notna()
        o_y_dat = np.array(data['diff'+err])
        
        for ind2, err2 in enumerate(err_types):
            x_notnan = data['diff'+err2].notna()
            subset = np.logical_and(x_notnan, y_notnan)
            
            x_dat = np.array(data['diff'+err2][~data['incomplete']])
            y_dat = o_y_dat[~data['incomplete']]
            
            
            # plot data points and setup axes
            sp = s_plots[ind, ind2]
            sp.scatter(x_dat, y_dat, color='black')
            x_dat = np.array(data['diff'+err2][subset])
            y_dat = o_y_dat[subset]
            sp.scatter(x_dat, y_dat, color='black', facecolors='none')            
            sp.set_xlim(xlims)
            sp.set_ylim(ylims)
            sp.axhline(y=0, c='gray', lw=0.5, zorder=0, ls='--')
            sp.axvline(x=0, c='gray', lw=0.5, zorder=0, ls='--')            
            if ind == (len(err_types)-1): sp.set_xlabel(err2)
            if ind2 == 0: sp.set_ylabel(err)
            
            # Fit regression
            model = LinearRegression().fit(x_dat.reshape(-1,1), y_dat)
            reg_ys = model.predict(reg_range.reshape(-1,1))
            sp.plot(reg_range, reg_ys, lw=1, color='black')
            r_sq = str(np.round(model.score(x_dat.reshape(-1,1), y_dat),2))
            sp.text(-0.10, 0.32, r'$R^2$ = ' + r_sq)