# -*- coding: utf-8 -*-
"""
Created on Thu Oct 26 19:09:25 2017
@author: porio
"""
import numpy as np
import matplotlib.pyplot as plt
import kpss
import tests
import scipy.stats
import matplotlib.ticker
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
def AllPlots(trace,spikes,time1,label=""):
ISIs=np.diff(spikes)
#%% analises
kpssval,kpssP,_=kpss.kpssTest(ISIs)
if kpssP==0.5:
kpsstext='pvalue > 0.5'
elif kpssP==0.001:
kpsstext='pvalue < 0.001'
else:
kpsstext='pvalue = %g'%kpssP
sspan=7
N,E,H,NN,Hfull=tests.hurst(ISIs,sspan) #R/S analysis
N,F,Hf,NN,DFfull=tests.DFA(ISIs,sspan) #DFA analysis
#R/S and DFA on surrogate data:
surrHm,surrHsd,surrDFm,surrDFsd,surrHf,surrDFf=tests.hurstS(ISIs,sspan,100)
#p-values of the overall fits in R/S and DFA analyses
pvalHf = scipy.stats.norm.sf(Hfull[0],loc=np.mean(surrHf[:,0]),scale=np.std(surrHf[:,0]))
pvalDFf = scipy.stats.norm.sf(DFfull[0],loc=np.mean(surrDFf[:,0]),scale=np.std(surrDFf[:,0]))
#%% PLOTS
plt.figure(figsize=(12,9))
plt.clf()
ax1=plt.subplot(331) # V and Z traces
timeW=np.where((time1>800)*(time1<1000))[0]
if len(trace.shape)>1:
ax1.plot(time1[timeW],trace[timeW,0])
ax1b=ax1.twinx()
ax1b.plot(time1[timeW],trace[timeW,1],'g')
zmin,zmax=(np.around(min(trace[:,1]),2),np.around(max(trace[:,1]),2))
ax1.set_ylim((-0.5,1.05))
ax1b.set_ylim((zmin,4*zmax))
ax1b.set_yticks((zmin,(zmin+zmax)/2,zmax))
ax1b.yaxis.set_label_coords(-0.1,0.2)
ax1b.set_ylabel("Z",color='g',weight='bold')
else:
ax1.plot(time1[timeW],trace[timeW])
ax1.set_yticks((0,0.5,1))
ax1.set_xlabel("Time (ms)")
ax1.set_ylabel("V",weight='bold')
ax2=plt.subplot(334) # ISI plot
plt.plot(np.array(spikes[1:])/1000,ISIs,'.',ms=1)
ax2.set_xlabel("Time (s)")
ax2.set_ylabel("ISI (ms)")
ax3=plt.subplot(337) #Windowed KS test
tests.windowedKS(spikes,label="stoch V - Det Z",maxW=15,ax=ax3,plot=True)
ax4=plt.subplot(332) #R/S statistics plot
ax4.loglog(N,E,'.')
ax4.plot(N,N**Hfull.slope*np.exp(Hfull.intercept),'r-',lw=1)
plt.ylabel("mean rescaled range\n $<RS(n)>$")
[plt.plot(dat,2**(np.log2(dat)*h0+h1)) for (h0,h1),dat in zip(H,NN)]
# plt.plot(dat,2**(np.log2(dat)*h[0]+h[1]))
plt.text(0.05,0.95,'H=%g\nr=%g\np=%0.3g'%(Hfull.slope,Hfull.rvalue,pvalHf),
transform=ax4.transAxes,va='top')
plt.xlim(xmin=5)
xmin,xmax=plt.xlim()
plt.subplot(335) #R/S moving (windowed) slopes plots
plt.semilogx(N[sspan:-sspan],np.array(H)[:,0],'.')
plt.semilogx(N[sspan:-sspan],surrHm,'r')
plt.fill_between(N[sspan:-sspan],surrHm+2*surrHsd,surrHm-2*surrHsd,alpha=0.5)
plt.xlabel("length of sequence ($n$)")
plt.ylabel("slope ($H$ value)")
plt.xlim((xmin,xmax))
ax5=plt.subplot(333) #DFA plot
# plt.loglog(N,E,'.')
plt.loglog(N,F,'.')
plt.plot(N,N**DFfull.slope*np.exp(DFfull.intercept),'r-',lw=1)
plt.ylabel("Detrended Fluctuation")
[plt.plot(dat,2**(np.log2(dat)*h0+h1)) for (h0,h1),dat in zip(Hf,NN)]
# plt.plot(dat,2**(np.log2(dat)*h[0]+h[1]))
plt.text(0.05,0.95,'H=%g\nr=%g\np=%0.3g'%(DFfull.slope,DFfull.rvalue,pvalDFf),
transform=ax5.transAxes,va='top')
plt.xlim(xmin=5)
xmin,xmax=plt.xlim()
plt.subplot(336) #DFA moving (windowed) slopes plot
plt.semilogx(N[sspan:-sspan],np.array(Hf)[:,0],'.')
plt.semilogx(N[sspan:-sspan],surrDFm,'r')
plt.fill_between(N[sspan:-sspan],surrDFm+2*surrDFsd,surrDFm-2*surrDFsd,alpha=0.5)
plt.xlabel("length of sequence ($n$)")
plt.ylabel("slope ($H$ value)")
plt.xlim((xmin,xmax))
ax8=plt.subplot(338) #Autocorrelation of ISIs in log-linear plot
plt.acorr(ISIs,maxlags=50,usevlines=False,marker='.')
plt.yscale('log')
y1,y2=plt.ylim()
ymin=min(y1,0.89)
ax8.set_ylim((ymin,y2))
ax8.yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter('%g'))
ax8.yaxis.set_minor_formatter(matplotlib.ticker.FormatStrFormatter('%g'))
plt.xlim((0,50))
plt.xlabel("Delay")
plt.ylabel("Correlation (log)")
ax8i=inset_axes(ax8,"35%","50%",loc=1)
ax8i.plot(ISIs[1:],ISIs[:-1],'.',ms=1)
ax8i.set_xlabel(r"$ISI_{n+1}$")
ax8i.set_ylabel(r"$ISI_{n}$")
ax9=plt.subplot(339) #Autocorrelation of ISIs in log-linear plot
plt.acorr(ISIs,maxlags=50,usevlines=False,marker='.')
plt.yscale('log')
y1,y2=plt.ylim()
ymin=min(y1,0.89)
plt.ylim((ymin,y2))
ax9.yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter('%g'))
ax9.yaxis.set_minor_formatter(matplotlib.ticker.FormatStrFormatter('%g'))
plt.xscale('log')
plt.xlabel("Delay (log)")
plt.ylabel("Correlation (log)")
ax9i=inset_axes(ax9,"35%","50%",loc=1)
ax9i.loglog(ISIs[1:],ISIs[:-1],'.',ms=1)
ax9i.set_xlabel(r"$ISI_{n+1}$")
ax9i.set_ylabel(r"$ISI_{n}$")
plt.suptitle(label + '\n%g spikes KPSS LEVEL=%g '%(len(spikes)+1,kpssval) + kpsstext,y=0.99)
plt.tight_layout(rect=(0,0,1,0.97))