from collections.abc import Sequence, Callable, Iterable
from functools import lru_cache
from typing import Optional, TypeVar, Type
import abstracttree.generics as generics
from abstracttree.generics import TreeLike, DownTreeLike
from abstracttree.mixins import Tree
from abstracttree.utils import eqv
T = TypeVar("T")
[docs]def convert_tree(tree: DownTreeLike, required_type=Type[T]) -> T:
"""Convert a TreeLike to a powerful Tree.
If needed, it uses a TreeAdapter.
"""
if isinstance(tree, required_type):
return tree
elif hasattr(tree, '_abstracttree_'):
tree = tree._abstracttree_()
else:
tree = as_tree(tree)
if isinstance(tree, required_type):
return tree
else:
raise TypeError(f"Unable to convert {type(tree)} to {required_type}")
[docs]def as_tree(
obj: T,
children: Callable[[T], Iterable[T]] = None,
parent: Callable[[T], Optional[T]] = None,
label: Callable[[T], str] = None,
) -> "TreeAdapter":
"""Convert any object to a tree.
Functions can be passed to control how the conversion should be done.
The original object can be accessed by using the value attribute.
"""
cls = type(obj)
adapter = compile_adapter(cls, children, parent, label)
tree = adapter(obj)
return tree
@lru_cache(maxsize=None)
def compile_adapter(
cls,
children: Callable[[T], Iterable[T]] = None,
parent: Callable[[T], Optional[T]] = None,
label: Callable[[T], str] = None,
):
if not parent and issubclass(cls, TreeLike):
parent = generics.parent.dispatch(cls)
class CustomTreeAdapter(TreeAdapter):
child_func = staticmethod(children or generics.children.dispatch(cls))
parent_func = staticmethod(parent)
label_func = staticmethod(label or generics.label.dispatch(cls))
return CustomTreeAdapter
# Alias for backwards compatibility
astree = as_tree
[docs]class TreeAdapter(Tree):
child_func = staticmethod(generics.children)
parent_func = staticmethod(generics.parent)
label_func = staticmethod(generics.label)
def __init__(self, value: TreeLike, _parent=None):
self._value = value
self._parent = _parent
def __repr__(self) -> str:
return f"{type(self).__qualname__}({self.value!r})"
def __str__(self) -> str:
return self.label_func(self._value)
@property
def nid(self) -> int:
return generics.nid(self._value)
def __eq__(self, other) -> bool:
"""Check if the same node is wrapped. Similar to eqv(self.value, other.value)."""
return eqv(self, other)
def __hash__(self) -> int:
"""An adapter is hashable iff the underlying object is hashable."""
return hash(self._value)
@property
def value(self):
return self._value
@property
def parent(self: T) -> Optional[T]:
if self._parent is not None:
return self._parent
cls = type(self)
if pf := cls.parent_func:
par = pf(self._value)
if par is not None:
return cls(par)
return None
@property
def children(self: T) -> Sequence[T]:
cls = type(self)
_child_func = cls.child_func
child_nodes = _child_func(self._value)
return [cls(c, self) for c in child_nodes]