import numpy as np
from scipy.interpolate import NearestNDInterpolator

import matplotlib.pyplot as plt
import matplotlib.colors as colors
from matplotlib.collections import LineCollection
from matplotlib import cm

from .morpho import Tree

__all__ = ['set_rc_defaults', 'remove_border', 'make_axes', 'plot_means_with_errorbars', 'plot_tree']

def set_rc_defaults():
    plt.rc('font', family='Arial', size=10)
    plt.rc('lines', linewidth=1, color='k')
    plt.rc('axes', linewidth=1, titlesize='medium', labelsize='medium')
    plt.rc('xtick', direction='out')
    plt.rc('ytick', direction='out')
    #plt.rc('figure', dpi=300)


def remove_border(axes=None, top=False, right=False, left=True, bottom=True):
    """
    Minimize chartjunk by stripping out unnecessary plot borders and axis ticks
    
    The top/right/left/bottom keywords toggle whether the corresponding plot border is drawn
    """
    ax = axes or plt.gca()
    ax.spines['top'].set_visible(top)
    ax.spines['right'].set_visible(right)
    ax.spines['left'].set_visible(left)
    ax.spines['bottom'].set_visible(bottom)
    #turn off all ticks
    ax.yaxis.set_ticks_position('none')
    ax.xaxis.set_ticks_position('none')
    #now re-enable visibles
    if top:
        ax.xaxis.tick_top()
    if bottom:
        ax.xaxis.tick_bottom()
    if left:
        ax.yaxis.tick_left()
    if right:
        ax.yaxis.tick_right()


def make_axes(r, c, i, offset=[0.1,0.1], spacing=[0.1,0.1], border=[0.05,0.05]):
    """
    A better subplot.
    """
    if np.isscalar(offset):
        offset = [offset,offset]
    if np.isscalar(spacing):
        spacing = [spacing,spacing]
    if np.isscalar(border):
        border = [border,border]
    w = (1 - offset[0] - spacing[0]*(c-1) - border[0])/c
    h = (1 - offset[1] - spacing[1]*(r-1) - border[1])/r
    i = i-1
    x = i%c
    y = r - 1 - int(i/c)
    ax = plt.axes([offset[0] + (w+spacing[0])*x, 
                   offset[1] + (h+spacing[1])*y,
                   w, h])
    return ax


def plot_means_with_errorbars(x, y, mode='sem', ax=None, **kwargs):
    if ax is None:
        ax = plt.gca()
    Ym = np.nanmean(y,axis=0)
    if mode == 'sem':
        Ys = np.nanstd(y,axis=0) / np.sqrt(y.shape[0])
    else:
        Ys = np.nanstd(y,axis=0)
    try:
        lbl = kwargs.pop('label')
    except:
        lbl = None
    for i,ym,ys in zip(x,Ym,Ys):
        ax.plot([i,i], [ym-ys,ym+ys], **kwargs)
    if lbl is not None:
        ax.plot(x, Ym, 'o-', label=lbl, **kwargs)
    else:
        ax.plot(x, Ym, 'o-', **kwargs)


