"""
Some plotting utilities to use scale bars rather than coordinate axes.
18 July 2010, C. Schmidt-Hieber, University College London

From the stfio module:
http://code.google.com/p/stimfit
"""

has_mpl = True

try:
    import matplotlib.pyplot as plt
    from matplotlib.patches import Rectangle
except ImportError:
    has_mpl = False

import numpy as np

scale_dist_x = 0.02
scale_dist_y = 0.02
graph_width = 6.0
graph_height = 4.0
key_dist = 0.01

class timeseries(object):
    def __init__(self, section, dt, xunits="ms", yunits="mV",  
                 linestyle="-k", linewidth=1.0):
        if type(section)==np.ndarray:
            self.data = section
        else:
            self.data = section.asarray()
        self.dt = dt
        self.xunits = xunits
        self.yunits = yunits
        self.linestyle = linestyle
        self.linewidth = linewidth

    def x_trange(self, tstart, tend):
        return np.arange(int(tstart/self.dt), int(tend/self.dt), 1.0, 
                         dtype=np.float) * self.dt
                         
    def y_trange(self, tstart, tend):
        return self.data[int(tstart/self.dt):int(tend/self.dt)]
    
    def timearray(self):
        if len(self.data.shape)==1:
            return np.arange(0.0, len(self.data), 1.0) * self.dt
        else:
            return np.arange(0.0, self.data.shape[1], 1.0) * self.dt

    def duration(self):
        if len(self.data.shape)==1:
            return len(self.data) * self.dt
        else:
            return self.data.shape[1] * self.dt

    def interpolate(self, newtime, newdt):
        if len(self.data.shape) == 1:
            flin = \
                interpolate.interp1d(self.timearray(), self.data, 
                                     bounds_error=False, fill_value=0)
            return timeseries(flin(newtime), newdt)
        else:
            # interpolate each row individually:
            iparray = ma.zeros((self.data.shape[0], len(newtime)))
            for nrow, row in enumerate(self.data):
                flin = \
                    interpolate.interp1d(self.timearray(), row, 
                                         bounds_error=False, fill_value=0)
                iparray[nrow,:]=flin(newtime)
            return timeseries(iparray, newdt)
    
    def maskedarray(self, center, left, right):
        # check whether we have enough data left and right:
        if len(self.data.shape) > 1:
            mask = \
                np.zeros((self.data.shape[0], int((right+left)/self.dt)))
            maskedarray = \
                ma.zeros((self.data.shape[0], int((right+left)/self.dt)))
        else:
            mask = np.zeros((int((right+left)/self.dt)))
            maskedarray = ma.zeros((int((right+left)/self.dt)))
        offset = 0
        if center - left < 0:
            if len(self.data.shape) > 1:
                mask[:,:int((left-center)/self.dt)] = 1
            else:
                mask[:int((left-center)/self.dt)] = 1
            leftindex = 0
            offset = int((left-center)/self.dt)
        else:
            leftindex = int((center-left)/self.dt)
        if center + right >= len(self.data) * self.dt:
            endtime = len(self.data) * self.dt
            if len(self.data.shape) > 1:
                mask[:,-int((center+right-endtime)/self.dt):] = 1
            else:
                mask[-int((center+right-endtime)/self.dt):] = 1
            rightindex = int(endtime/self.dt)
        else:
            rightindex = int((center+right)/self.dt)
        for timest in range(leftindex, rightindex):
                if len(self.data.shape) > 1:
                    if timest-leftindex+offset < maskedarray.shape[1] and timest<self.data.shape[1]:
                        maskedarray[:,timest-leftindex+offset]=self.data[:,timest]
                else:
                    if timest-leftindex+offset < len(maskedarray):
                        maskedarray[timest-leftindex+offset]=self.data[timest]
        maskedarray.mask = ma.make_mask(mask)
        return timeseries(maskedarray, self.dt)
        
def average(tsl):
    # find fastest dt:
    dt_common = 1e12
    for ts in tsl:
        if ts.dt < dt_common:
            newtime = ts.timearray()
            dt_common = ts.dt
            
    # interpolate all series to new dt:
    tslip = [ts.interpolate(newtime, dt_common) for ts in tsl]
    if len(tslip[0].data.shape)==1:
        ave = np.empty((len(tslip), len(tslip[0].data)))
    else:
        ave = np.empty((len(tslip), tslip[0].data.shape[0], tslip[0].data.shape[1]))
        
    for its, ts in enumerate(tslip):
        if len(ts.data.shape)==1:
            ave[its] = ts.data
        else:
            ave[its,:,:] = ts.data[:,:]

    if len(ts.data.shape)==1:
        return timeseries(ma.mean(ave, axis=0), dt_common)
    else:
        avef = ma.zeros((tslip[0].data.shape[0], tslip[0].data.shape[1]))
        for nrow, row in enumerate(avef):
            avef[nrow,:] = ma.mean(ave[:,nrow,:], axis=0)
        return timeseries(avef, dt_common)

