# -*- coding: utf-8 -*-
import numpy as np
from scipy.signal import cwt, morlet2, butter, sosfiltfilt
import matplotlib.pyplot as plt




def compPWT(signal,fs=np.arange(50.,450.1,3.),dt=1.e-4,word=5,tmin=.0):
  """
  Computes the Wavelet Power Spectra (modulus of complex wavelet transform) for a set of signals
  Input:
   signal (to be analyzed)
   fs (frequencies to compute cwt)
   dt
   word (wavelet order, default=5)
   nlevels (number of levels at contour plot)
   tmin (plot wavelet spectra from tmin to end of signal. in seconds)
  """

  ktmin = int(tmin/dt);

  # Perform Wavelet Analysis
  samp_rate = 1./dt; # Sampling rate is inverse of time step dt
  ss = word*samp_rate/(2.*np.pi*fs); # Compute corresponding wavelet scalings or widths (s).
  pwt = np.power(np.abs(cwt(signal[ktmin:]-np.mean(signal[ktmin:]),morlet2,widths=ss)),2.) # Compute wavelet power, i.e. squared modulus, using a Morlet wavelet

  return pwt

def compPhaseWT(signal,fs=np.arange(50.,450.1,3.),dt=1.e-4,word=5,tmin=.0):
  """
  Computes the Wavelet Power Phase (modulus of complex wavelet transform) for a set of signals
  Input:
   signal (to be analyzed)
   fs (frequencies to compute cwt)
   dt
   word (wavelet order, default=5)
   nlevels (number of levels at contour plot)
   tmin (plot wavelet spectra from tmin to end of signal. in seconds)
  """

  ktmin = int(tmin/dt);

  # Perform Wavelet Analysis
  samp_rate = 1./dt; # Sampling rate is inverse of time step dt
  ss = word*samp_rate/(2.*np.pi*fs); # Compute corresponding wavelet scalings or widths (s).
  #pwt = np.angle(cwt(signal[ktmin:]-np.mean(signal[ktmin:]),morlet2,widths=ss)) # Compute wavelet phase, using a Morlet wavelet
  pwt = np.angle(cwt(signal[ktmin:]-np.mean(signal[ktmin:]),morlet2,widths=ss))
  return pwt

def plot1(cwtPow,max_power=None,fs=np.arange(50.,450.1,3.),tmin=.0,dt=.1,nlevels=30,Ncycs=1,flims=[],xlims=[],colorbar=True,cmap="hot"):
    """
    plot the contour plot for the wavelet scalogram for one single simulation
    """
    ktmin = int(np.round(tmin/dt));
    Nt_pr_cyc = int(np.round(125./dt));
    if max_power==None: max_power = np.max(np.max(cwtPow));
    levels = np.linspace(0.,max_power+1.,nlevels);
    tmp = plt.contourf(np.linspace(-np.pi,-np.pi+2.*np.pi*Ncycs,int(np.round(Ncycs*Nt_pr_cyc))),fs,np.abs(cwtPow[:,ktmin:]),levels=levels,cmap=cmap,rasterized=True);

    if colorbar==True: tmp = plt.colorbar();

    if len(flims)==0: tmp = plt.ylim(fs[0],fs[-1]);
    else: tmp = plt.ylim(flims[0],flims[1]);

    if len(xlims)==0: tmp = plt.xlim(-np.pi,-np.pi+2.*np.pi*Ncycs);
    else: tmp = plt.xlim(xlims[0],xlims[1]);

def plotPhase(cwtPow,max_power=None,fs=np.arange(50.,450.1,3.),tmin=.0,dt=.1,nlevels=30,Ncycs=1,flims=[],xlims=[],colorbar=True,cmap="hot"):
    """
    plot the contour plot for the wavelet scalogram for one single simulation
    """
    ktmin = int(np.round(tmin/dt));
    Nt_pr_cyc = int(np.round(125./dt));
    #if max_power==None: max_power = np.max(np.max(cwtPow));
    levels = 30#np.linspace(0.,max_power+1.,nlevels);
    tmp = plt.contourf(np.linspace(-np.pi,-np.pi+2.*np.pi*Ncycs,int(np.round(Ncycs*Nt_pr_cyc))),fs,cwtPow[:,ktmin:],levels=levels,cmap=cmap,rasterized=True);

    if colorbar==True: tmp = plt.colorbar();

    if len(flims)==0: tmp = plt.ylim(fs[0],fs[-1]);
    else: tmp = plt.ylim(flims[0],flims[1]);

    if len(xlims)==0: tmp = plt.xlim(-np.pi,-np.pi+2.*np.pi*Ncycs);
    else: tmp = plt.xlim(xlims[0],xlims[1]);