Source code for abstracttree.adapters

import ast
import functools
import operator
import pathlib
import xml.etree.ElementTree as ET
import zipfile
from collections.abc import Sequence, Mapping
from typing import TypeVar, Callable, Iterable, overload, Collection, Union

from .tree import Tree, DownTree

TWrap = TypeVar("TWrap")
BaseString = Union[str, bytes, bytearray]  # Why did they ever remove this type?


@overload
def astree(obj) -> Tree: ...


@overload
def astree(obj: TWrap, children: Callable[[TWrap], Collection[TWrap]]) -> Tree: ...


@overload
def astree(
    obj: TWrap,
    children: Callable[[TWrap], Collection[TWrap]],
    parent: Callable[[TWrap], TWrap],
) -> Tree: ...


[docs]def astree( obj: TWrap, children: Callable[[TWrap], Iterable[TWrap]] = None, parent: Callable[[TWrap], TWrap] = None, ) -> Tree: """Convert an arbitrary object into a tree. If no children or parents are given it will try a standard conversion. If children and parents are given, they will be used as a function. If children, but no parent is given, it will create a tree that has obj as its root. """ if not children: return convert_tree(obj) else: if parent: class CustomTree(TreeAdapter): child_func = staticmethod(children) parent_func = staticmethod(parent) else: class CustomTree(StoredParent): child_func = staticmethod(children) return CustomTree(obj)
[docs]@functools.singledispatch def convert_tree(tree): """Low level conversion of tree object to an abstracttree.Tree. The default implementation ducktypes on `tree.parent` and `tree.children`. Classes can also define a method _abstracttree_ to override their conversion. """ explicit_conversion = getattr(tree, "_abstracttree_", None) if explicit_conversion: return explicit_conversion() if hasattr(tree, "children"): if hasattr(tree, "parent"): return TreeAdapter(tree) else: return StoredParent(tree) else: raise NotImplementedError
@convert_tree.register def _(tree: Tree): return tree @convert_tree.register def _(tree: DownTree): return UpgradedTree(tree) @convert_tree.register def _(tree: pathlib.PurePath): return PathTree(tree) @convert_tree.register def _(zf: zipfile.ZipFile): return PathTree(zipfile.Path(zf)) @convert_tree.register def _(path: zipfile.Path): return PathTree(path) @convert_tree.register def _(tree: Sequence): return SequenceTree(tree) @convert_tree.register(str) @convert_tree.register(bytes) @convert_tree.register(bytearray) def _(_: BaseString): raise NotImplementedError( "astree(x: str | bytes | bytearray) is unsafe, " "because x is infinitely recursively iterable." ) @convert_tree.register def _(tree: Mapping): return MappingTree((None, tree)) @convert_tree.register def _(cls: type, invert=False): if invert: return InvertedTypeTree(cls) return TypeTree(cls) @convert_tree.register def _(node: ast.AST): return AstTree(node) @convert_tree.register def _(element: ET.Element): return XmlTree(element) @convert_tree.register def _(tree: ET.ElementTree): return XmlTree(tree.getroot()) class TreeAdapter(Tree): __slots__ = "node" child_func: Callable[[TWrap], Iterable[TWrap]] = operator.attrgetter("children") parent_func: Callable[[TWrap], TWrap] = operator.attrgetter("parent") def __init__(self, node): self.node = node def __repr__(self): return f"{type(self).__name__}({self.node})" def __str__(self): return str(self.node) def __eq__(self, other): return isinstance(other, type(self)) and self.node == other.node @property def nid(self): return id(self.node) def eqv(self, other): return isinstance(other, type(self)) and self.node is other.node @property def children(self): cls = self.__class__ child_func = self.child_func return [cls(c) for c in child_func(self.node)] @property def parent(self): parent = self.parent_func(self.node) if parent is not None: return self.__class__(parent) else: return None class PathTree(TreeAdapter): __slots__ = () _custom_nids = {} @property def nid(self): try: # Doesn't work on zipfile st = self.node.lstat() except (FileNotFoundError, AttributeError): return self._custom_nids.setdefault(str(self.node), len(self._custom_nids)) else: return -st.st_ino def eqv(self, other): return self.node == other.node @staticmethod def parent_func(path): # pathlib makes parent an infinite, but we want None parent = path.parent if path != parent: return parent else: return None @property def children(self): try: return list(map(type(self), self.child_func(self.node))) except PermissionError: return [] @staticmethod def child_func(p): return p.iterdir() if p.is_dir() else () @property def root(self): return type(self)(type(self.node)(self.node.anchor)) class StoredParent(Tree): __slots__ = "node", "_parent" child_func: Callable[[TWrap], Iterable[TWrap]] = operator.attrgetter("children") def __init__(self, node, parent=None): self.node = node self._parent = parent def __repr__(self): return f"{type(self).__name__}({self.node})" def __str__(self): return str(self.node) def __eq__(self, other): return isinstance(other, type(self)) and self.node == other.node @property def nid(self): return id(self.node) def eqv(self, other): return isinstance(other, type(self)) and self.node is other.node @property def children(self): cls = type(self) child_func = self.child_func return [cls(c, self) for c in child_func(self.node)] @property def parent(self): return self._parent class UpgradedTree(StoredParent): __slots__ = () @property def nid(self): return self.node.nid def eqv(self, other): return isinstance(other, type(self)) and self.node.eqv(other.node) class SequenceTree(StoredParent): __slots__ = () @staticmethod def child_func(seq): if isinstance(seq, Sequence) and not isinstance(seq, BaseString): return seq else: return () def __str__(self): if self.is_leaf: return str(self.node) else: cls_name = type(self.node).__name__ return f"{cls_name}[{len(self.node)}]" class MappingTree(StoredParent): __slots__ = () @staticmethod def child_func(item): _, value = item if isinstance(value, Mapping): return value.items() elif value is not None: return [(value, None)] else: return [] @property def key(self): key, _ = self.node return key @property def mapping(self): _, mapping = self.node return mapping def __str__(self): key, mapping = self.node if self.is_root: cls_name = type(mapping).__name__ return f"{cls_name}[{len(mapping)}]" else: return str(key) class TypeTree(StoredParent): __slots__ = () @property def nid(self): return id(self.node) @staticmethod def child_func(cls): return cls.__subclasses__() @property def parents(self): return list(map(type(self), self.node.__bases__)) def __str__(self): return self.node.__qualname__ class InvertedTypeTree(TypeTree): __slots__ = () @staticmethod def child_func(cls): return cls.__bases__ class AstTree(StoredParent): __slots__ = () CONT = "↓" SINGLETON = Union[ast.expr_context, ast.boolop, ast.operator, ast.unaryop, ast.cmpop] @classmethod def child_func(cls, node): return tuple(node for node in ast.iter_child_nodes(node) if cls.is_child(node)) @classmethod def is_child(cls, node): return isinstance(node, ast.AST) and not isinstance(node, cls.SINGLETON) def __str__(self): if self.is_leaf: return ast.dump(self.node) else: format_value = self.format_value args = [f"{name}={format_value(field)}" for name, field in ast.iter_fields(self.node)] joined_args = ", ".join(args) return f"{type(self.node).__name__}({joined_args})" @classmethod def format_value(cls, field): if cls.is_child(field): field_str = cls.CONT elif isinstance(field, ast.AST): field_str = ast.dump(field) elif isinstance(field, Sequence) and not isinstance(field, BaseString): field = [cls.format_value(f) for f in field] field_str = "[" + ", ".join(field) + "]" else: field_str = repr(field) return field_str class XmlTree(StoredParent): @staticmethod def child_func(element): return element def __str__(self): element = self.node output = [f"<{element.tag}"] for k, v in element.items(): output.append(f" {k}={v!r}") output.append(">") if text := element.text and str(element.text).strip(): output.append(text) output.append(f"</{element.tag}>") if tail := element.tail and str(element.tail).strip(): output.append(tail) return "".join(output)