def _plot_tree_fast(tree, type_ids=(1,2,3,4), scalebar_length=None, cmap=None, points=None, values=None,
                    cbar_levels=None, cbar_ticks=10, cbar_orientation='vertical', cbar_label='', ax=None,
                    bounds=None):
    
    if ax is None:
        ax = plt.gca()

    if points is None or values is None:
        uniform_color_branches = True
        if cmap is None:
            color_fun = lambda i: [
                [0,0,0],    # soma
                [.2,.2,.2], # axon
                [.7,0,.7],  # basal
                [0,.7,0]    # apical
            ][i-1]
        elif isinstance(cmap, dict):
            color_fun = lambda key: cmap[key]
        else:
            color_fun = cmap
    else:
        uniform_color_branches = False
        interp = NearestNDInterpolator(points, values)
        norm = colors.Normalize(vmin = values.min(), vmax = values.max())

    for branch in tree.branches:
        if branch[0].type not in type_ids:
            continue
        if branch[0].parent is not None:
            node = branch[0].parent
            xyzd = np.concatenate((np.array([node.x, node.y, node.z, node.diam], ndmin=2),
                                [[node.x, node.y, node.z, node.diam] for node in branch]))
        else:
            xyzd = np.array([[node.x, node.y, node.z, node.diam] for node in branch])

        if branch[0].parent is not None and branch[0].parent.type == 1:
            xyzd[0,-1] = xyzd[1,-1]

        xy = xyzd[:,:2].reshape(-1, 1, 2)
        segments = np.concatenate([xy[:-1], xy[1:]], axis=1)
        if uniform_color_branches:
            lc = LineCollection(segments, linewidths=xyzd[:,-1]/2, colors=color_fun(branch[0].type))
        else:
            lc = LineCollection(segments, linewidths=xyzd[:,-1]/2, cmap=cmap, norm=norm)
            lc.set_array(interp(xyzd[:,:3]))
        line = ax.add_collection(lc)

    if bounds is not None:
        ax.set_xlim(bounds[0])
        ax.set_ylim(bounds[1])
    else:
        ax.set_xlim(tree.bounds[0])
        ax.set_ylim(tree.bounds[1])
        ax.axis('equal')

    if scalebar_length is not None:
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        x = xlim[0] / 1.5
        y = (ylim[1] - scalebar_length) / 2
        ax.plot(x + np.zeros(2), y + np.array([0, scalebar_length]), 'k', lw=2)
        ax.text(x - np.diff(xlim)/15, y + scalebar_length/2, r'{} $\mu$m'.format(scalebar_length), fontsize=12, \
                horizontalalignment='center', verticalalignment='center', rotation=90)

    if not uniform_color_branches and cbar_levels is not None:
        if np.isscalar(cbar_ticks):
            ticks = np.round(np.linspace(values.min(), values.max(), cbar_ticks))
            levels = np.linspace(ticks[0], ticks[-1], cbar_levels)
        else:
            ticks = cbar_ticks
            levels = np.linspace(values.min(), values.max(), cbar_levels)
        cbar = plt.colorbar(line, ax=ax, fraction=0.1, shrink=0.5, aspect=30, ticks=ticks, orientation=cbar_orientation)
        if cbar_orientation == 'vertical':
            cbar.ax.set_ylabel(cbar_label)
        else:
            cbar.ax.set_xlabel(cbar_label)



def _plot_tree_btmorph(tree, type_ids=(1,2,3,4), ax=None):
    import matplotlib.pyplot as plt
    if ax is None:
        _,ax = plt.subplots(1, 1)
    min_x, max_x = 0, 0
    for node in tree:
        if not node.parent is None and node.content['p3d'].type in type_ids:
            if node.content['p3d'].type == 1 and node.parent.content['p3d'].type == 1:
                continue
            parent_xy = node.parent.content['p3d'].xyz[:2]
            xy = node.content['p3d'].xyz[:2]
            if xy[0] >  max_x:
                max_x = xy[0]
            if xy[0] < min_x:
                min_x = xy[0]
            r = node.content['p3d'].radius
            if 'on_oblique_branch' in node.content and node.content['on_oblique_branch']:
                col = 'g'
            elif 'on_terminal_branch' in node.content and node.content['on_terminal_branch']:
                col = 'm'
            else:
                col = 'k'
            ax.plot([parent_xy[0], xy[0]], [parent_xy[1], xy[1]], color=col, linewidth=r)
    width = max_x - min_x
    dx = 100
    ax.plot(max_x - width / 10 + np.zeros(2), 50 + np.array([0,dx]), 'k', lw=1)
    ax.text(max_x - width / 6.5, 50 + dx/2, r'{} $\mu$m'.format(dx), horizontalalignment='center', \
            verticalalignment='center', rotation=90)
    ax.axis('equal')



def plot_tree(tree, type_ids=(1,2,3,4), scalebar_length=None, cmap=None, points=None, values=None,
              cbar_levels=None, cbar_ticks=10, cbar_orientation='vertical', cbar_label='', ax=None, bounds=None):
    if isinstance(tree, Tree):
        _plot_tree_fast(tree, type_ids, scalebar_length, cmap, points, values,
                        cbar_levels, cbar_ticks, cbar_orientation, cbar_label, ax, bounds)
    else:
        _plot_tree_btmorph(tree, type_ids, ax)