def prettyNumber(f):
    fScaled = f
    if fScaled < 1:
        correct = 10.0
    else:
        correct = 1.0

    # set stepsize
    nZeros = int(np.log10(fScaled))
    prev10e = 10.0**nZeros / correct
    next10e = prev10e * 10

    if fScaled / prev10e  > 7.5:
        return next10e
    elif fScaled / prev10e  > 5.0:
        return 5 * prev10e
    else:
        return round(fScaled/prev10e) * prev10e
    
def plot_scalebars(ax, div=3.0, labels=True, 
                    xunits="", yunits="", nox=False, 
                    sb_xoff=0, sb_yoff=0, rotate_yslabel=False, 
                    linestyle="-k", linewidth=4.0,
                    textcolor='k', textweight='normal'):
    # print dir(ax.dataLim)
    xmin = ax.dataLim.xmin
    xmax = ax.dataLim.xmax
    ymin = ax.dataLim.ymin
    ymax = ax.dataLim.ymax
    xscale = xmax-xmin
    yscale = ymax-ymin

    xoff = (scale_dist_x + sb_xoff) * xscale
    yoff = (scale_dist_y - sb_yoff) * yscale

    # plot scale bars:
    xlength = prettyNumber((xmax-xmin)/div)
    xend_x, xend_y = xmax, ymin
    if not nox:
        xstart_x, xstart_y = xmax-xlength, ymin
        scalebarsx = [xstart_x+xoff, xend_x+xoff]
        scalebarsy = [xstart_y-yoff, xend_y-yoff]
    else:
        scalebarsx=[xend_x+xoff,]
        scalebarsy=[xend_y-yoff]
    
    ylength = prettyNumber((ymax-ymin)/div)
    yend_x, yend_y = xmax, ymin+ylength
    scalebarsx.append(yend_x+xoff)
    scalebarsy.append(yend_y-yoff)
        
    ax.plot(scalebarsx, scalebarsy, linestyle, linewidth=linewidth, solid_joinstyle='miter')

    if labels:
        # if textcolor is not None:
        #     color = "\color{%s}" % textcolor
        # else:
        #     color = ""
        if not nox:
            # xlabel
            if xlength >=1:
                xlabel = r"%d$\,$%s" % (xlength, xunits)
            else:
                xlabel = r"%g$\,$%s" % (xlength, xunits)
            xlabel_x, xlabel_y = xmax-xlength/2.0, ymin
            xlabel_y -= key_dist*yscale
            ax.text(xlabel_x+xoff, xlabel_y-yoff, xlabel, ha='center', va='top',
                    weight=textweight, color=textcolor) #, [pyx.text.halign.center,pyx.text.valign.top])
        # ylabel
        if ylength >=1:
            ylabel = r"%d$\,$%s" % (ylength,yunits)
        else:
            ylabel = r"%g$\,$%s" % (ylength,yunits)
        if not rotate_yslabel:
            ylabel_x, ylabel_y = xmax, ymin + ylength/2.0
            ylabel_x += key_dist*xscale
            ax.text(ylabel_x+xoff, ylabel_y-yoff, ylabel, ha='left', va='center',
                    weight=textweight, color=textcolor)
        else:
            ylabel_x, ylabel_y = xmax, ymin + ylength/2.0
            ylabel_x += key_dist*xscale
            ax.text(ylabel_x+xoff, ylabel_y-yoff, ylabel, ha='center', va='top', rotation=90,
                    weight=textweight, color=textcolor)


def xFormat(x, res, data_len, width):
    points = int(width/2.5 * res)
    part = float(x) / data_len
    return int(part*points)

def yFormat(y):
    return y

