"""
The analysis functions for EEE stimulation data.
1. Plateau Amplitude
2. Plateau Duration
3. Spike Numbers on plateau
4. Interspike intervals
5. EPSPs amplitude and time stamp

Note: new analysis functions are added on August 1st, 2018
Plateu amplitude and duration measurements in basal dendrites with soma

Author: Peng Penny Gao
penggao.1987@gmail.com

Contributors:
salvadordura@gmail.com
joe.w.graham@gmail.com
"""

import json
import matplotlib.pyplot as plt
import os
import numpy as np
import pandas as pd
import seaborn as sns

########################################
### Function: to measure plateau duration
########################################
def meas_platdur(data, thresh = 10, dt = 0.025):
    """Measures plateau duration (ms)

    Parameters:
    -----------
        data: voltage trace
        thresh (mV): measure the duration of the plateau which are bigger than baseline + thresh
        dt: the sampling rate (default is 0.025ms or 40kHz)
    Return:
    -----------
        platur (ms): plateau duration in ms
    """
    # Make sure dt = 0.025 and there is no stimulation for the first 50 ms
    # If there is anything before 50ms, change the time points for baseline
    baseline = np.mean(data[4000:6000])
    stable = data[6000:]
    above = [val for val in stable if val > (baseline + thresh)]
    platdur = dt * len(above)
    return platdur

########################################
### Function: to calculate the number of Spikes
########################################
def spike_count(data, thresh = 0):
    """Measure the number of Spikes

    Parameters:
    -----------
    data: list
        The electrical recording data from soma in the CA229 model.
        It was recorded in h.Vector() originally, then saved as list into json
        file duirng the batch processing of simulation.

    thresh: int
        The threshold to detect spikes. By default: if the membrane potential
        is larger than 0 mV, it will be detected as a spike.

    Return:
    -----------
    count: int
        Number of spikes in the data.
    """
    spike_flag = False
    count = 0
    for idx, val in enumerate(data[:-1]):
        if spike_flag == False and val < thresh and data[idx+1] >= thresh:
            spike_flag = True
        elif spike_flag == True and val >= thresh and data[idx+1] < thresh:
            count += 1
            spike_flag = False
    return count

########################################
### Function: to calculate spike interval
########################################
def IST_spikes (data, dt = 0.025):
    """
    Get the interspike intervals, spike_idx, spike_mvalue and spike_midx
    Return:
    -----------
    spike_mvalue:
        The list of minimum values between each two spikes
    spike_midx:
        The list of timepoints of minimum values
    ISTs:
        The average time of interspike intervals
    """
    spikes = get_EPSPs(data, 60, dt)
    IST = []
    count = len(spikes)
    spike_mvalue = []
    spike_midx = []
    if count <= 1:
        ISTs = 0
    else:
        for i in range(count-1):
            IST.append(spikes[i+1][0]-spikes[i][0])
            idx1 = int(round((spikes[i][0])/dt))
            idx2 = int(round((spikes[i+1][0])/dt))
            temp = np.min(data[idx1:idx2])
            spike_mvalue.append(temp)
            spike_midx.append(np.argmin(data[idx1:idx2]) + idx1)

        ISTs = np.mean(IST)
    return ISTs , spike_mvalue, spike_midx

########################################
### Function: to analyze EPSPs
########################################
def get_EPSPs(data, thresh = 2, dt = 0.025):
    """
    Get the maximum index and value of each EPSP peak
    Return:
    -----------
    spikes:
        A list of turple.
        Each turple has the index of max EPSP, and value of max EPSP.
    """
    baseline = np.mean(data[4000:6000])
    stable = data[6000:]
    EPSPs = []
    for idx, val in enumerate(stable[:-1]):
        if (val >= (baseline + thresh)) and (val > stable[idx-1]) and (val >= stable[idx+1]):
            time = (idx + 6000) * dt
            EPSPs.append((time, val))
    return EPSPs

########################################
### Function: to analyze spikes
########################################
def single_spike(data, dt = 0.025):
    """
    Get the maximum index and value of each EPSP peak
    Return:
    -----------
    spikes:
        A list of turple.
        Each turple has the index of max EPSP, and value of max EPSP.
    """
    baseline = np.mean(data[4000:6000])
    stable = data[6000:]
    peak_v = np.amax(stable) - baseline
    peak_t = (np.argmax(stable) + 6000) * dt
    return peak_v, peak_t

######################################
# Get the value index in data when the value is the cloest to target value
def get_closest (data, target):
    """Get the index of value in data which is closest to target.
    """
    t = data.index(min(data, key=lambda x:abs(x-target)))
    return t

