# -*- coding: utf-8 -*-
"""
Created on Thu Oct 26 17:07:27 2017

@author: porio
"""

import numpy as np
import matplotlib.pyplot as plt
from scipy import stats

def windowedKS(serie,label='data',Lmin=200,maxW=20,plot=True,ax=None,saveplot=False):
    """
    Windowed Kolmogorov-Smirnoff test for data stationarity. Data is divided
    in windows and the windows are compared against each other with a (non-
    parametric) KS test, to see if they follow the same distribution.
    
    Parameters
    ----------
    serie : ndarray-like
        Data series to perform the analysis
    label : string
        Base of the filename to save the plot. Default = 'data'
    Lmin : integer
        Minimum length of windows. Cannot be lower than 1/6 of the whole data lenght.
        Default = 200
    maxW : integer
        Maximum number of windows. Default = 20
    plot : Boolean
        If plot==True, a plot will be generated. Default = True
    ax : matplotlib axis object
        Axis to make the plot. If ax=None, a new Figure and axes will be created.
        Default=None
    saveplot : Boolean
        If True, the plot will be saved in a .png file with the name label +
        "-stat.png". Default: False.
        
    Returns
    -------
    Table : ndarray
        Table of p-values obtained in the comparison of data windows.
        p-values greater than 0.05 indicate that the corresponding windows are
        likely to have the same distribution.
    """
    
        
    ISIs=np.diff(serie)
    N=len(serie)  
    Lmin=min(N/6,Lmin)
    T=np.around(np.max(serie)/2,-1)*2
    Nwindows=np.min((maxW,N/Lmin))
    windowLength=T/Nwindows
    windowlimits=[np.searchsorted(serie,v) for v in np.arange(0,T,windowLength)]
    
    ISIn=np.array([ISIs[i:j] for i,j in zip(windowlimits[:-1],windowlimits[1:])])
    Table=np.array([[stats.ks_2samp(a1,a2) for a1 in ISIn] for a2 in ISIn])
    if plot:
        if ax==None:
            plt.figure(2,figsize=(8,6))
            plt.clf()
            plt.suptitle(label + '   %g spikes'%(len(ISIs)+1))
            ax=plt.subplot(111)
        
        cax=ax.imshow(Table[:,:,1],interpolation='none',extent=[0,T/1000,T/1000,0],vmax=0.15)
        plt.colorbar(cax,extend='max',ticks=(0,0.05,0.1,0.15))
        
        plt.xlabel('Time (s)')
        plt.ylabel('Time (s)')
        ax.text(1.35,0.5,'p-value KS test',rotation='vertical',
                    ha='center',va='center',transform=ax.transAxes)
        if saveplot:
            plt.savefig( label + '-stat.png',dpi=150)
    
    return Table
    
