mirror of
https://github.com/Instagram/LibCST.git
synced 2025-12-23 10:35:53 +00:00
245 lines
8.2 KiB
Python
245 lines
8.2 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
# pyre-strict
|
|
import inspect
|
|
from collections import defaultdict
|
|
from collections.abc import Sequence as ABCSequence
|
|
from dataclasses import dataclass, fields, replace
|
|
from typing import Dict, Generator, List, Mapping, Sequence, Set, Type, Union
|
|
|
|
import libcst as cst
|
|
|
|
|
|
def _get_nodes() -> Generator[Type[cst.CSTNode], None, None]:
|
|
"""
|
|
Grab all CSTNodes that are not a superclass. Basically, anything that a
|
|
person might use to generate a tree.
|
|
"""
|
|
|
|
for name in dir(cst):
|
|
if name.startswith("__") and name.endswith("__"):
|
|
continue
|
|
if name == "CSTNode":
|
|
continue
|
|
|
|
node = getattr(cst, name)
|
|
try:
|
|
if issubclass(node, cst.CSTNode):
|
|
yield node
|
|
except TypeError:
|
|
# This isn't a class, so we don't care about it.
|
|
pass
|
|
|
|
|
|
all_libcst_nodes: Sequence[Type[cst.CSTNode]] = list(_get_nodes())
|
|
node_to_bases: Dict[Type[cst.CSTNode], List[Type[cst.CSTNode]]] = {}
|
|
for node in all_libcst_nodes:
|
|
# Map the base classes for this node
|
|
node_to_bases[node] = list(
|
|
reversed([b for b in inspect.getmro(node) if issubclass(b, cst.CSTNode)])
|
|
)
|
|
|
|
|
|
def _get_most_generic_base_for_node(node: Type[cst.CSTNode]) -> Type[cst.CSTNode]:
|
|
# Ignore non-exported bases, a user couldn't specify these types
|
|
# in type hints.
|
|
exportable_bases = [b for b in node_to_bases[node] if b in node_to_bases]
|
|
return exportable_bases[0]
|
|
|
|
|
|
nodebases: Dict[Type[cst.CSTNode], Type[cst.CSTNode]] = {}
|
|
for node in all_libcst_nodes:
|
|
# Find the most generic version of this node that isn't CSTNode.
|
|
nodebases[node] = _get_most_generic_base_for_node(node)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Usage:
|
|
maybe: bool = False
|
|
optional: bool = False
|
|
sequence: bool = False
|
|
|
|
|
|
nodeuses: Dict[Type[cst.CSTNode], Usage] = {node: Usage() for node in all_libcst_nodes}
|
|
|
|
|
|
def _is_maybe(typeobj: object) -> bool:
|
|
try:
|
|
# pyre-ignore We wrap this in a TypeError check so this is safe
|
|
return issubclass(typeobj, cst.MaybeSentinel)
|
|
except TypeError:
|
|
return False
|
|
|
|
|
|
def _get_origin(typeobj: object) -> object:
|
|
try:
|
|
# pyre-ignore We wrap this in a AttributeError check so this is safe
|
|
return typeobj.__origin__
|
|
except AttributeError:
|
|
# Don't care, not a union or sequence
|
|
return None
|
|
|
|
|
|
def _get_args(typeobj: object) -> List[object]:
|
|
try:
|
|
# pyre-ignore We wrap this in a AttributeError check so this is safe
|
|
return typeobj.__args__
|
|
except AttributeError:
|
|
# Don't care, not a union or sequence
|
|
return []
|
|
|
|
|
|
def _is_sequence(typeobj: object) -> bool:
|
|
origin = _get_origin(typeobj)
|
|
# pyre-ignore Pyre doesn't know about collections.abc.Sequence
|
|
return origin is Sequence or origin is ABCSequence
|
|
|
|
|
|
def _is_union(typeobj: object) -> bool:
|
|
return _get_origin(typeobj) is Union
|
|
|
|
|
|
def _calc_node_usage(typeobj: object) -> None:
|
|
if _is_union(typeobj):
|
|
has_maybe = any(_is_maybe(n) for n in _get_args(typeobj))
|
|
has_none = any(isinstance(n, type(None)) for n in _get_args(typeobj))
|
|
|
|
for node in _get_args(typeobj):
|
|
if node in all_libcst_nodes:
|
|
nodeuses[node] = replace(
|
|
nodeuses[node],
|
|
maybe=nodeuses[node].maybe or has_maybe,
|
|
optional=nodeuses[node].optional or has_none,
|
|
)
|
|
else:
|
|
_calc_node_usage(node)
|
|
|
|
if _is_sequence(typeobj):
|
|
for node in _get_args(typeobj):
|
|
if node in all_libcst_nodes:
|
|
nodeuses[node] = replace(nodeuses[node], sequence=True)
|
|
else:
|
|
_calc_node_usage(node)
|
|
|
|
|
|
for node in all_libcst_nodes:
|
|
for field in fields(node) or []:
|
|
if field.name == "_metadata":
|
|
continue
|
|
|
|
_calc_node_usage(field.type)
|
|
|
|
|
|
imports: Mapping[str, Set[str]] = defaultdict(set)
|
|
for node, base in nodebases.items():
|
|
if node.__name__.startswith("Base"):
|
|
continue
|
|
for x in (node, base):
|
|
imports[x.__module__].add(x.__name__)
|
|
|
|
generated_code: List[str] = []
|
|
generated_code.append("# Copyright (c) Facebook, Inc. and its affiliates.")
|
|
generated_code.append("#")
|
|
generated_code.append(
|
|
"# This source code is licensed under the MIT license found in the"
|
|
)
|
|
generated_code.append("# LICENSE file in the root directory of this source tree.")
|
|
generated_code.append("")
|
|
generated_code.append("# pyre-strict")
|
|
generated_code.append("")
|
|
generated_code.append("# This file was generated by libcst.codegen.gen_matcher_classes")
|
|
generated_code.append("from typing import Optional, Union, TYPE_CHECKING")
|
|
generated_code.append("")
|
|
generated_code.append("from libcst._maybe_sentinel import MaybeSentinel")
|
|
generated_code.append("from libcst._removal_sentinel import RemovalSentinel")
|
|
generated_code.append("from libcst._typed_visitor_base import mark_no_op")
|
|
|
|
# Import the types we use. These have to be type guarded since it would
|
|
# cause an import cycle otherwise.
|
|
generated_code.append("")
|
|
generated_code.append("")
|
|
generated_code.append(f"if TYPE_CHECKING:")
|
|
for module, objects in imports.items():
|
|
generated_code.append(f" from {module} import ( # noqa: F401")
|
|
generated_code.append(f" {', '.join(sorted(list(objects)))}")
|
|
generated_code.append(" )")
|
|
|
|
# Generate the base visit_ methods
|
|
generated_code.append("")
|
|
generated_code.append("")
|
|
generated_code.append("class CSTTypedBaseFunctions:")
|
|
for node in sorted(nodebases.keys(), key=lambda node: node.__name__):
|
|
name = node.__name__
|
|
if name.startswith("Base"):
|
|
continue
|
|
generated_code.append("")
|
|
generated_code.append(" @mark_no_op")
|
|
generated_code.append(
|
|
f' def visit_{name}(self, node: "{name}") -> Optional[bool]:'
|
|
)
|
|
generated_code.append(" pass")
|
|
for field in fields(node) or []:
|
|
if field.name == "_metadata":
|
|
continue
|
|
generated_code.append("")
|
|
generated_code.append(" @mark_no_op")
|
|
generated_code.append(
|
|
f' def visit_{name}_{field.name}(self, node: "{name}") -> None:'
|
|
)
|
|
generated_code.append(" pass")
|
|
generated_code.append("")
|
|
generated_code.append(" @mark_no_op")
|
|
generated_code.append(
|
|
f' def leave_{name}_{field.name}(self, node: "{name}") -> None:'
|
|
)
|
|
generated_code.append(" pass")
|
|
|
|
# Generate the visitor leave_ methods
|
|
generated_code.append("")
|
|
generated_code.append("")
|
|
generated_code.append("class CSTTypedVisitorFunctions(CSTTypedBaseFunctions):")
|
|
for node in sorted(nodebases.keys(), key=lambda node: node.__name__):
|
|
name = node.__name__
|
|
if name.startswith("Base"):
|
|
continue
|
|
generated_code.append("")
|
|
generated_code.append(" @mark_no_op")
|
|
generated_code.append(
|
|
f' def leave_{name}(self, original_node: "{name}") -> None:'
|
|
)
|
|
generated_code.append(" pass")
|
|
|
|
# Generate the transformer leave_ methods
|
|
generated_code.append("")
|
|
generated_code.append("")
|
|
generated_code.append("class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):")
|
|
generated_code.append(" pass")
|
|
for node in sorted(nodebases.keys(), key=lambda node: node.__name__):
|
|
name = node.__name__
|
|
if name.startswith("Base"):
|
|
continue
|
|
generated_code.append("")
|
|
generated_code.append(" @mark_no_op")
|
|
valid_return_types: List[str] = [f'"{nodebases[node].__name__}"']
|
|
node_uses = nodeuses[node]
|
|
base_uses = nodeuses[nodebases[node]]
|
|
if node_uses.maybe or base_uses.maybe:
|
|
valid_return_types.append("MaybeSentinel")
|
|
if (
|
|
node_uses.optional
|
|
or node_uses.sequence
|
|
or base_uses.optional
|
|
or base_uses.sequence
|
|
):
|
|
valid_return_types.append("RemovalSentinel")
|
|
generated_code.append(
|
|
f' def leave_{name}(self, original_node: "{name}", updated_node: "{name}") -> Union[{", ".join(valid_return_types)}]:'
|
|
)
|
|
generated_code.append(" return updated_node")
|
|
|
|
if __name__ == "__main__":
|
|
# Output the code
|
|
print("\n".join(generated_code))
|