# -*- coding: utf-8 -*-
"""
Created on Thu Mar 17 15:53:10 2016

@author: Subhasis
"""


from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d import proj3d


def plot_3d_points(graph):
    """Plot the nodes of the tree as points in 3D"""
    fig = plt.figure()
    ax = fig.gca(projection='3d')
    for node, attr in graph.nodes(data=True):
        ax.plot([attr['x']], [attr['y']], [attr['z']], 'o') # This works - will need to plot cylinders for edges.
    return fig

    
def closest(G, x, y, z):
    """Return the node closest to the position x, y, z."""
    closest = -1
    dmin = 1e9
    for n, attr in G.nodes(data=True):
        d = np.sqrt((attr['x'] - x)**2 + (attr['y'] - y)**2 + \
            (attr['z'] - z)**2)
        if d < dmin:
            dmin = d
            closest = n
    return closest, dmin
    
    
def plot_3d_lines(graph, ax=None, color='k', alpha=0.7):
    """Plot the neuronal morphology tree in 3D
    if arrow is True use arrows from parent to child node.
    
    If color is a dict then each segment is color is looked up by the structure type attribute of its child node.
    """
    def show_node(event):
        ln = event.artist
        xdata = ln.get_xdata()
        ydata = ln.get_ydata()
        zdata = ln.get_zdata()
        ind = event.ind
        node, d = closest(graph, xdata[ind], ydata[ind], zdata[ind])
        print('onpick points:', node, d)
    
    if ax is None:
        fig = plt.figure()
        ax = fig.gca(projection='3d')
        ax.figure.canvas.mpl_connect('pick_event', show_node)
        
    for node, attr in graph.nodes(data=True):
        if isinstance(color, dict):
            c=color[attr['s']]
        else:
            c = color
        try:
            px = graph.node[attr['p']]['x']
            py = graph.node[attr['p']]['y']
            pz = graph.node[attr['p']]['z']
            ax.plot([attr['x'], px], [attr['y'], py], [attr['z'], pz], 
                    color=c, ls='-', alpha=alpha)            
        except KeyError:
            ax.plot([attr['x']], [attr['y']], [attr['z']], 'o', color=c,
                    alpha=alpha)            
    return ax
    
def mark_leaf_nodes(graph, ax, color='r', marker='o'):
    for node in graph.nodes():
        if graph.degree(node) == 1:
            ax.plot([graph.node[node]['x']], [graph.node[node]['y']], [graph.node[node]['z']], marker=marker, color=color)
            ax.text(graph.node[node]['x'], graph.node[node]['y'], graph.node[node]['z'], str(node))
    return ax

def plot_nodes(g, nodes, ax, color='r', marker='^'):
    """Mark the nodes of G plotted on axes ax.
    
    Useful for labeling specific nodes in two stages:

    ax = plot_3d_lines(g)
    ax = plot_nodes(g, nodes, ax)
    """
    for n in nodes:
        x = g.node[n]['x']
        y = g.node[n]['y']
        z = g.node[n]['z']
        ax.plot([x], [y], [z], marker=marker, color=color)
        ax.text(x, y, z, str(n))