# Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import re from dataclasses import dataclass, fields from typing import Generator, List, Optional, Sequence, Set, Tuple, Type, Union import libcst as cst from libcst import ensure_type, parse_expression from libcst.codegen.gather import all_libcst_nodes, typeclasses CST_DIR: Set[str] = set(dir(cst)) CLASS_RE = r"" OPTIONAL_RE = r"typing\.Union\[([^,]*?), NoneType]" class CleanseFullTypeNames(cst.CSTTransformer): def leave_Call( self, original_node: cst.Call, updated_node: cst.Call ) -> cst.BaseExpression: # Convert forward ref repr back to a SimpleString. if isinstance(updated_node.func, cst.Name) and ( updated_node.func.deep_equals(cst.Name("_ForwardRef")) or updated_node.func.deep_equals(cst.Name("ForwardRef")) ): return updated_node.args[0].value return updated_node def leave_Attribute( self, original_node: cst.Attribute, updated_node: cst.Attribute ) -> Union[cst.Attribute, cst.Name]: # Unwrap all attributes, so things like libcst.x.y.Name becomes Name return updated_node.attr def leave_Name( self, original_node: cst.Name, updated_node: cst.Name ) -> Union[cst.Name, cst.SimpleString]: value = updated_node.value if value == "NoneType": # This is special-cased in typing, un-special case it. return updated_node.with_changes(value="None") if value in CST_DIR and not value.endswith("Sentinel"): # If this isn't a typing define and it isn't a builtin, convert it to # a forward ref string. return cst.SimpleString(repr(value)) return updated_node def leave_SubscriptElement( self, original_node: cst.SubscriptElement, updated_node: cst.SubscriptElement ) -> Union[cst.SubscriptElement, cst.RemovalSentinel]: slc = updated_node.slice if isinstance(slc, cst.Index): val = slc.value if isinstance(val, cst.Name): if "Sentinel" in val.value: # We don't support maybes in matchers. return cst.RemoveFromParent() # Simple trick to kill trailing commas return updated_node.with_changes(comma=cst.MaybeSentinel.DEFAULT) class RemoveTypesFromGeneric(cst.CSTTransformer): def __init__(self, values: Sequence[str]) -> None: self.values: Set[str] = set(values) def leave_SubscriptElement( self, original_node: cst.SubscriptElement, updated_node: cst.SubscriptElement ) -> Union[cst.SubscriptElement, cst.RemovalSentinel]: slc = updated_node.slice if isinstance(slc, cst.Index): val = slc.value if isinstance(val, cst.Name): if val.value in self.values: # This type matches, so out it goes return cst.RemoveFromParent() return updated_node def _remove_types( oldtype: cst.BaseExpression, values: Sequence[str] ) -> cst.BaseExpression: """ Given a BaseExpression from a type, return a new BaseExpression that does not refer to any types listed in values. """ return ensure_type( oldtype.visit(RemoveTypesFromGeneric(values)), cst.BaseExpression ) class MatcherClassToLibCSTClass(cst.CSTTransformer): def leave_SimpleString( self, original_node: cst.SimpleString, updated_node: cst.SimpleString ) -> Union[cst.SimpleString, cst.Attribute]: value = updated_node.evaluated_value if value in CST_DIR: return cst.Attribute(cst.Name("cst"), cst.Name(value)) return updated_node def _convert_match_nodes_to_cst_nodes( matchtype: cst.BaseExpression, ) -> cst.BaseExpression: """ Given a BaseExpression in a type, convert this to a new BaseExpression that refers to LibCST nodes instead of forward references to matcher nodes. """ return ensure_type(matchtype.visit(MatcherClassToLibCSTClass()), cst.BaseExpression) def _get_match_if_true(oldtype: cst.BaseExpression) -> cst.SubscriptElement: """ Construct a MatchIfTrue type node appropriate for going into a Union. """ return cst.SubscriptElement( cst.Index( cst.Subscript( cst.Name("MatchIfTrue"), slice=( cst.SubscriptElement( cst.Index( # MatchIfTrue takes in the original node type, # and returns a boolean. So, lets convert our # quoted classes (forward refs to other # matchers) back to the CSTNode they refer to. # We can do this because there's always a 1:1 # name mapping. _convert_match_nodes_to_cst_nodes(oldtype) ), ), ), ) ) ) def _add_generic(name: str, oldtype: cst.BaseExpression) -> cst.BaseExpression: return cst.Subscript(cst.Name(name), (cst.SubscriptElement(cst.Index(oldtype)),)) class AddLogicMatchersToUnions(cst.CSTTransformer): def leave_Subscript( self, original_node: cst.Subscript, updated_node: cst.Subscript ) -> cst.Subscript: if updated_node.value.deep_equals(cst.Name("Union")): # Take the original node, remove do not care so we have concrete types. # Explicitly taking the original node because we want to discard nested # changes. concrete_only_expr = _remove_types(updated_node, ["DoNotCareSentinel"]) return updated_node.with_changes( slice=[ *updated_node.slice, cst.SubscriptElement( cst.Index(_add_generic("OneOf", concrete_only_expr)) ), cst.SubscriptElement( cst.Index(_add_generic("AllOf", concrete_only_expr)) ), ] ) return updated_node class AddWildcardsToSequenceUnions(cst.CSTTransformer): def __init__(self) -> None: super().__init__() self.in_match_if_true: Set[cst.CSTNode] = set() self.fixup_nodes: Set[cst.Subscript] = set() def visit_Subscript(self, node: cst.Subscript) -> None: # If the current node is a MatchIfTrue, we don't want to modify it. if node.value.deep_equals(cst.Name("MatchIfTrue")): self.in_match_if_true.add(node) # If the direct descendant is a union, lets add it to be fixed up. elif node.value.deep_equals(cst.Name("Sequence")): if self.in_match_if_true: # We don't want to add AtLeastN/AtMostN inside MatchIfTrue # type blocks, even for sequence types. return if len(node.slice) != 1: raise Exception( "Unexpected number of sequence elements inside Sequence type " + "annotation!" ) nodeslice = node.slice[0].slice if isinstance(nodeslice, cst.Index): possibleunion = nodeslice.value if isinstance(possibleunion, cst.Subscript): if possibleunion.value.deep_equals(cst.Name("Union")): self.fixup_nodes.add(possibleunion) def leave_Subscript( self, original_node: cst.Subscript, updated_node: cst.Subscript ) -> cst.Subscript: if original_node in self.in_match_if_true: self.in_match_if_true.remove(original_node) if original_node in self.fixup_nodes: self.fixup_nodes.remove(original_node) return updated_node.with_changes( slice=[ *updated_node.slice, cst.SubscriptElement( cst.Index(_add_generic("AtLeastN", original_node)) ), cst.SubscriptElement( cst.Index(_add_generic("AtMostN", original_node)) ), ] ) return updated_node def _get_do_not_care() -> cst.SubscriptElement: """ Construct a DoNotCareSentinel entry appropriate for going into a Union. """ return cst.SubscriptElement(cst.Index(cst.Name("DoNotCareSentinel"))) def _get_match_metadata() -> cst.SubscriptElement: """ Construct a MetadataMatchType entry appropriate for going into a Union. """ return cst.SubscriptElement(cst.Index(cst.Name("MetadataMatchType"))) def _get_wrapped_union_type( node: cst.BaseExpression, addition: cst.SubscriptElement, *additions: cst.SubscriptElement, ) -> cst.Subscript: """ Take two or more nodes, wrap them in a union type. Function signature is explicitly defined as taking at least one addition for type safety. """ return cst.Subscript( cst.Name("Union"), [cst.SubscriptElement(cst.Index(node)), addition, *additions] ) # List of global aliases we've already generated, so we don't redefine types _global_aliases: Set[str] = set() @dataclass(frozen=True) class Alias: name: str type: str @dataclass(frozen=True) class Field: name: str type: str aliases: List[Alias] def _get_raw_name(node: cst.CSTNode) -> Optional[str]: if isinstance(node, cst.Name): return node.value elif isinstance(node, cst.SimpleString): evaluated_value = node.evaluated_value if isinstance(evaluated_value, str): return evaluated_value elif isinstance(node, cst.SubscriptElement): return _get_raw_name(node.slice) elif isinstance(node, cst.Index): return _get_raw_name(node.value) else: return None def _get_alias_name(node: cst.CSTNode) -> Optional[str]: if isinstance(node, (cst.Name, cst.SimpleString)): return f"{_get_raw_name(node)}MatchType" elif isinstance(node, cst.Subscript): if node.value.deep_equals(cst.Name("Union")): names = [_get_raw_name(s) for s in node.slice] if any(n is None for n in names): return None return "Or".join(n for n in names if n is not None) + "MatchType" return None def _wrap_clean_type( aliases: List[Alias], name: Optional[str], value: cst.Subscript ) -> cst.BaseExpression: if name is not None: # We created an alias, lets use that, wrapping the alias in a do not care. aliases.append(Alias(name=name, type=cst.Module(body=()).code_for_node(value))) return _get_wrapped_union_type(cst.Name(name), _get_do_not_care()) else: # Couldn't name the alias, fall back to regular node creation, add do not # care to the resulting type we widened. return value.with_changes(slice=[*value.slice, _get_do_not_care()]) def _get_clean_type_from_expression( aliases: List[Alias], typecst: cst.BaseExpression ) -> cst.BaseExpression: name = _get_alias_name(typecst) value = _get_wrapped_union_type( typecst, _get_match_metadata(), _get_match_if_true(typecst) ) return _wrap_clean_type(aliases, name, value) def _maybe_fix_sequence_in_union( aliases: List[Alias], typecst: cst.SubscriptElement ) -> cst.SubscriptElement: slc = typecst.slice if isinstance(slc, cst.Index): val = slc.value if isinstance(val, cst.Subscript): return cst.ensure_type( typecst.deep_replace(val, _get_clean_type_from_subscript(aliases, val)), cst.SubscriptElement, ) return typecst def _get_clean_type_from_union( aliases: List[Alias], typecst: cst.Subscript ) -> cst.BaseExpression: name = _get_alias_name(typecst) value = typecst.with_changes( slice=[ *[_maybe_fix_sequence_in_union(aliases, slc) for slc in typecst.slice], _get_match_metadata(), _get_match_if_true(typecst), ] ) return _wrap_clean_type(aliases, name, value) def _get_clean_type_from_subscript( aliases: List[Alias], typecst: cst.Subscript ) -> cst.BaseExpression: if typecst.value.deep_equals(cst.Name("Sequence")): # Lets attempt to widen the sequence type and alias it. if len(typecst.slice) != 1: raise Exception("Logic error, Sequence shouldn't have more than one param!") inner_type = typecst.slice[0].slice if not isinstance(inner_type, cst.Index): raise Exception("Logic error, expecting Index for only Sequence element!") inner_type = inner_type.value if isinstance(inner_type, cst.Subscript): clean_inner_type = _get_clean_type_from_subscript(aliases, inner_type) elif isinstance(inner_type, (cst.Name, cst.SimpleString)): clean_inner_type = _get_clean_type_from_expression(aliases, inner_type) else: raise Exception("Logic error, unexpected type in Sequence!") return _get_wrapped_union_type( typecst.deep_replace(inner_type, clean_inner_type), _get_do_not_care(), _get_match_if_true(typecst), ) # We can modify this as-is to add our extra values elif typecst.value.deep_equals(cst.Name("Union")): return _get_clean_type_from_union(aliases, typecst) else: # Don't handle other types like "Literal", just widen them. return _get_clean_type_from_expression(aliases, typecst) def _get_clean_type_and_aliases( typeobj: object, ) -> Tuple[str, List[Alias]]: # noqa: C901 """ Given a type object as returned by dataclasses, sanitize it and convert it to a type string that is appropriate for our codegen below. """ # First, get the type as a parseable expression. typestr = repr(typeobj) typestr = re.sub(CLASS_RE, r"\1", typestr) typestr = re.sub(OPTIONAL_RE, r"typing.Optional[\1]", typestr) # Now, parse the expression with LibCST. cleanser = CleanseFullTypeNames() typecst = parse_expression(typestr) typecst = typecst.visit(cleanser) aliases: List[Alias] = [] # Now, convert the type to allow for MetadataMatchType and MatchIfTrue values. if isinstance(typecst, cst.Subscript): clean_type = _get_clean_type_from_subscript(aliases, typecst) elif isinstance(typecst, (cst.Name, cst.SimpleString)): clean_type = _get_clean_type_from_expression(aliases, typecst) else: raise Exception("Logic error, unexpected top level type!") # Now, insert OneOf/AllOf and MatchIfTrue into unions so we can typecheck their usage. # This allows us to put OneOf[SomeType] or MatchIfTrue[cst.SomeType] into any # spot that we would have originally allowed a SomeType. clean_type = ensure_type(clean_type.visit(AddLogicMatchersToUnions()), cst.CSTNode) # Now, insert AtMostN and AtLeastN into sequence unions, so we can typecheck # them. This relies on the previous OneOf/AllOf insertion to ensure that all # sequences we care about are Sequence[Union[]]. clean_type = ensure_type( clean_type.visit(AddWildcardsToSequenceUnions()), cst.CSTNode ) # Finally, generate the code given a default Module so we can spit it out. return cst.Module(body=()).code_for_node(clean_type), aliases def _get_fields(node: Type[cst.CSTNode]) -> Generator[Field, None, None]: """ Given a CSTNode, generate a field name and type string for each. """ for field in fields(node) or []: if field.name == "_metadata": continue fieldtype, aliases = _get_clean_type_and_aliases(field.type) yield Field( name=field.name, type=fieldtype, aliases=[a for a in aliases if a.name not in _global_aliases], ) _global_aliases.update(a.name for a in aliases) all_exports: Set[str] = set() generated_code: List[str] = [] generated_code.append("# Copyright (c) Meta Platforms, Inc. and affiliates.") generated_code.append("#") generated_code.append( "# This source code is licensed under the MIT license found in the" ) generated_code.append("# LICENSE file in the root directory of this source tree.") generated_code.append("") generated_code.append("") generated_code.append("# This file was generated by libcst.codegen.gen_matcher_classes") generated_code.append("from dataclasses import dataclass") generated_code.append("from typing import Optional, Sequence, Union") generated_code.append("from typing_extensions import Literal") generated_code.append("import libcst as cst") generated_code.append("") generated_code.append( "from libcst.matchers._matcher_base import AbstractBaseMatcherNodeMeta, BaseMatcherNode, DoNotCareSentinel, DoNotCare, TypeOf, OneOf, AllOf, DoesNotMatch, MatchIfTrue, MatchRegex, MatchMetadata, MatchMetadataIfTrue, ZeroOrMore, AtLeastN, ZeroOrOne, AtMostN, SaveMatchedNode, extract, extractall, findall, matches, replace" ) all_exports.update( [ "BaseMatcherNode", "DoNotCareSentinel", "DoNotCare", "OneOf", "AllOf", "DoesNotMatch", "MatchIfTrue", "MatchRegex", "MatchMetadata", "MatchMetadataIfTrue", "TypeOf", "ZeroOrMore", "AtLeastN", "ZeroOrOne", "AtMostN", "SaveMatchedNode", "extract", "extractall", "findall", "matches", "replace", ] ) generated_code.append( "from libcst.matchers._decorators import call_if_inside, call_if_not_inside, visit, leave" ) all_exports.update(["call_if_inside", "call_if_not_inside", "visit", "leave"]) generated_code.append( "from libcst.matchers._visitors import MatchDecoratorMismatch, MatcherDecoratableTransformer, MatcherDecoratableVisitor" ) all_exports.update( [ "MatchDecoratorMismatch", "MatcherDecoratableTransformer", "MatcherDecoratableVisitor", ] ) generated_code.append("") generated_code.append("") generated_code.append("class _NodeABC(metaclass=AbstractBaseMatcherNodeMeta):") generated_code.append(" __slots__ = ()") for base in typeclasses: generated_code.append("") generated_code.append("") generated_code.append(f"class {base.__name__}(_NodeABC):") generated_code.append(" pass") all_exports.add(base.__name__) # Add a generic MetadataMatchType to be referred to by everywhere else. generated_code.append("") generated_code.append("") generated_code.append("MetadataMatchType = Union[MatchMetadata, MatchMetadataIfTrue]") for node in all_libcst_nodes: if node.__name__.startswith("Base"): continue classes: List[str] = [] for tc in typeclasses: if issubclass(node, tc): classes.append(tc.__name__) classes.append("BaseMatcherNode") has_aliases = False node_fields = list(_get_fields(node)) for field in node_fields: for alias in field.aliases: # Output a separator if we're going to output any aliases if not has_aliases: generated_code.append("") generated_code.append("") has_aliases = True # Must generate code for aliases before the class they are referenced in generated_code.append(f"{alias.name} = {alias.type}") generated_code.append("") generated_code.append("") generated_code.append("@dataclass(frozen=True, eq=False, unsafe_hash=False)") generated_code.append(f'class {node.__name__}({", ".join(classes)}):') all_exports.add(node.__name__) fields_printed = False for field in node_fields: fields_printed = True generated_code.append(f" {field.name}: {field.type} = DoNotCare()") # Add special metadata field generated_code.append( " metadata: Union[MetadataMatchType, DoNotCareSentinel, OneOf[MetadataMatchType], AllOf[MetadataMatchType]] = DoNotCare()" ) # Make sure to add an __all__ for flake8 and compatibility with "from libcst.matchers import *" generated_code.append(f"__all__ = {repr(sorted(all_exports))}") if __name__ == "__main__": # Output the code print("\n".join(generated_code))