import numpy as np
__all__ = ['find_node_in_tree', 'compute_branch_id', 'find_terminal_branches',
'find_oblique_branches', 'find_terminal_and_oblique_branches',
'Node', 'Tree']
SWC_types = {'soma': 1, 'axon': 2, 'basal': 3, 'apical': 4, 'soma_contour': 16}
def find_node_in_tree(ID, tree):
ID = int(ID)
for node in tree:
if node.index == ID:
return node
return None
def compute_branch_id(tree):
branch_id = 0
for node in tree:
if node.parent is None:
node.content['branch_id'] = branch_id
elif len(node.parent.children) > 1:
branch_id += 1
node.content['branch_id'] = branch_id
else:
node.content['branch_id'] = node.parent.content['branch_id']
def find_terminal_branches(tree, point_types = (SWC_types['basal'], SWC_types['apical'])):
for node in tree:
if not 'on_terminal_branch' in node.content:
node.content['on_terminal_branch'] = False
if node.parent is not None and node.content['p3d'].type in point_types and len(node.parent.children) > 1:
child = node
while len(child.children) == 1:
child = child.children[0]
if len(child.children) == 0 and child.content['p3d'].xyz[1] > 0:
parent = child
while parent != node:
parent.content['on_terminal_branch'] = True
parent = parent.parent
node.content['on_terminal_branch'] = True
def find_oblique_branches(tree, end_point_limits=[0,240], max_angle=70):
if not 'on_terminal_branch' in tree.root.content:
find_terminal_branches(tree)
for node in tree:
if not node.content['on_terminal_branch']:
node.content['on_oblique_branch'] = False
elif len(node.parent.children) > 1:
child = node
while len(child.children) == 1:
child = child.children[0]
start_point = node.content['p3d'].xyz[:2]
end_point = child.content['p3d'].xyz[:2]
x,y = end_point - start_point
angle = np.abs(np.rad2deg(np.arctan(y / x)))
if node.content['p3d'].type == SWC_types['apical'] and \
end_point[1] > end_point_limits[0] and \
end_point[1] < end_point_limits[1] and \
angle < max_angle:
on_oblique = True
else:
on_oblique = False
parent = child
while parent is not None and parent != node:
parent.content['on_oblique_branch'] = on_oblique
parent = parent.parent
node.content['on_oblique_branch'] = on_oblique
def find_terminal_and_oblique_branches(swc_file, terminal_point_types = (SWC_types['basal'], SWC_types['apical']),
branch_types = ('terminal', 'oblique'),
end_point_limits=[0,240], max_angle=70):
import btmorph
tree = btmorph.STree2()
tree.read_SWC_tree_from_file(swc_file)
compute_branch_id(tree)
if 'oblique' in branch_types:
if not SWC_types['apical'] in terminal_point_types:
point_types = (SWC_types['apical'],) + terminal_point_types
else:
point_types = terminal_point_types
find_terminal_branches(tree, point_types)
find_oblique_branches(tree, end_point_limits, max_angle)
if not 'terminal' in branch_types:
# remove all terminal nodes
for node in tree:
node.content['on_terminal_branch'] = False
elif not SWC_types['apical'] in terminal_point_types:
for node in tree:
if node.content['p3d'].type == SWC_types['apical']:
node.content['on_terminal_branch'] = False
else:
find_terminal_branches(tree, terminal_point_types)
return tree
class Node (object):
def __init__(self, x, y, z, diam, node_type, node_id):
self._x = x
self._y = y
self._z = z
self._xyz = np.array([x,y,z])
self._diam = diam
self._node_type = node_type
self._node_id = node_id
self._parent = None
self._children = []
@property
def id(self):
return self._id
@property
def type(self):
return self._node_type
@property
def x(self):
return self._x
@x.setter
def x(self, value):
self._x = value
self._xyz[0] = value
@property
def y(self):
return self._y
@y.setter
def y(self, value):
self._y = value
self._xyz[1] = value
@property
def z(self):
return self._z
@z.setter
def z(self, value):
self._z = value
self._xyz[2] = value
@property
def diam(self):
return self._diam
@diam.setter
def diam(self, value):
if diam <= 0:
raise Exception('Diameter must be > 0')
self._diam = value
@property
def xyz(self):
return self._xyz
@xyz.setter
def xyz(self, value):
self._xyz = value
self._x, self._y, self._z = value
@property
def children(self):
return self._children
def add_to_children(self, node):
if not node in self._children:
self._children.append(node)
node.parent = self
def remove_from_children(self, node):
idx = self._children.index(node)
self._children.pop(idx)
@property
def parent(self):
return self._parent
@parent.setter
def parent(self, value):
if self._parent == value:
return
old_parent = self._parent
if old_parent is not None:
old_parent.remove_from_children(self)
self._parent = value
value.add_to_children(self)
class Tree (object):
def __init__(self, swc_file):
from collections import OrderedDict
data = np.loadtxt(swc_file)
idx = (data[:,1] == 2) | (data[:,1] == 3) | (data[:,1] == 4)
x = data[idx, 2]
y = data[idx, 3]
z = data[idx, 4]
self.xy_ratio = (x.max() - x.min()) / (y.max() - y.min())
self.bounds = np.array([[x.min(), x.max()], [y.min(), y.max()], [z.min(), z.max()]])
nodes = OrderedDict()
for row in data:
node_id = int(row[0])
node_type = int(row[1])
x, y, z, = row[2:5]
diam = row[5]
parent_id = int(row[6])
nodes[node_id] = Node(x, y, z, diam, node_type, node_id)
if parent_id > 0:
nodes[node_id].parent = nodes[parent_id]
_,self._root = nodes.popitem(last=False)
self.branches = []
self._make_branches(self.root, self.branches)
@property
def root(self):
return self._root
def _gather_nodes(self, node, node_list):
if not node is None:
node_list.append(node)
for child in node.children :
self._gather_nodes(child, node_list)
def __iter__(self):
nodes = []
self._gather_nodes(self.root, nodes)
for n in nodes:
yield n
def _make_branches(self, node, branches):
branch = []
while len(node.children) == 1:
branch.append(node)
node = node.children[0]
branch.append(node)
branches.append(branch)
for child in node.children:
self._make_branches(child, branches)