LibCST/libcst/codegen/gen_visitor_functions.py
2019-08-28 13:28:29 -07:00

245 lines
8.2 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import inspect
from collections import defaultdict
from collections.abc import Sequence as ABCSequence
from dataclasses import dataclass, fields, replace
from typing import Dict, Generator, List, Mapping, Sequence, Set, Type, Union
import libcst as cst
def _get_nodes() -> Generator[Type[cst.CSTNode], None, None]:
"""
Grab all CSTNodes that are not a superclass. Basically, anything that a
person might use to generate a tree.
"""
for name in dir(cst):
if name.startswith("__") and name.endswith("__"):
continue
if name == "CSTNode":
continue
node = getattr(cst, name)
try:
if issubclass(node, cst.CSTNode):
yield node
except TypeError:
# This isn't a class, so we don't care about it.
pass
all_libcst_nodes: Sequence[Type[cst.CSTNode]] = list(_get_nodes())
node_to_bases: Dict[Type[cst.CSTNode], List[Type[cst.CSTNode]]] = {}
for node in all_libcst_nodes:
# Map the base classes for this node
node_to_bases[node] = list(
reversed([b for b in inspect.getmro(node) if issubclass(b, cst.CSTNode)])
)
def _get_most_generic_base_for_node(node: Type[cst.CSTNode]) -> Type[cst.CSTNode]:
# Ignore non-exported bases, a user couldn't specify these types
# in type hints.
exportable_bases = [b for b in node_to_bases[node] if b in node_to_bases]
return exportable_bases[0]
nodebases: Dict[Type[cst.CSTNode], Type[cst.CSTNode]] = {}
for node in all_libcst_nodes:
# Find the most generic version of this node that isn't CSTNode.
nodebases[node] = _get_most_generic_base_for_node(node)
@dataclass(frozen=True)
class Usage:
maybe: bool = False
optional: bool = False
sequence: bool = False
nodeuses: Dict[Type[cst.CSTNode], Usage] = {node: Usage() for node in all_libcst_nodes}
def _is_maybe(typeobj: object) -> bool:
try:
# pyre-ignore We wrap this in a TypeError check so this is safe
return issubclass(typeobj, cst.MaybeSentinel)
except TypeError:
return False
def _get_origin(typeobj: object) -> object:
try:
# pyre-ignore We wrap this in a AttributeError check so this is safe
return typeobj.__origin__
except AttributeError:
# Don't care, not a union or sequence
return None
def _get_args(typeobj: object) -> List[object]:
try:
# pyre-ignore We wrap this in a AttributeError check so this is safe
return typeobj.__args__
except AttributeError:
# Don't care, not a union or sequence
return []
def _is_sequence(typeobj: object) -> bool:
origin = _get_origin(typeobj)
# pyre-ignore Pyre doesn't know about collections.abc.Sequence
return origin is Sequence or origin is ABCSequence
def _is_union(typeobj: object) -> bool:
return _get_origin(typeobj) is Union
def _calc_node_usage(typeobj: object) -> None:
if _is_union(typeobj):
has_maybe = any(_is_maybe(n) for n in _get_args(typeobj))
has_none = any(isinstance(n, type(None)) for n in _get_args(typeobj))
for node in _get_args(typeobj):
if node in all_libcst_nodes:
nodeuses[node] = replace(
nodeuses[node],
maybe=nodeuses[node].maybe or has_maybe,
optional=nodeuses[node].optional or has_none,
)
else:
_calc_node_usage(node)
if _is_sequence(typeobj):
for node in _get_args(typeobj):
if node in all_libcst_nodes:
nodeuses[node] = replace(nodeuses[node], sequence=True)
else:
_calc_node_usage(node)
for node in all_libcst_nodes:
for field in fields(node) or []:
if field.name == "_metadata":
continue
_calc_node_usage(field.type)
imports: Mapping[str, Set[str]] = defaultdict(set)
for node, base in nodebases.items():
if node.__name__.startswith("Base"):
continue
for x in (node, base):
imports[x.__module__].add(x.__name__)
generated_code: List[str] = []
generated_code.append("# Copyright (c) Facebook, Inc. and its affiliates.")
generated_code.append("#")
generated_code.append(
"# This source code is licensed under the MIT license found in the"
)
generated_code.append("# LICENSE file in the root directory of this source tree.")
generated_code.append("")
generated_code.append("# pyre-strict")
generated_code.append("")
generated_code.append("# This file was generated by libcst.codegen.gen_matcher_classes")
generated_code.append("from typing import Optional, Union, TYPE_CHECKING")
generated_code.append("")
generated_code.append("from libcst._maybe_sentinel import MaybeSentinel")
generated_code.append("from libcst._removal_sentinel import RemovalSentinel")
generated_code.append("from libcst._typed_visitor_base import mark_no_op")
# Import the types we use. These have to be type guarded since it would
# cause an import cycle otherwise.
generated_code.append("")
generated_code.append("")
generated_code.append(f"if TYPE_CHECKING:")
for module, objects in imports.items():
generated_code.append(f" from {module} import ( # noqa: F401")
generated_code.append(f" {', '.join(sorted(list(objects)))}")
generated_code.append(" )")
# Generate the base visit_ methods
generated_code.append("")
generated_code.append("")
generated_code.append("class CSTTypedBaseFunctions:")
for node in sorted(nodebases.keys(), key=lambda node: node.__name__):
name = node.__name__
if name.startswith("Base"):
continue
generated_code.append("")
generated_code.append(" @mark_no_op")
generated_code.append(
f' def visit_{name}(self, node: "{name}") -> Optional[bool]:'
)
generated_code.append(" pass")
for field in fields(node) or []:
if field.name == "_metadata":
continue
generated_code.append("")
generated_code.append(" @mark_no_op")
generated_code.append(
f' def visit_{name}_{field.name}(self, node: "{name}") -> None:'
)
generated_code.append(" pass")
generated_code.append("")
generated_code.append(" @mark_no_op")
generated_code.append(
f' def leave_{name}_{field.name}(self, node: "{name}") -> None:'
)
generated_code.append(" pass")
# Generate the visitor leave_ methods
generated_code.append("")
generated_code.append("")
generated_code.append("class CSTTypedVisitorFunctions(CSTTypedBaseFunctions):")
for node in sorted(nodebases.keys(), key=lambda node: node.__name__):
name = node.__name__
if name.startswith("Base"):
continue
generated_code.append("")
generated_code.append(" @mark_no_op")
generated_code.append(
f' def leave_{name}(self, original_node: "{name}") -> None:'
)
generated_code.append(" pass")
# Generate the transformer leave_ methods
generated_code.append("")
generated_code.append("")
generated_code.append("class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):")
generated_code.append(" pass")
for node in sorted(nodebases.keys(), key=lambda node: node.__name__):
name = node.__name__
if name.startswith("Base"):
continue
generated_code.append("")
generated_code.append(" @mark_no_op")
valid_return_types: List[str] = [f'"{nodebases[node].__name__}"']
node_uses = nodeuses[node]
base_uses = nodeuses[nodebases[node]]
if node_uses.maybe or base_uses.maybe:
valid_return_types.append("MaybeSentinel")
if (
node_uses.optional
or node_uses.sequence
or base_uses.optional
or base_uses.sequence
):
valid_return_types.append("RemovalSentinel")
generated_code.append(
f' def leave_{name}(self, original_node: "{name}", updated_node: "{name}") -> Union[{", ".join(valid_return_types)}]:'
)
generated_code.append(" return updated_node")
if __name__ == "__main__":
# Output the code
print("\n".join(generated_code))