def hurst(series,slope_span = 7):
    """
    R/S Analysis of a time series to obtain the Hurst exponent.
    A single 'H' value is obtained in Hfull[0] with p-value in Hfull[3]
    
    Parameters
    ----------
    series : array
        data to be analyzed
    
    slope_span : integer (default=7)
        half-width of moving windows to calculate moving slopes
    
    Returns
    -------
    N : List, 5*(int(np.log2(len(series)))-4) + 1
         Lenghts of series for which R/S is calculated
    E : List, same lenght of N
         R/S statistics for each value of N
    H : List of ndarrays. Lenght = len(N) - 2*slope_span
         Fit coefficients (slope, intercept) in moving windows of width 2*slope_span
         The first coefficient is centered in slope_span and the last in len(series)-slope_span
    NN : List of arrays. Same length as H
         Lenghts of series employed in the calculation of each of the coefficients in H
    Hfull : LinregressResult, lenght 5
            Full linear regression results between N and E.
            (slope, intercept, rvalue, pvalue, stderr)
    
    """
    
      # slopes will be calculated with 2*slope_span + 1 points
    l=len(series)
    totalint=int(np.floor(np.log2(l/2)))
    intb2=np.arange(3,totalint+0.2,0.2) #log 2 of intervals to be used

    N=np.int32(2**intb2) #intervals to be used
    E=[]
        
    for n in N:
        short_series=np.resize(series,(l//n,n))
        Y = short_series - np.mean(short_series,axis=-1)[:,None]  
        Z = np.cumsum(Y,axis=1)
        R = Z.max(1) - Z.min(1)
        S = np.std(short_series,axis=1)
        E.append(np.mean(R/S))
      
        print('ready',n)
    
    H=[];NN=[]

    for n in range(len(N)-slope_span*2):
        Ndat=N[n:n+slope_span*2+1]
        Edat=E[n:n+slope_span*2+1]
        NN.append(Ndat)
        H.append(np.ma.polyfit(np.log2(Ndat),np.log2(Edat),1))
    
    Hfull=stats.linregress(np.log(N),np.log(E))
        
    return N,E,H,NN,Hfull

def DFA(series,slope_span = 7):
    """
    Detrended Fluctuation Analysis (DFA) of a time series.
    A single 'H' value is obtained in DFfull[0] with p-value in Hfull[3]
    
    Parameters
    ----------
    series : array
        data to be analyzed
    
    slope_span : integer (default=7)
        half-width of moving windows to calculate moving slopes
    
    Returns
    -------
    N : List, 5*(int(np.log2(len(series)))-4) + 1
         Lenghts of series for which DF is calculated
    F : List, same lenght of N
         Detrended Fluctuation for each value of N
    H : List of ndarrays. Lenght = len(N) - 2*slope_span
         Fit coefficients (slope, intercept) in moving windows of width 2*slope_span
         The first coefficient is centered in slope_span and the last in len(series)-slope_span
    NN : List of arrays. Same length as H
         Lenghts of series employed in the calculation of each of the coefficients in H
    DFfull : LinregressResult, lenght 5
            Full linear regression results between N and F.
            (slope, intercept, rvalue, pvalue, stderr)
    
    """

    l=len(series)
    totalint=int(np.floor(np.log2(l/2)))
    intb2=np.arange(3,totalint+0.2,0.2) #log 2 of intervals to be used

    N=np.int32(2**intb2) #intervals to be used
    F=[]
        
    for n in N:
        Xfit=np.arange(n)
        short_series=np.resize(series,(l//n,n)).T
        Y = np.cumsum(short_series,axis=0)
        lscoef=np.polyfit(Xfit,Y,1)
        Fn = np.sqrt(np.sum((Y-lscoef[0,:]*Xfit[:,None]-lscoef[1,:])**2,0)/n)
        F.append(np.mean(Fn))
      
        print('ready',n)

    DFfull=stats.linregress(np.log(N),np.log(F))

    H=[];NN=[]
    for n in range(len(N)-slope_span*2):
        Ndat=N[n:n+slope_span*2+1]
        Fdat=F[n:n+slope_span*2+1]
        NN.append(Ndat)
        H.append(np.ma.polyfit(np.log2(Ndat),np.log2(Fdat),1))

    return N,F,H,NN,DFfull

def hurstS(series,slope_span = 7,mul=100):
    """
    R/S and DFA for Surrogate data.
    For building a single confidence interval, use the slopes in HfS[:,0] (R/S) and DFfS[:,0] (DFA)
        
    Parameters
    ----------
    series : array
        data to be analyzed
    
    slope_span : integer (default=7)
        half-width of moving windows to calculate moving slopes
        
    mul : integer (default 100)
        Number of surrogate random series.
    
    Returns
    -------
    surrHm : ndarray. Lenght = len(N) - 2*slope_span
        Mean of the R/S coefficients for moving slopes. The first coefficient is centered in slope_span and the last in len(series)-slope_span
    surrHsd : ndarray. Lenght = len(N) - 2*slope_span
        Standard deviation the R/S coefficients for moving slopes.
    surrDFm : ndarray. Lenght = len(N) - 2*slope_span
        Mean of the DFA coefficients for moving slopes. The first coefficient is centered in slope_span and the last in len(series)-slope_span
    surrDFsd : ndarray. Lenght = len(N) - 2*slope_span
        Standard Deviation of the DFA coefficients for moving slopes. The first coefficient is centered in slope_span and the last in len(series)-slope_span
    HfS : 2-D ndarray, shape (mul,5)
        Full linear regression results for the R/S on each of the surrogate series. 
        (slope, intercept, rvalue, pvalue, stderr)
    DFfS : 2-D ndarray, shape (mul,5)
        Full linear regression results for the DFA on each of the surrogate series. 
        (slope, intercept, rvalue, pvalue, stderr)        
    """
    
    
      # slopes will be calculated with 2*slope_span + 1 points
    l=len(series)
    totalint=int(np.floor(np.log2(l/2)))
    intb2=np.arange(3,totalint+0.2,0.2) #log 2 of intervals to be used
    
    seriesSurr=[np.random.permutation(series) for i in range(mul)]

    N=np.int32(2**intb2) #intervals to be used
    E=[]
    F=[]
        
    for n in N:
        Xfit=np.arange(n)
        short_series=np.array([np.resize(serSurr,(l//n,n)) for serSurr in seriesSurr])
        
        Y = short_series - np.mean(short_series,axis=-1)[:,:,None]
        Z = np.cumsum(Y,axis=-1)
        R = Z.max(-1) - Z.min(-1)
        S = np.std(short_series,axis=-1)
        E.append(np.mean(R/S,-1))
        
        Yf = np.cumsum(short_series,axis=-1)
        Yf = np.rollaxis(Yf,2,1)
        lscoef=np.array([np.polyfit(Xfit,Yfm,1) for Yfm in Yf])
        Fn = np.sqrt(np.sum((Yf-lscoef[:,[0],:]*Xfit[None,:,None]-lscoef[:,[1],:])**2,1)/n)
        F.append(np.mean(Fn,-1))
      
        print('ready surr',n)
        
    E=np.asarray(E).T
    F=np.asarray(F).T
    
    Hfull=np.array([stats.linregress(np.log(N),np.log(y)) for y in E])
    DFfull=np.array([stats.linregress(np.log(N),np.log(y)) for y in F])
    
    H=[];NN=[];Hs=[]
    for n in range(len(N)-slope_span*2):
        idx=range(n,n+slope_span*2+1)
        NN.append(N[idx])
        H.append([np.ma.polyfit(np.log2(N[idx]),np.log2(Edat[idx]),1)[0]
                    for Edat in E])
        Hs.append([np.ma.polyfit(np.log2(N[idx]),np.log2(Fdat[idx]),1)[0]
                    for Fdat in F])

    return np.mean(H,-1),np.std(H,-1),np.mean(Hs,-1),np.std(Hs,-1),Hfull,DFfull