# -*- coding: utf-8 -*-
"""
Created on Thu Oct 26 19:09:25 2017

@author: porio
"""


import numpy as np
import matplotlib.pyplot as plt
import kpss
import tests
import scipy.stats
import matplotlib.ticker
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

def AllPlots(trace,spikes,time1,label=""):
    ISIs=np.diff(spikes)

#%%  analises
    kpssval,kpssP,_=kpss.kpssTest(ISIs)
    if kpssP==0.5:
        kpsstext='pvalue > 0.5'
    elif kpssP==0.001:
        kpsstext='pvalue < 0.001'
    else:
        kpsstext='pvalue = %g'%kpssP
    
    sspan=7
    N,E,H,NN,Hfull=tests.hurst(ISIs,sspan)  #R/S analysis
    N,F,Hf,NN,DFfull=tests.DFA(ISIs,sspan)  #DFA analysis
    
    #R/S and DFA on surrogate data:
    surrHm,surrHsd,surrDFm,surrDFsd,surrHf,surrDFf=tests.hurstS(ISIs,sspan,100)  
    
    #p-values of the overall fits in R/S and DFA analyses
    pvalHf = scipy.stats.norm.sf(Hfull[0],loc=np.mean(surrHf[:,0]),scale=np.std(surrHf[:,0]))
    pvalDFf = scipy.stats.norm.sf(DFfull[0],loc=np.mean(surrDFf[:,0]),scale=np.std(surrDFf[:,0]))
    
    #%% PLOTS
    
    plt.figure(figsize=(12,9))
    plt.clf()
    
    ax1=plt.subplot(331) # V and Z traces
    
    timeW=np.where((time1>800)*(time1<1000))[0]
    
    if len(trace.shape)>1:
        ax1.plot(time1[timeW],trace[timeW,0])
        ax1b=ax1.twinx()
        ax1b.plot(time1[timeW],trace[timeW,1],'g')
        zmin,zmax=(np.around(min(trace[:,1]),2),np.around(max(trace[:,1]),2))
        ax1.set_ylim((-0.5,1.05))
        ax1b.set_ylim((zmin,4*zmax))
        ax1b.set_yticks((zmin,(zmin+zmax)/2,zmax))
        ax1b.yaxis.set_label_coords(-0.1,0.2)
        ax1b.set_ylabel("Z",color='g',weight='bold')
    else:
        ax1.plot(time1[timeW],trace[timeW])
    ax1.set_yticks((0,0.5,1))
    ax1.set_xlabel("Time (ms)")
    ax1.set_ylabel("V",weight='bold')
    
    ax2=plt.subplot(334)   # ISI plot
    plt.plot(np.array(spikes[1:])/1000,ISIs,'.',ms=1)
    ax2.set_xlabel("Time (s)")
    ax2.set_ylabel("ISI (ms)")
    
    ax3=plt.subplot(337)   #Windowed KS test
    tests.windowedKS(spikes,label="stoch V - Det Z",maxW=15,ax=ax3,plot=True)
    
    ax4=plt.subplot(332)    #R/S statistics plot
    ax4.loglog(N,E,'.')
    ax4.plot(N,N**Hfull.slope*np.exp(Hfull.intercept),'r-',lw=1)    
    plt.ylabel("mean rescaled range\n $<RS(n)>$")
    [plt.plot(dat,2**(np.log2(dat)*h0+h1)) for (h0,h1),dat in zip(H,NN)]
    #    plt.plot(dat,2**(np.log2(dat)*h[0]+h[1]))
    plt.text(0.05,0.95,'H=%g\nr=%g\np=%0.3g'%(Hfull.slope,Hfull.rvalue,pvalHf),
             transform=ax4.transAxes,va='top')
       
    plt.xlim(xmin=5)
    xmin,xmax=plt.xlim()
    
    plt.subplot(335)     #R/S moving (windowed) slopes plots
    plt.semilogx(N[sspan:-sspan],np.array(H)[:,0],'.')
    plt.semilogx(N[sspan:-sspan],surrHm,'r')
    plt.fill_between(N[sspan:-sspan],surrHm+2*surrHsd,surrHm-2*surrHsd,alpha=0.5)
    
    plt.xlabel("length of sequence ($n$)")
    plt.ylabel("slope ($H$ value)")
    
    plt.xlim((xmin,xmax))
    
    ax5=plt.subplot(333)   #DFA plot
    #    plt.loglog(N,E,'.')        
    plt.loglog(N,F,'.') 
    plt.plot(N,N**DFfull.slope*np.exp(DFfull.intercept),'r-',lw=1)       
    plt.ylabel("Detrended Fluctuation")
    [plt.plot(dat,2**(np.log2(dat)*h0+h1)) for (h0,h1),dat in zip(Hf,NN)]
    #    plt.plot(dat,2**(np.log2(dat)*h[0]+h[1]))
    plt.text(0.05,0.95,'H=%g\nr=%g\np=%0.3g'%(DFfull.slope,DFfull.rvalue,pvalDFf),
             transform=ax5.transAxes,va='top')
    
    plt.xlim(xmin=5)
    xmin,xmax=plt.xlim()
    
    plt.subplot(336)   #DFA moving (windowed) slopes plot
    plt.semilogx(N[sspan:-sspan],np.array(Hf)[:,0],'.')
    plt.semilogx(N[sspan:-sspan],surrDFm,'r')
    plt.fill_between(N[sspan:-sspan],surrDFm+2*surrDFsd,surrDFm-2*surrDFsd,alpha=0.5)
    plt.xlabel("length of sequence ($n$)")
    plt.ylabel("slope ($H$ value)")
    plt.xlim((xmin,xmax))
    
    ax8=plt.subplot(338)   #Autocorrelation of ISIs in log-linear plot
    plt.acorr(ISIs,maxlags=50,usevlines=False,marker='.')
    plt.yscale('log')    
    y1,y2=plt.ylim()
    ymin=min(y1,0.89)
    ax8.set_ylim((ymin,y2))
    ax8.yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter('%g'))
    ax8.yaxis.set_minor_formatter(matplotlib.ticker.FormatStrFormatter('%g'))
    plt.xlim((0,50))
    plt.xlabel("Delay")
    plt.ylabel("Correlation (log)")
    
    ax8i=inset_axes(ax8,"35%","50%",loc=1)
    ax8i.plot(ISIs[1:],ISIs[:-1],'.',ms=1)
    ax8i.set_xlabel(r"$ISI_{n+1}$")
    ax8i.set_ylabel(r"$ISI_{n}$")
    
    ax9=plt.subplot(339)   #Autocorrelation of ISIs in log-linear plot
    plt.acorr(ISIs,maxlags=50,usevlines=False,marker='.')
    plt.yscale('log')
    y1,y2=plt.ylim()
    ymin=min(y1,0.89)
    plt.ylim((ymin,y2))
    ax9.yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter('%g'))
    ax9.yaxis.set_minor_formatter(matplotlib.ticker.FormatStrFormatter('%g'))
    plt.xscale('log')
    plt.xlabel("Delay (log)")
    plt.ylabel("Correlation (log)")
    
    ax9i=inset_axes(ax9,"35%","50%",loc=1)
    ax9i.loglog(ISIs[1:],ISIs[:-1],'.',ms=1)
    ax9i.set_xlabel(r"$ISI_{n+1}$")
    ax9i.set_ylabel(r"$ISI_{n}$")
    
    
    
    plt.suptitle(label + '\n%g spikes   KPSS LEVEL=%g  '%(len(spikes)+1,kpssval) + kpsstext,y=0.99)
    
    plt.tight_layout(rect=(0,0,1,0.97))