########################################
### Function: to measure plateau amplitude
########################################
def meas_platamp(data, dt = 0.025):
    """Measures plateau amplitude (average of voltage - baseline while volt
        trace is above baseline + thresh)"""
    baseline = np.mean(data[4000:6000])
    stable = data[6000:]
    # above = [val for val in stable if val > (baseline + thresh)]
    spike_num = spike_count(stable)
    if spike_num == 0:
        platamp = max(stable) - baseline
        platdur = 0
        ISI = 0
    elif spike_num == 1: # Maybe filtering would be better?
        spikegap = 5 # ms to skip after spike
        idx = data.index(max(data)) + int(spikegap/dt)
        platamp = data[idx] - baseline
        ISI = 0
    else:
        ISI, spike_mvalue, spike_midx = IST_spikes(data, dt)
        platamp = spike_mvalue[-1] - baseline
    return ISI , platamp

########################################
### Function: to measure plateau amplitude for TTX condition
########################################
def TTX_platamp(data):
    baseline = np.mean(data[4000:6000])
    stable = data[6000:]
    above = [idx for idx, val in enumerate(stable) if val > (baseline+15)]
    # platamp = max(stable) - baseline
    if len(above) <= 15:
        platamp = max(stable) - baseline
    else:
        idx1 = above[0]
        idx2 = above[-1]
        platamp = stable[int(0.75*idx2 + 0.25*idx1)] - baseline
    return platamp

########################################
### Function: to measure the plateau ampltiude and duration
###           in soma and basal dendrites where the inputs are located.
### Using the time stamp for soma plateau amplitude measurement to determine
### the plateau ampltiude in dendrite.
########################################
def soma_platamp_TTX(data):
    baseline = np.mean(data[4000:6000])
    stable = data[6000:]
    above = [idx for idx, val in enumerate(stable) if val > (baseline+15)]
    # platamp = max(stable) - baseline
    if len(above) <= 15:
        platamp = np.max(stable) - baseline
        idx = np.argmax(stable)
    else:
        idx1 = above[0]
        idx2 = above[-1]
        idx = int(0.75*idx2 + 0.25*idx1)
        platamp = stable[idx] - baseline
    return idx+6000, platamp

def soma_platdur_TTX(data, dt = 0.025):
    baseline = np.mean(data[4000:6000])
    stable = data[6000:]
    idx, amp = soma_platamp_TTX(data)
    above = [idx for idx, val in enumerate(stable) if val > (baseline+amp/2.0)]
    platdur = dt * len(above)
    return platdur

def TTX_dend_plat(data, idx, dt = 0.025):
    baseline = np.mean(data[4000:6000])
    stable = data[6000:]
    platamp = data[idx] - baseline
    threshold = platamp*0.5
    above = [val for val in stable if val > (baseline + threshold)]
    platdur = dt * len(above)
    return platamp, platdur

def soma_plat(data, dt = 0.025):
    """Measures plateau amplitude (average of voltage - baseline while volt
        trace is above baseline + thresh)"""
    baseline = np.mean(data[4000:6000])
    stable = data[6000:]
    spike_num = spike_count(stable)
    if spike_num == 0:
        platamp = max(stable) - baseline
        platdur = 0
        idx = data.index(max(data))
    elif spike_num == 1: # Maybe filtering would be better?
        spikegap = 5 # ms to skip after spike
        idx = data.index(max(data)) + int(spikegap/dt)
        platamp = data[idx] - baseline
    else:
        ISI, spike_mvalue, spike_midx = IST_spikes(data, dt)
        platamp = spike_mvalue[-1] - baseline
        idx = spike_midx[-1]
    return idx, platamp

def dend_plat(data, idx, dt = 0.025):
    """Measures plateau amplitude (average of voltage - baseline while volt
        trace is above baseline + thresh)"""
    baseline = np.mean(data[4000:6000])
    stable = data[6000:]
    platamp = data[idx] - baseline
    threshold = platamp*0.5
    above = [val for val in stable if val > (baseline + threshold)]
    platdur = dt * len(above)
    return platamp, platdur

def v_curr_inj(data):
    baseline = np.mean(data[4000:6000])
    stable = data[6000:]
    amp = np.max(stable) - baseline
    return amp

#######################################
# Color
#######################################
# Scale the RGB values to the [0, 1] range, which is the format matplotlib accepts.
def tableau (num = 0):
    "The number has to be smaller than 21."
    tableau21 = [(31, 119, 180), (174, 199, 232), (255, 127, 14), (255, 187, 120),
                 (44, 160, 44), (152, 223, 138), (214, 39, 40), (255, 152, 150),
                 (148, 103, 189), (197, 176, 213), (140, 86, 75), (196, 156, 148),
                 (227, 119, 194), (247, 182, 210), (127, 127, 127), (199, 199, 199),
                 (188, 189, 34), (219, 219, 141), (23, 190, 207), (158, 218, 229),
                 (0, 0, 0)]
    for i in range(len(tableau21)):
        r, g, b = tableau21[i]
        tableau21[i] = (r / 255., g / 255., b / 255.)
    return tableau21[num]