Source code for abstracttree.predicates

from dataclasses import dataclass
from typing import Callable

from .tree import AbstractTree, NodeItem


[docs]class Predicate(Callable[[AbstractTree, NodeItem], bool]): __slots__ = () def __or__(self, other): if other is None: return None else: return PredicateUnion(self, other) def __and__(self, other): if other is None: return self else: return PredicateIntersection(self, other)
[docs]class PredicateUnion(Predicate): __slots__ = "preds" def __init__(self, *preds): self.preds = preds def __call__(self, node: AbstractTree, item: NodeItem): return any(pred(node, item) for pred in self.preds)
[docs]class PredicateIntersection(Predicate): __slots__ = "preds" def __init__(self, *preds): self.preds = preds def __call__(self, node: AbstractTree, item: NodeItem): return all(pred(node, item) for pred in self.preds)
[docs]class RemoveDuplicates(Predicate): """Remove duplicates in case of cycles or multiparent trees.""" __slots__ = "seen" def __init__(self): self.seen = set() def __call__(self, node: AbstractTree, item): if node.nid in self.seen: return False else: self.seen.add(node.nid) return True
[docs]class PreventCycles(Predicate): """Prevent looping cyclic trees. It might yield nodes more than once, but will not repeat children. This is mostly useful when trying to plot cyclic trees. """ __slots__ = "seen", "duplicates" def __init__(self): self.seen = set() self.duplicates = set() def __call__(self, node: AbstractTree, item): if node.parent is not None and node.parent.nid in self.duplicates and node.nid in self.seen: return False if node.nid in self.seen: self.duplicates.add(node.nid) else: self.seen.add(node.nid) return True
# From python 3.10, add slots=True
[docs]@dataclass(frozen=True) class MaxDepth(Predicate): """Limit iteration to a certain depth Can be passed to keep argument of methods such as tree.iter_tree(). >>> from littletree import Node >>> tree = Node(identifier='root').path.create(['a', 'b', 'c', 'd']).root >>> [node.identifier for node in tree.nodes.preorder(keep=MaxDepth(3))] ['root', 'a', 'b', 'c'] """ depth: int def __call__(self, _node, item): return item.depth <= self.depth