def reduce(ydata, dy, maxres, xoffset=0, width=graph_width):
    x_last = xFormat(0, maxres, len(ydata), width)
    y_last = yFormat(ydata[0])
    y_max = y_last
    y_min = y_last
    x_next = 0
    y_next = 0
    xrange = list()
    yrange = list()
    xrange.append(x_last)
    yrange.append(y_last)
    for (n,pt) in enumerate(ydata[:-1]):
        x_next = xFormat(n+1, maxres, len(ydata), width)
        y_next = yFormat(ydata[n+1])
        # if we are still at the same pixel column, only draw if this is an extremum:
        if (x_next == x_last):
            if (y_next < y_min):
                y_min = y_next
            if (y_next > y_max):
                y_max = y_next
        else:
            # else, always draw and reset extrema:
            if (y_min != y_next):
                xrange.append(x_last)
                yrange.append(y_min)
                y_last = y_min
            if (y_max != y_next):
                xrange.append(x_last)
                yrange.append(y_max)
                y_last = y_max
            xrange.append(x_next)
            yrange.append(y_next)
            y_min = y_next
            y_max = y_next
            x_last = x_next
            y_last = y_next
    trace_len_pts  = width/2.5 * maxres
    trace_len_time = len(ydata) * dy
    dt_per_pt = trace_len_time / trace_len_pts
    xrange = np.array(xrange)*dt_per_pt + xoffset
    
    return xrange, yrange

def plot_traces(traces, pulses=None,
                 xmin=None, xmax=None, ymin=None, ymax=None, xoffset=0,
                 maxres = None,
                 sb_yoff=0, sb_xoff=0, linestyle_sb = "-k",
                 dashedline=None, sagline=None, rotate_yslabel=False,
                 textcolor='k', textweight='normal'):
    
    Fig = plt.figure(dpi=maxres)
    Fig.patch.set_alpha(0.0)
    border = 0.1
    pulseprop = 0.1
    if pulses is not None and len(pulses) > 0:
        prop = 1.0-pulseprop-border
    else:
        prop = 1.0-border
    ax = Fig.add_axes([0.0,(1.0-prop),1.0-border,prop], alpha=0.0)

    for trace in traces:
        if maxres is None:
            xrange = trace.timearray()+xoffset
            yrange = trace.data
        else:
            xrange, yrange = reduce(trace.data, trace.dt, maxres=maxres)
            xrange += xoffset
        ax.plot(xrange, yrange, trace.linestyle, lw=trace.linewidth)

    if xmin is not None:
        phantomrect_x0 = xmin
    else:
        phantomrect_x0 = ax.dataLim.xmin
        
    if xmax is not None:
        phantomrect_x1 = xmax
    else:
        phantomrect_x1 = ax.dataLim.xmax

    if ymin is not None:
        phantomrect_y0 = ymin
    else:
        phantomrect_y0 = ax.dataLim.ymin

    if ymax is not None:
        phantomrect_y1 = ymax
    else:
        phantomrect_y1 = ax.dataLim.ymax

    pr = ax.plot([phantomrect_x0, phantomrect_x1], [phantomrect_y0, phantomrect_y1], alpha=0.0)

    xscale = ax.dataLim.xmax-ax.dataLim.xmin
    yscale = ax.dataLim.ymax-ax.dataLim.ymin
    if dashedline is not None:
        ax.plot([ax.dataLim.xmin, ax.dataLim.xmax],[dashedline, dashedline], 
                "--k", linewidth=linewidth*2.0)
        gridline_x, gridline_y = ax.dataLim.xmax, dashedline
        gridline_x += key_dist*xscale
        xoff = scale_dist_x * xscale


    if sagline is not None:
        ax.plot([ax.dataLim.xmin, ax.dataLim.xmax],[sagline, sagline], 
                "--k", linewidth=linewidth*2.0)
        gridline_x, gridline_y = ax.dataLim.xmax, sagline
        gridline_x += key_dist*xscale
        xoff = scale_dist_x * xscale

    plot_scalebars(ax, linestyle=linestyle_sb, xunits=traces[0].xunits, yunits=traces[0].yunits,
                   textweight=textweight, textcolor=textcolor)

    if pulses is not None and len(pulses) > 0:
        axp = Fig.add_axes([0.0,0.0,1.0-border,pulseprop+border/2.0], sharex=ax)
        for pulse in pulses:
            xrange = pulse.timearray()
            yrange = pulse.data
            axp.plot(xrange, yrange, pulse.linestyle, linewidth=pulse.linewidth)
        plot_scalebars(axp, linestyle=linestyle_sb, nox=True, yunits=pulses[0].yunits,
                       textweight=textweight, textcolor=textcolor)
        for o in axp.findobj():
            o.set_clip_on(False)
        axp.axis('off')

    if xmin is None:
        xmin = ax.dataLim.xmin
    if xmax is None:
        xmax = ax.dataLim.xmax
    if ymin is None:
        ymin = ax.dataLim.ymin
    if ymax is None:
        ymax = ax.dataLim.ymax

    ax.set_xlim(xmin, xmax)
    ax.set_ylim(ymin, ymax)

    for o in ax.findobj():
        o.set_clip_on(False)
    ax.axis('off')
    
    return Fig