mirror of
https://github.com/Instagram/LibCST.git
synced 2025-12-23 10:35:53 +00:00
Summary: This PR removes the `typing_extensions` and `typing_inspect` dependencies as we can now rely on the built-in `typing` module since Python 3.9. Test Plan: existing tests
556 lines
20 KiB
Python
556 lines
20 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import re
|
|
from dataclasses import dataclass, fields
|
|
from typing import Generator, List, Optional, Sequence, Set, Tuple, Type, Union
|
|
|
|
import libcst as cst
|
|
from libcst import ensure_type, parse_expression
|
|
from libcst.codegen.gather import all_libcst_nodes, typeclasses
|
|
|
|
CST_DIR: Set[str] = set(dir(cst))
|
|
CLASS_RE = r"<class \'(.*?)\'>"
|
|
OPTIONAL_RE = r"typing\.Union\[([^,]*?), NoneType]"
|
|
|
|
|
|
class CleanseFullTypeNames(cst.CSTTransformer):
|
|
def leave_Call(
|
|
self, original_node: cst.Call, updated_node: cst.Call
|
|
) -> cst.BaseExpression:
|
|
# Convert forward ref repr back to a SimpleString.
|
|
if isinstance(updated_node.func, cst.Name) and (
|
|
updated_node.func.deep_equals(cst.Name("_ForwardRef"))
|
|
or updated_node.func.deep_equals(cst.Name("ForwardRef"))
|
|
):
|
|
return updated_node.args[0].value
|
|
return updated_node
|
|
|
|
def leave_Attribute(
|
|
self, original_node: cst.Attribute, updated_node: cst.Attribute
|
|
) -> Union[cst.Attribute, cst.Name]:
|
|
# Unwrap all attributes, so things like libcst.x.y.Name becomes Name
|
|
return updated_node.attr
|
|
|
|
def leave_Name(
|
|
self, original_node: cst.Name, updated_node: cst.Name
|
|
) -> Union[cst.Name, cst.SimpleString]:
|
|
value = updated_node.value
|
|
if value == "NoneType":
|
|
# This is special-cased in typing, un-special case it.
|
|
return updated_node.with_changes(value="None")
|
|
if value in CST_DIR and not value.endswith("Sentinel"):
|
|
# If this isn't a typing define and it isn't a builtin, convert it to
|
|
# a forward ref string.
|
|
return cst.SimpleString(repr(value))
|
|
return updated_node
|
|
|
|
def leave_SubscriptElement(
|
|
self, original_node: cst.SubscriptElement, updated_node: cst.SubscriptElement
|
|
) -> Union[cst.SubscriptElement, cst.RemovalSentinel]:
|
|
slc = updated_node.slice
|
|
if isinstance(slc, cst.Index):
|
|
val = slc.value
|
|
if isinstance(val, cst.Name):
|
|
if "Sentinel" in val.value:
|
|
# We don't support maybes in matchers.
|
|
return cst.RemoveFromParent()
|
|
# Simple trick to kill trailing commas
|
|
return updated_node.with_changes(comma=cst.MaybeSentinel.DEFAULT)
|
|
|
|
|
|
class RemoveTypesFromGeneric(cst.CSTTransformer):
|
|
def __init__(self, values: Sequence[str]) -> None:
|
|
self.values: Set[str] = set(values)
|
|
|
|
def leave_SubscriptElement(
|
|
self, original_node: cst.SubscriptElement, updated_node: cst.SubscriptElement
|
|
) -> Union[cst.SubscriptElement, cst.RemovalSentinel]:
|
|
slc = updated_node.slice
|
|
if isinstance(slc, cst.Index):
|
|
val = slc.value
|
|
if isinstance(val, cst.Name):
|
|
if val.value in self.values:
|
|
# This type matches, so out it goes
|
|
return cst.RemoveFromParent()
|
|
return updated_node
|
|
|
|
|
|
def _remove_types(
|
|
oldtype: cst.BaseExpression, values: Sequence[str]
|
|
) -> cst.BaseExpression:
|
|
"""
|
|
Given a BaseExpression from a type, return a new BaseExpression that does not
|
|
refer to any types listed in values.
|
|
"""
|
|
return ensure_type(
|
|
oldtype.visit(RemoveTypesFromGeneric(values)), cst.BaseExpression
|
|
)
|
|
|
|
|
|
class MatcherClassToLibCSTClass(cst.CSTTransformer):
|
|
def leave_SimpleString(
|
|
self, original_node: cst.SimpleString, updated_node: cst.SimpleString
|
|
) -> Union[cst.SimpleString, cst.Attribute]:
|
|
value = updated_node.evaluated_value
|
|
if value in CST_DIR:
|
|
return cst.Attribute(cst.Name("cst"), cst.Name(value))
|
|
return updated_node
|
|
|
|
|
|
def _convert_match_nodes_to_cst_nodes(
|
|
matchtype: cst.BaseExpression,
|
|
) -> cst.BaseExpression:
|
|
"""
|
|
Given a BaseExpression in a type, convert this to a new BaseExpression that refers
|
|
to LibCST nodes instead of forward references to matcher nodes.
|
|
"""
|
|
return ensure_type(matchtype.visit(MatcherClassToLibCSTClass()), cst.BaseExpression)
|
|
|
|
|
|
def _get_match_if_true(oldtype: cst.BaseExpression) -> cst.SubscriptElement:
|
|
"""
|
|
Construct a MatchIfTrue type node appropriate for going into a Union.
|
|
"""
|
|
return cst.SubscriptElement(
|
|
cst.Index(
|
|
cst.Subscript(
|
|
cst.Name("MatchIfTrue"),
|
|
slice=(
|
|
cst.SubscriptElement(
|
|
cst.Index(
|
|
# MatchIfTrue takes in the original node type,
|
|
# and returns a boolean. So, lets convert our
|
|
# quoted classes (forward refs to other
|
|
# matchers) back to the CSTNode they refer to.
|
|
# We can do this because there's always a 1:1
|
|
# name mapping.
|
|
_convert_match_nodes_to_cst_nodes(oldtype)
|
|
),
|
|
),
|
|
),
|
|
)
|
|
)
|
|
)
|
|
|
|
|
|
def _add_generic(name: str, oldtype: cst.BaseExpression) -> cst.BaseExpression:
|
|
return cst.Subscript(cst.Name(name), (cst.SubscriptElement(cst.Index(oldtype)),))
|
|
|
|
|
|
class AddLogicMatchersToUnions(cst.CSTTransformer):
|
|
def leave_Subscript(
|
|
self, original_node: cst.Subscript, updated_node: cst.Subscript
|
|
) -> cst.Subscript:
|
|
if updated_node.value.deep_equals(cst.Name("Union")):
|
|
# Take the original node, remove do not care so we have concrete types.
|
|
# Explicitly taking the original node because we want to discard nested
|
|
# changes.
|
|
concrete_only_expr = _remove_types(updated_node, ["DoNotCareSentinel"])
|
|
return updated_node.with_changes(
|
|
slice=[
|
|
*updated_node.slice,
|
|
cst.SubscriptElement(
|
|
cst.Index(_add_generic("OneOf", concrete_only_expr))
|
|
),
|
|
cst.SubscriptElement(
|
|
cst.Index(_add_generic("AllOf", concrete_only_expr))
|
|
),
|
|
]
|
|
)
|
|
return updated_node
|
|
|
|
|
|
class AddWildcardsToSequenceUnions(cst.CSTTransformer):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.in_match_if_true: Set[cst.CSTNode] = set()
|
|
self.fixup_nodes: Set[cst.Subscript] = set()
|
|
|
|
def visit_Subscript(self, node: cst.Subscript) -> None:
|
|
# If the current node is a MatchIfTrue, we don't want to modify it.
|
|
if node.value.deep_equals(cst.Name("MatchIfTrue")):
|
|
self.in_match_if_true.add(node)
|
|
# If the direct descendant is a union, lets add it to be fixed up.
|
|
elif node.value.deep_equals(cst.Name("Sequence")):
|
|
if self.in_match_if_true:
|
|
# We don't want to add AtLeastN/AtMostN inside MatchIfTrue
|
|
# type blocks, even for sequence types.
|
|
return
|
|
if len(node.slice) != 1:
|
|
raise Exception(
|
|
"Unexpected number of sequence elements inside Sequence type "
|
|
+ "annotation!"
|
|
)
|
|
nodeslice = node.slice[0].slice
|
|
if isinstance(nodeslice, cst.Index):
|
|
possibleunion = nodeslice.value
|
|
if isinstance(possibleunion, cst.Subscript):
|
|
if possibleunion.value.deep_equals(cst.Name("Union")):
|
|
self.fixup_nodes.add(possibleunion)
|
|
|
|
def leave_Subscript(
|
|
self, original_node: cst.Subscript, updated_node: cst.Subscript
|
|
) -> cst.Subscript:
|
|
if original_node in self.in_match_if_true:
|
|
self.in_match_if_true.remove(original_node)
|
|
if original_node in self.fixup_nodes:
|
|
self.fixup_nodes.remove(original_node)
|
|
return updated_node.with_changes(
|
|
slice=[
|
|
*updated_node.slice,
|
|
cst.SubscriptElement(
|
|
cst.Index(_add_generic("AtLeastN", original_node))
|
|
),
|
|
cst.SubscriptElement(
|
|
cst.Index(_add_generic("AtMostN", original_node))
|
|
),
|
|
]
|
|
)
|
|
return updated_node
|
|
|
|
|
|
def _get_do_not_care() -> cst.SubscriptElement:
|
|
"""
|
|
Construct a DoNotCareSentinel entry appropriate for going into a Union.
|
|
"""
|
|
|
|
return cst.SubscriptElement(cst.Index(cst.Name("DoNotCareSentinel")))
|
|
|
|
|
|
def _get_match_metadata() -> cst.SubscriptElement:
|
|
"""
|
|
Construct a MetadataMatchType entry appropriate for going into a Union.
|
|
"""
|
|
|
|
return cst.SubscriptElement(cst.Index(cst.Name("MetadataMatchType")))
|
|
|
|
|
|
def _get_wrapped_union_type(
|
|
node: cst.BaseExpression,
|
|
addition: cst.SubscriptElement,
|
|
*additions: cst.SubscriptElement,
|
|
) -> cst.Subscript:
|
|
"""
|
|
Take two or more nodes, wrap them in a union type. Function signature is
|
|
explicitly defined as taking at least one addition for type safety.
|
|
|
|
"""
|
|
|
|
return cst.Subscript(
|
|
cst.Name("Union"), [cst.SubscriptElement(cst.Index(node)), addition, *additions]
|
|
)
|
|
|
|
|
|
# List of global aliases we've already generated, so we don't redefine types
|
|
_global_aliases: Set[str] = set()
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Alias:
|
|
name: str
|
|
type: str
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Field:
|
|
name: str
|
|
type: str
|
|
aliases: List[Alias]
|
|
|
|
|
|
def _get_raw_name(node: cst.CSTNode) -> Optional[str]:
|
|
if isinstance(node, cst.Name):
|
|
return node.value
|
|
elif isinstance(node, cst.SimpleString):
|
|
evaluated_value = node.evaluated_value
|
|
if isinstance(evaluated_value, str):
|
|
return evaluated_value
|
|
elif isinstance(node, cst.SubscriptElement):
|
|
return _get_raw_name(node.slice)
|
|
elif isinstance(node, cst.Index):
|
|
return _get_raw_name(node.value)
|
|
else:
|
|
return None
|
|
|
|
|
|
def _get_alias_name(node: cst.CSTNode) -> Optional[str]:
|
|
if isinstance(node, (cst.Name, cst.SimpleString)):
|
|
return f"{_get_raw_name(node)}MatchType"
|
|
elif isinstance(node, cst.Subscript):
|
|
if node.value.deep_equals(cst.Name("Union")):
|
|
names = [_get_raw_name(s) for s in node.slice]
|
|
if any(n is None for n in names):
|
|
return None
|
|
return "Or".join(n for n in names if n is not None) + "MatchType"
|
|
|
|
return None
|
|
|
|
|
|
def _wrap_clean_type(
|
|
aliases: List[Alias], name: Optional[str], value: cst.Subscript
|
|
) -> cst.BaseExpression:
|
|
if name is not None:
|
|
# We created an alias, lets use that, wrapping the alias in a do not care.
|
|
aliases.append(Alias(name=name, type=cst.Module(body=()).code_for_node(value)))
|
|
return _get_wrapped_union_type(cst.Name(name), _get_do_not_care())
|
|
else:
|
|
# Couldn't name the alias, fall back to regular node creation, add do not
|
|
# care to the resulting type we widened.
|
|
return value.with_changes(slice=[*value.slice, _get_do_not_care()])
|
|
|
|
|
|
def _get_clean_type_from_expression(
|
|
aliases: List[Alias], typecst: cst.BaseExpression
|
|
) -> cst.BaseExpression:
|
|
name = _get_alias_name(typecst)
|
|
value = _get_wrapped_union_type(
|
|
typecst, _get_match_metadata(), _get_match_if_true(typecst)
|
|
)
|
|
return _wrap_clean_type(aliases, name, value)
|
|
|
|
|
|
def _maybe_fix_sequence_in_union(
|
|
aliases: List[Alias], typecst: cst.SubscriptElement
|
|
) -> cst.SubscriptElement:
|
|
slc = typecst.slice
|
|
if isinstance(slc, cst.Index):
|
|
val = slc.value
|
|
if isinstance(val, cst.Subscript):
|
|
return cst.ensure_type(
|
|
typecst.deep_replace(val, _get_clean_type_from_subscript(aliases, val)),
|
|
cst.SubscriptElement,
|
|
)
|
|
return typecst
|
|
|
|
|
|
def _get_clean_type_from_union(
|
|
aliases: List[Alias], typecst: cst.Subscript
|
|
) -> cst.BaseExpression:
|
|
name = _get_alias_name(typecst)
|
|
value = typecst.with_changes(
|
|
slice=[
|
|
*[_maybe_fix_sequence_in_union(aliases, slc) for slc in typecst.slice],
|
|
_get_match_metadata(),
|
|
_get_match_if_true(typecst),
|
|
]
|
|
)
|
|
return _wrap_clean_type(aliases, name, value)
|
|
|
|
|
|
def _get_clean_type_from_subscript(
|
|
aliases: List[Alias], typecst: cst.Subscript
|
|
) -> cst.BaseExpression:
|
|
if typecst.value.deep_equals(cst.Name("Sequence")):
|
|
# Lets attempt to widen the sequence type and alias it.
|
|
if len(typecst.slice) != 1:
|
|
raise Exception("Logic error, Sequence shouldn't have more than one param!")
|
|
inner_type = typecst.slice[0].slice
|
|
if not isinstance(inner_type, cst.Index):
|
|
raise Exception("Logic error, expecting Index for only Sequence element!")
|
|
inner_type = inner_type.value
|
|
|
|
if isinstance(inner_type, cst.Subscript):
|
|
clean_inner_type = _get_clean_type_from_subscript(aliases, inner_type)
|
|
elif isinstance(inner_type, (cst.Name, cst.SimpleString)):
|
|
clean_inner_type = _get_clean_type_from_expression(aliases, inner_type)
|
|
else:
|
|
raise Exception("Logic error, unexpected type in Sequence!")
|
|
|
|
return _get_wrapped_union_type(
|
|
typecst.deep_replace(inner_type, clean_inner_type),
|
|
_get_do_not_care(),
|
|
_get_match_if_true(typecst),
|
|
)
|
|
# We can modify this as-is to add our extra values
|
|
elif typecst.value.deep_equals(cst.Name("Union")):
|
|
return _get_clean_type_from_union(aliases, typecst)
|
|
else:
|
|
# Don't handle other types like "Literal", just widen them.
|
|
return _get_clean_type_from_expression(aliases, typecst)
|
|
|
|
|
|
def _get_clean_type_and_aliases(
|
|
typeobj: object,
|
|
) -> Tuple[str, List[Alias]]: # noqa: C901
|
|
"""
|
|
Given a type object as returned by dataclasses, sanitize it and convert it
|
|
to a type string that is appropriate for our codegen below.
|
|
"""
|
|
|
|
# First, get the type as a parseable expression.
|
|
typestr = repr(typeobj)
|
|
typestr = re.sub(CLASS_RE, r"\1", typestr)
|
|
typestr = re.sub(OPTIONAL_RE, r"typing.Optional[\1]", typestr)
|
|
|
|
# Now, parse the expression with LibCST.
|
|
cleanser = CleanseFullTypeNames()
|
|
typecst = parse_expression(typestr)
|
|
typecst = typecst.visit(cleanser)
|
|
aliases: List[Alias] = []
|
|
|
|
# Now, convert the type to allow for MetadataMatchType and MatchIfTrue values.
|
|
if isinstance(typecst, cst.Subscript):
|
|
clean_type = _get_clean_type_from_subscript(aliases, typecst)
|
|
elif isinstance(typecst, (cst.Name, cst.SimpleString)):
|
|
clean_type = _get_clean_type_from_expression(aliases, typecst)
|
|
else:
|
|
raise Exception("Logic error, unexpected top level type!")
|
|
|
|
# Now, insert OneOf/AllOf and MatchIfTrue into unions so we can typecheck their usage.
|
|
# This allows us to put OneOf[SomeType] or MatchIfTrue[cst.SomeType] into any
|
|
# spot that we would have originally allowed a SomeType.
|
|
clean_type = ensure_type(clean_type.visit(AddLogicMatchersToUnions()), cst.CSTNode)
|
|
# Now, insert AtMostN and AtLeastN into sequence unions, so we can typecheck
|
|
# them. This relies on the previous OneOf/AllOf insertion to ensure that all
|
|
# sequences we care about are Sequence[Union[<x>]].
|
|
clean_type = ensure_type(
|
|
clean_type.visit(AddWildcardsToSequenceUnions()), cst.CSTNode
|
|
)
|
|
# Finally, generate the code given a default Module so we can spit it out.
|
|
return cst.Module(body=()).code_for_node(clean_type), aliases
|
|
|
|
|
|
def _get_fields(node: Type[cst.CSTNode]) -> Generator[Field, None, None]:
|
|
"""
|
|
Given a CSTNode, generate a field name and type string for each.
|
|
"""
|
|
|
|
for field in fields(node) or []:
|
|
if field.name == "_metadata":
|
|
continue
|
|
|
|
fieldtype, aliases = _get_clean_type_and_aliases(field.type)
|
|
yield Field(
|
|
name=field.name,
|
|
type=fieldtype,
|
|
aliases=[a for a in aliases if a.name not in _global_aliases],
|
|
)
|
|
_global_aliases.update(a.name for a in aliases)
|
|
|
|
|
|
all_exports: Set[str] = set()
|
|
generated_code: List[str] = []
|
|
generated_code.append("# Copyright (c) Meta Platforms, Inc. and affiliates.")
|
|
generated_code.append("#")
|
|
generated_code.append(
|
|
"# This source code is licensed under the MIT license found in the"
|
|
)
|
|
generated_code.append("# LICENSE file in the root directory of this source tree.")
|
|
generated_code.append("")
|
|
generated_code.append("")
|
|
generated_code.append("# This file was generated by libcst.codegen.gen_matcher_classes")
|
|
generated_code.append("from dataclasses import dataclass")
|
|
generated_code.append("from typing import Literal, Optional, Sequence, Union")
|
|
generated_code.append("import libcst as cst")
|
|
generated_code.append("")
|
|
generated_code.append(
|
|
"from libcst.matchers._matcher_base import AbstractBaseMatcherNodeMeta, BaseMatcherNode, DoNotCareSentinel, DoNotCare, TypeOf, OneOf, AllOf, DoesNotMatch, MatchIfTrue, MatchRegex, MatchMetadata, MatchMetadataIfTrue, ZeroOrMore, AtLeastN, ZeroOrOne, AtMostN, SaveMatchedNode, extract, extractall, findall, matches, replace"
|
|
)
|
|
all_exports.update(
|
|
[
|
|
"BaseMatcherNode",
|
|
"DoNotCareSentinel",
|
|
"DoNotCare",
|
|
"OneOf",
|
|
"AllOf",
|
|
"DoesNotMatch",
|
|
"MatchIfTrue",
|
|
"MatchRegex",
|
|
"MatchMetadata",
|
|
"MatchMetadataIfTrue",
|
|
"TypeOf",
|
|
"ZeroOrMore",
|
|
"AtLeastN",
|
|
"ZeroOrOne",
|
|
"AtMostN",
|
|
"SaveMatchedNode",
|
|
"extract",
|
|
"extractall",
|
|
"findall",
|
|
"matches",
|
|
"replace",
|
|
]
|
|
)
|
|
generated_code.append(
|
|
"from libcst.matchers._decorators import call_if_inside, call_if_not_inside, visit, leave"
|
|
)
|
|
all_exports.update(["call_if_inside", "call_if_not_inside", "visit", "leave"])
|
|
generated_code.append(
|
|
"from libcst.matchers._visitors import MatchDecoratorMismatch, MatcherDecoratableTransformer, MatcherDecoratableVisitor"
|
|
)
|
|
all_exports.update(
|
|
[
|
|
"MatchDecoratorMismatch",
|
|
"MatcherDecoratableTransformer",
|
|
"MatcherDecoratableVisitor",
|
|
]
|
|
)
|
|
|
|
generated_code.append("")
|
|
generated_code.append("")
|
|
generated_code.append("class _NodeABC(metaclass=AbstractBaseMatcherNodeMeta):")
|
|
generated_code.append(" __slots__ = ()")
|
|
|
|
for base in typeclasses:
|
|
generated_code.append("")
|
|
generated_code.append("")
|
|
generated_code.append(f"class {base.__name__}(_NodeABC):")
|
|
generated_code.append(" pass")
|
|
all_exports.add(base.__name__)
|
|
|
|
|
|
# Add a generic MetadataMatchType to be referred to by everywhere else.
|
|
generated_code.append("")
|
|
generated_code.append("")
|
|
generated_code.append("MetadataMatchType = Union[MatchMetadata, MatchMetadataIfTrue]")
|
|
|
|
|
|
for node in all_libcst_nodes:
|
|
if node.__name__.startswith("Base"):
|
|
continue
|
|
classes: List[str] = []
|
|
for tc in typeclasses:
|
|
if issubclass(node, tc):
|
|
classes.append(tc.__name__)
|
|
classes.append("BaseMatcherNode")
|
|
|
|
has_aliases = False
|
|
node_fields = list(_get_fields(node))
|
|
for field in node_fields:
|
|
for alias in field.aliases:
|
|
# Output a separator if we're going to output any aliases
|
|
if not has_aliases:
|
|
generated_code.append("")
|
|
generated_code.append("")
|
|
has_aliases = True
|
|
|
|
# Must generate code for aliases before the class they are referenced in
|
|
generated_code.append(f"{alias.name} = {alias.type}")
|
|
|
|
generated_code.append("")
|
|
generated_code.append("")
|
|
generated_code.append("@dataclass(frozen=True, eq=False, unsafe_hash=False)")
|
|
generated_code.append(f'class {node.__name__}({", ".join(classes)}):')
|
|
all_exports.add(node.__name__)
|
|
|
|
fields_printed = False
|
|
for field in node_fields:
|
|
fields_printed = True
|
|
generated_code.append(f" {field.name}: {field.type} = DoNotCare()")
|
|
|
|
# Add special metadata field
|
|
generated_code.append(
|
|
" metadata: Union[MetadataMatchType, DoNotCareSentinel, OneOf[MetadataMatchType], AllOf[MetadataMatchType]] = DoNotCare()"
|
|
)
|
|
|
|
|
|
# Make sure to add an __all__ for flake8 and compatibility with "from libcst.matchers import *"
|
|
generated_code.append(f"__all__ = {repr(sorted(all_exports))}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Output the code
|
|
print("\n".join(generated_code))
|