Add FlattenSentinel to support replacing a statement with multiple statements (#455)

* Add flatten_sentinal

* Add FlattenSentinal to __all__

* Fix lint errors

* Fix type errors

* Update test to use leave_Return

* Update and run codegen

* Add empty test

* Update docs

* autofix
This commit is contained in:
Caleb Donovick 2021-03-22 23:23:40 -07:00 committed by GitHub
parent 507b453e74
commit 0ee0831eb6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 281 additions and 65 deletions

View file

@ -7,6 +7,7 @@ Visitors
.. autoclass:: libcst.CSTTransformer
.. autofunction:: libcst.RemoveFromParent
.. autoclass:: libcst.RemovalSentinel
.. autoclass:: libcst.FlattenSentinel
Visit and Leave Helper Functions
--------------------------------

View file

@ -5,6 +5,7 @@
from libcst._batched_visitor import BatchableCSTVisitor, visit_batched
from libcst._exceptions import MetadataException, ParserSyntaxError
from libcst._flatten_sentinel import FlattenSentinel
from libcst._maybe_sentinel import MaybeSentinel
from libcst._metadata_dependent import MetadataDependent
from libcst._nodes.base import CSTNode, CSTValidationError
@ -211,6 +212,7 @@ __all__ = [
"CSTValidationError",
"CSTVisitor",
"CSTVisitorT",
"FlattenSentinel",
"MaybeSentinel",
"MetadataException",
"ParserSyntaxError",

View file

@ -0,0 +1,46 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import sys
# PEP 585
if sys.version_info < (3, 9):
from typing import Iterable, Sequence
else:
from collections.abc import Iterable, Sequence
from libcst._types import CSTNodeT_co
class FlattenSentinel(Sequence[CSTNodeT_co]):
"""
A :class:`FlattenSentinel` may be returned by a :meth:`CSTTransformer.on_leave`
method when one wants to replace a node with multiple nodes. The replaced
node must be contained in a `Sequence` attribute such as
:attr:`~libcst.Module.body`. This is generally the case for
:class:`~libcst.BaseStatement` and :class:`~libcst.BaseSmallStatement`.
For example to insert a print before every return::
def leave_Return(
self, original_node: cst.Return, updated_node: cst.Return
) -> Union[cst.Return, cst.RemovalSentinel, cst.FlattenSentinel[cst.BaseSmallStatement]]:
log_stmt = cst.Expr(cst.parse_expression("print('returning')"))
return cst.FlattenSentinel([log_stmt, updated_node])
Returning an empty :class:`FlattenSentinel` is equivalent to returning
:attr:`cst.RemovalSentinel.REMOVE` and is subject to its requirements.
"""
nodes: Sequence[CSTNodeT_co]
def __init__(self, nodes: Iterable[CSTNodeT_co]) -> None:
self.nodes = tuple(nodes)
def __getitem__(self, idx: int) -> CSTNodeT_co:
return self.nodes[idx]
def __len__(self) -> int:
return len(self.nodes)

View file

@ -8,6 +8,7 @@ from copy import deepcopy
from dataclasses import dataclass, field, fields, replace
from typing import Any, Dict, List, Mapping, Sequence, TypeVar, Union, cast
from libcst._flatten_sentinel import FlattenSentinel
from libcst._nodes.internal import CodegenState
from libcst._removal_sentinel import RemovalSentinel
from libcst._type_enforce import is_value_of_type
@ -207,7 +208,7 @@ class CSTNode(ABC):
def visit(
self: _CSTNodeSelfT, visitor: CSTVisitorT
) -> Union[_CSTNodeSelfT, RemovalSentinel]:
) -> Union[_CSTNodeSelfT, RemovalSentinel, FlattenSentinel[_CSTNodeSelfT]]:
"""
Visits the current node, its children, and all transitive children using
the given visitor's callbacks.
@ -234,7 +235,7 @@ class CSTNode(ABC):
leave_result = visitor.on_leave(self, with_updated_children)
# validate return type of the user-defined `visitor.on_leave` method
if not isinstance(leave_result, (CSTNode, RemovalSentinel)):
if not isinstance(leave_result, (CSTNode, RemovalSentinel, FlattenSentinel)):
raise Exception(
"Expected a node of type CSTNode or a RemovalSentinel, "
+ f"but got a return value of {type(leave_result).__name__}"
@ -379,9 +380,9 @@ class CSTNode(ABC):
child, all instances will be replaced.
"""
new_tree = self.visit(_ChildReplacementTransformer(old_node, new_node))
if isinstance(new_tree, RemovalSentinel):
# The above transform never returns RemovalSentinel, so this isn't possible
raise Exception("Logic error, cannot get a RemovalSentinel here!")
if isinstance(new_tree, (FlattenSentinel, RemovalSentinel)):
# The above transform never returns *Sentinel, so this isn't possible
raise Exception("Logic error, cannot get a *Sentinal here!")
return new_tree
def deep_remove(
@ -392,10 +393,16 @@ class CSTNode(ABC):
have previously modified the tree in a way that ``old_node`` appears more than
once as a deep child, all instances will be removed.
"""
return self.visit(
new_tree = self.visit(
_ChildReplacementTransformer(old_node, RemovalSentinel.REMOVE)
)
if isinstance(new_tree, FlattenSentinel):
# The above transform never returns FlattenSentinel, so this isn't possible
raise Exception("Logic error, cannot get a FlattenSentinel here!")
return new_tree
def with_deep_changes(
self: _CSTNodeSelfT, old_node: "CSTNode", **changes: Any
) -> _CSTNodeSelfT:
@ -412,9 +419,9 @@ class CSTNode(ABC):
similar API in the future.
"""
new_tree = self.visit(_ChildWithChangesTransformer(old_node, changes))
if isinstance(new_tree, RemovalSentinel):
if isinstance(new_tree, (FlattenSentinel, RemovalSentinel)):
# This is impossible with the above transform.
raise Exception("Logic error, cannot get a RemovalSentinel here!")
raise Exception("Logic error, cannot get a *Sentinel here!")
return new_tree
def __eq__(self: _CSTNodeSelfT, other: _CSTNodeSelfT) -> bool:

View file

@ -9,6 +9,7 @@ from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Iterable, Iterator, List, Optional, Sequence, Union
from libcst._add_slots import add_slots
from libcst._flatten_sentinel import FlattenSentinel
from libcst._maybe_sentinel import MaybeSentinel
from libcst._removal_sentinel import RemovalSentinel
from libcst._types import CSTNodeT
@ -84,6 +85,13 @@ def visit_required(
f"We got a RemovalSentinel while visiting a {type(node).__name__}. This "
+ "node's parent does not allow it to be removed."
)
elif isinstance(result, FlattenSentinel):
raise TypeError(
f"We got a FlattenSentinel while visiting a {type(node).__name__}. This "
+ "node's parent does not allow for it to be it to be replaced with a "
+ "sequence."
)
visitor.on_leave_attribute(parent, fieldname)
return result
@ -101,6 +109,12 @@ def visit_optional(
return None
visitor.on_visit_attribute(parent, fieldname)
result = node.visit(visitor)
if isinstance(result, FlattenSentinel):
raise TypeError(
f"We got a FlattenSentinel while visiting a {type(node).__name__}. This "
+ "node's parent does not allow for it to be it to be replaced with a "
+ "sequence."
)
visitor.on_leave_attribute(parent, fieldname)
return None if isinstance(result, RemovalSentinel) else result
@ -121,6 +135,12 @@ def visit_sentinel(
return MaybeSentinel.DEFAULT
visitor.on_visit_attribute(parent, fieldname)
result = node.visit(visitor)
if isinstance(result, FlattenSentinel):
raise TypeError(
f"We got a FlattenSentinel while visiting a {type(node).__name__}. This "
+ "node's parent does not allow for it to be it to be replaced with a "
+ "sequence."
)
visitor.on_leave_attribute(parent, fieldname)
return MaybeSentinel.DEFAULT if isinstance(result, RemovalSentinel) else result
@ -138,7 +158,9 @@ def visit_iterable(
visitor.on_visit_attribute(parent, fieldname)
for child in children:
new_child = child.visit(visitor)
if not isinstance(new_child, RemovalSentinel):
if isinstance(new_child, FlattenSentinel):
yield from new_child
elif not isinstance(new_child, RemovalSentinel):
yield new_child
visitor.on_leave_attribute(parent, fieldname)
@ -179,11 +201,17 @@ def visit_body_iterable(
# and the new child is. This means a RemovalSentinel
# caused a child of this node to be dropped, and it
# is now useless.
if (not child._is_removable()) and new_child._is_removable():
continue
# Safe to yield child in this case.
yield new_child
if isinstance(new_child, FlattenSentinel):
for child_ in new_child:
if (not child._is_removable()) and child_._is_removable():
continue
yield child_
else:
if (not child._is_removable()) and new_child._is_removable():
continue
# Safe to yield child in this case.
yield new_child
visitor.on_leave_attribute(parent, fieldname)

View file

@ -0,0 +1,79 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Type, Union
import libcst as cst
from libcst import FlattenSentinel, RemovalSentinel, parse_expression, parse_module
from libcst._nodes.tests.base import CSTNodeTest
from libcst._types import CSTNodeT
from libcst._visitors import CSTTransformer
from libcst.testing.utils import data_provider
class InsertPrintBeforeReturn(CSTTransformer):
def leave_Return(
self, original_node: cst.Return, updated_node: cst.Return
) -> Union[cst.Return, RemovalSentinel, FlattenSentinel[cst.BaseSmallStatement]]:
return FlattenSentinel(
[
cst.Expr(parse_expression("print('returning')")),
updated_node,
]
)
class FlattenLines(CSTTransformer):
def on_leave(
self, original_node: CSTNodeT, updated_node: CSTNodeT
) -> Union[CSTNodeT, RemovalSentinel, FlattenSentinel[cst.SimpleStatementLine]]:
if isinstance(updated_node, cst.SimpleStatementLine):
return FlattenSentinel(
[
cst.SimpleStatementLine(
[stmt.with_changes(semicolon=cst.MaybeSentinel.DEFAULT)]
)
for stmt in updated_node.body
]
)
else:
return updated_node
class RemoveReturnWithEmpty(CSTTransformer):
def leave_Return(
self, original_node: cst.Return, updated_node: cst.Return
) -> Union[cst.Return, RemovalSentinel, FlattenSentinel[cst.BaseSmallStatement]]:
return FlattenSentinel([])
class FlattenBehavior(CSTNodeTest):
@data_provider(
(
("return", "print('returning'); return", InsertPrintBeforeReturn),
(
"print('returning'); return",
"print('returning')\nreturn",
FlattenLines,
),
(
"print('returning')\nreturn",
"print('returning')",
RemoveReturnWithEmpty,
),
)
)
def test_flatten_pass_behavior(
self, before: str, after: str, visitor: Type[CSTTransformer]
) -> None:
# Test doesn't have newline termination case
before_module = parse_module(before)
after_module = before_module.visit(visitor())
self.assertEqual(after, after_module.code)
# Test does have newline termination case
before_module = parse_module(before + "\n")
after_module = before_module.visit(visitor())
self.assertEqual(after + "\n", after_module.code)

View file

@ -7,6 +7,7 @@
# This file was generated by libcst.codegen.gen_matcher_classes
from typing import TYPE_CHECKING, Optional, Union
from libcst._flatten_sentinel import FlattenSentinel
from libcst._maybe_sentinel import MaybeSentinel
from libcst._removal_sentinel import RemovalSentinel
from libcst._typed_visitor_base import mark_no_op
@ -5284,7 +5285,9 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
@mark_no_op
def leave_AnnAssign(
self, original_node: "AnnAssign", updated_node: "AnnAssign"
) -> Union["BaseSmallStatement", RemovalSentinel]:
) -> Union[
"BaseSmallStatement", FlattenSentinel["BaseSmallStatement"], RemovalSentinel
]:
return updated_node
@mark_no_op
@ -5296,7 +5299,7 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
@mark_no_op
def leave_Arg(
self, original_node: "Arg", updated_node: "Arg"
) -> Union["Arg", RemovalSentinel]:
) -> Union["Arg", FlattenSentinel["Arg"], RemovalSentinel]:
return updated_node
@mark_no_op
@ -5306,13 +5309,17 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
@mark_no_op
def leave_Assert(
self, original_node: "Assert", updated_node: "Assert"
) -> Union["BaseSmallStatement", RemovalSentinel]:
) -> Union[
"BaseSmallStatement", FlattenSentinel["BaseSmallStatement"], RemovalSentinel
]:
return updated_node
@mark_no_op
def leave_Assign(
self, original_node: "Assign", updated_node: "Assign"
) -> Union["BaseSmallStatement", RemovalSentinel]:
) -> Union[
"BaseSmallStatement", FlattenSentinel["BaseSmallStatement"], RemovalSentinel
]:
return updated_node
@mark_no_op
@ -5324,7 +5331,7 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
@mark_no_op
def leave_AssignTarget(
self, original_node: "AssignTarget", updated_node: "AssignTarget"
) -> Union["AssignTarget", RemovalSentinel]:
) -> Union["AssignTarget", FlattenSentinel["AssignTarget"], RemovalSentinel]:
return updated_node
@mark_no_op
@ -5342,7 +5349,9 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
@mark_no_op
def leave_AugAssign(
self, original_node: "AugAssign", updated_node: "AugAssign"
) -> Union["BaseSmallStatement", RemovalSentinel]:
) -> Union[
"BaseSmallStatement", FlattenSentinel["BaseSmallStatement"], RemovalSentinel
]:
return updated_node
@mark_no_op
@ -5408,7 +5417,9 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
@mark_no_op
def leave_Break(
self, original_node: "Break", updated_node: "Break"
) -> Union["BaseSmallStatement", RemovalSentinel]:
) -> Union[
"BaseSmallStatement", FlattenSentinel["BaseSmallStatement"], RemovalSentinel
]:
return updated_node
@mark_no_op
@ -5420,7 +5431,7 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
@mark_no_op
def leave_ClassDef(
self, original_node: "ClassDef", updated_node: "ClassDef"
) -> Union["BaseStatement", RemovalSentinel]:
) -> Union["BaseStatement", FlattenSentinel["BaseStatement"], RemovalSentinel]:
return updated_node
@mark_no_op
@ -5460,7 +5471,9 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
@mark_no_op
def leave_ComparisonTarget(
self, original_node: "ComparisonTarget", updated_node: "ComparisonTarget"
) -> Union["ComparisonTarget", RemovalSentinel]:
) -> Union[
"ComparisonTarget", FlattenSentinel["ComparisonTarget"], RemovalSentinel
]:
return updated_node
@mark_no_op
@ -5472,19 +5485,23 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
@mark_no_op
def leave_Continue(
self, original_node: "Continue", updated_node: "Continue"
) -> Union["BaseSmallStatement", RemovalSentinel]:
) -> Union[
"BaseSmallStatement", FlattenSentinel["BaseSmallStatement"], RemovalSentinel
]:
return updated_node
@mark_no_op
def leave_Decorator(
self, original_node: "Decorator", updated_node: "Decorator"
) -> Union["Decorator", RemovalSentinel]:
) -> Union["Decorator", FlattenSentinel["Decorator"], RemovalSentinel]:
return updated_node
@mark_no_op
def leave_Del(
self, original_node: "Del", updated_node: "Del"
) -> Union["BaseSmallStatement", RemovalSentinel]:
) -> Union[
"BaseSmallStatement", FlattenSentinel["BaseSmallStatement"], RemovalSentinel
]:
return updated_node
@mark_no_op
@ -5502,7 +5519,7 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
@mark_no_op
def leave_DictElement(
self, original_node: "DictElement", updated_node: "DictElement"
) -> Union["BaseDictElement", RemovalSentinel]:
) -> Union["BaseDictElement", FlattenSentinel["BaseDictElement"], RemovalSentinel]:
return updated_node
@mark_no_op
@ -5520,13 +5537,13 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
@mark_no_op
def leave_Dot(
self, original_node: "Dot", updated_node: "Dot"
) -> Union["Dot", RemovalSentinel]:
) -> Union["Dot", FlattenSentinel["Dot"], RemovalSentinel]:
return updated_node
@mark_no_op
def leave_Element(
self, original_node: "Element", updated_node: "Element"
) -> Union["BaseElement", RemovalSentinel]:
) -> Union["BaseElement", FlattenSentinel["BaseElement"], RemovalSentinel]:
return updated_node
@mark_no_op
@ -5542,7 +5559,7 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
@mark_no_op
def leave_EmptyLine(
self, original_node: "EmptyLine", updated_node: "EmptyLine"
) -> Union["EmptyLine", RemovalSentinel]:
) -> Union["EmptyLine", FlattenSentinel["EmptyLine"], RemovalSentinel]:
return updated_node
@mark_no_op
@ -5554,13 +5571,15 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
@mark_no_op
def leave_ExceptHandler(
self, original_node: "ExceptHandler", updated_node: "ExceptHandler"
) -> Union["ExceptHandler", RemovalSentinel]:
) -> Union["ExceptHandler", FlattenSentinel["ExceptHandler"], RemovalSentinel]:
return updated_node
@mark_no_op
def leave_Expr(
self, original_node: "Expr", updated_node: "Expr"
) -> Union["BaseSmallStatement", RemovalSentinel]:
) -> Union[
"BaseSmallStatement", FlattenSentinel["BaseSmallStatement"], RemovalSentinel
]:
return updated_node
@mark_no_op
@ -5590,7 +5609,7 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
@mark_no_op
def leave_For(
self, original_node: "For", updated_node: "For"
) -> Union["BaseStatement", RemovalSentinel]:
) -> Union["BaseStatement", FlattenSentinel["BaseStatement"], RemovalSentinel]:
return updated_node
@mark_no_op
@ -5604,13 +5623,21 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
self,
original_node: "FormattedStringExpression",
updated_node: "FormattedStringExpression",
) -> Union["BaseFormattedStringContent", RemovalSentinel]:
) -> Union[
"BaseFormattedStringContent",
FlattenSentinel["BaseFormattedStringContent"],
RemovalSentinel,
]:
return updated_node
@mark_no_op
def leave_FormattedStringText(
self, original_node: "FormattedStringText", updated_node: "FormattedStringText"
) -> Union["BaseFormattedStringContent", RemovalSentinel]:
) -> Union[
"BaseFormattedStringContent",
FlattenSentinel["BaseFormattedStringContent"],
RemovalSentinel,
]:
return updated_node
@mark_no_op
@ -5620,7 +5647,7 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
@mark_no_op
def leave_FunctionDef(
self, original_node: "FunctionDef", updated_node: "FunctionDef"
) -> Union["BaseStatement", RemovalSentinel]:
) -> Union["BaseStatement", FlattenSentinel["BaseStatement"], RemovalSentinel]:
return updated_node
@mark_no_op
@ -5632,7 +5659,9 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
@mark_no_op
def leave_Global(
self, original_node: "Global", updated_node: "Global"
) -> Union["BaseSmallStatement", RemovalSentinel]:
) -> Union[
"BaseSmallStatement", FlattenSentinel["BaseSmallStatement"], RemovalSentinel
]:
return updated_node
@mark_no_op
@ -5650,7 +5679,7 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
@mark_no_op
def leave_If(
self, original_node: "If", updated_node: "If"
) -> Union["BaseStatement", RemovalSentinel]:
) -> Union["BaseStatement", FlattenSentinel["BaseStatement"], RemovalSentinel]:
return updated_node
@mark_no_op
@ -5668,19 +5697,23 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
@mark_no_op
def leave_Import(
self, original_node: "Import", updated_node: "Import"
) -> Union["BaseSmallStatement", RemovalSentinel]:
) -> Union[
"BaseSmallStatement", FlattenSentinel["BaseSmallStatement"], RemovalSentinel
]:
return updated_node
@mark_no_op
def leave_ImportAlias(
self, original_node: "ImportAlias", updated_node: "ImportAlias"
) -> Union["ImportAlias", RemovalSentinel]:
) -> Union["ImportAlias", FlattenSentinel["ImportAlias"], RemovalSentinel]:
return updated_node
@mark_no_op
def leave_ImportFrom(
self, original_node: "ImportFrom", updated_node: "ImportFrom"
) -> Union["BaseSmallStatement", RemovalSentinel]:
) -> Union[
"BaseSmallStatement", FlattenSentinel["BaseSmallStatement"], RemovalSentinel
]:
return updated_node
@mark_no_op
@ -5734,7 +5767,9 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
@mark_no_op
def leave_LeftParen(
self, original_node: "LeftParen", updated_node: "LeftParen"
) -> Union["LeftParen", MaybeSentinel, RemovalSentinel]:
) -> Union[
"LeftParen", MaybeSentinel, FlattenSentinel["LeftParen"], RemovalSentinel
]:
return updated_node
@mark_no_op
@ -5836,7 +5871,7 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
@mark_no_op
def leave_NameItem(
self, original_node: "NameItem", updated_node: "NameItem"
) -> Union["NameItem", RemovalSentinel]:
) -> Union["NameItem", FlattenSentinel["NameItem"], RemovalSentinel]:
return updated_node
@mark_no_op
@ -5854,7 +5889,9 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
@mark_no_op
def leave_Nonlocal(
self, original_node: "Nonlocal", updated_node: "Nonlocal"
) -> Union["BaseSmallStatement", RemovalSentinel]:
) -> Union[
"BaseSmallStatement", FlattenSentinel["BaseSmallStatement"], RemovalSentinel
]:
return updated_node
@mark_no_op
@ -5880,7 +5917,7 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
@mark_no_op
def leave_Param(
self, original_node: "Param", updated_node: "Param"
) -> Union["Param", MaybeSentinel, RemovalSentinel]:
) -> Union["Param", MaybeSentinel, FlattenSentinel["Param"], RemovalSentinel]:
return updated_node
@mark_no_op
@ -5912,7 +5949,9 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
@mark_no_op
def leave_Pass(
self, original_node: "Pass", updated_node: "Pass"
) -> Union["BaseSmallStatement", RemovalSentinel]:
) -> Union[
"BaseSmallStatement", FlattenSentinel["BaseSmallStatement"], RemovalSentinel
]:
return updated_node
@mark_no_op
@ -5934,13 +5973,17 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
@mark_no_op
def leave_Raise(
self, original_node: "Raise", updated_node: "Raise"
) -> Union["BaseSmallStatement", RemovalSentinel]:
) -> Union[
"BaseSmallStatement", FlattenSentinel["BaseSmallStatement"], RemovalSentinel
]:
return updated_node
@mark_no_op
def leave_Return(
self, original_node: "Return", updated_node: "Return"
) -> Union["BaseSmallStatement", RemovalSentinel]:
) -> Union[
"BaseSmallStatement", FlattenSentinel["BaseSmallStatement"], RemovalSentinel
]:
return updated_node
@mark_no_op
@ -5952,7 +5995,9 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
@mark_no_op
def leave_RightParen(
self, original_node: "RightParen", updated_node: "RightParen"
) -> Union["RightParen", MaybeSentinel, RemovalSentinel]:
) -> Union[
"RightParen", MaybeSentinel, FlattenSentinel["RightParen"], RemovalSentinel
]:
return updated_node
@mark_no_op
@ -5992,7 +6037,7 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
@mark_no_op
def leave_SimpleStatementLine(
self, original_node: "SimpleStatementLine", updated_node: "SimpleStatementLine"
) -> Union["BaseStatement", RemovalSentinel]:
) -> Union["BaseStatement", FlattenSentinel["BaseStatement"], RemovalSentinel]:
return updated_node
@mark_no_op
@ -6022,13 +6067,13 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
@mark_no_op
def leave_StarredDictElement(
self, original_node: "StarredDictElement", updated_node: "StarredDictElement"
) -> Union["BaseDictElement", RemovalSentinel]:
) -> Union["BaseDictElement", FlattenSentinel["BaseDictElement"], RemovalSentinel]:
return updated_node
@mark_no_op
def leave_StarredElement(
self, original_node: "StarredElement", updated_node: "StarredElement"
) -> Union["BaseElement", RemovalSentinel]:
) -> Union["BaseElement", FlattenSentinel["BaseElement"], RemovalSentinel]:
return updated_node
@mark_no_op
@ -6040,7 +6085,9 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
@mark_no_op
def leave_SubscriptElement(
self, original_node: "SubscriptElement", updated_node: "SubscriptElement"
) -> Union["SubscriptElement", RemovalSentinel]:
) -> Union[
"SubscriptElement", FlattenSentinel["SubscriptElement"], RemovalSentinel
]:
return updated_node
@mark_no_op
@ -6064,7 +6111,7 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
@mark_no_op
def leave_Try(
self, original_node: "Try", updated_node: "Try"
) -> Union["BaseStatement", RemovalSentinel]:
) -> Union["BaseStatement", FlattenSentinel["BaseStatement"], RemovalSentinel]:
return updated_node
@mark_no_op
@ -6082,19 +6129,19 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
@mark_no_op
def leave_While(
self, original_node: "While", updated_node: "While"
) -> Union["BaseStatement", RemovalSentinel]:
) -> Union["BaseStatement", FlattenSentinel["BaseStatement"], RemovalSentinel]:
return updated_node
@mark_no_op
def leave_With(
self, original_node: "With", updated_node: "With"
) -> Union["BaseStatement", RemovalSentinel]:
) -> Union["BaseStatement", FlattenSentinel["BaseStatement"], RemovalSentinel]:
return updated_node
@mark_no_op
def leave_WithItem(
self, original_node: "WithItem", updated_node: "WithItem"
) -> Union["WithItem", RemovalSentinel]:
) -> Union["WithItem", FlattenSentinel["WithItem"], RemovalSentinel]:
return updated_node
@mark_no_op

View file

@ -12,3 +12,4 @@ if TYPE_CHECKING:
CSTNodeT = TypeVar("CSTNodeT", bound="CSTNode")
CSTNodeT_co = TypeVar("CSTNodeT_co", bound="CSTNode", covariant=True)

View file

@ -5,6 +5,7 @@
from typing import TYPE_CHECKING, Union
from libcst._flatten_sentinel import FlattenSentinel
from libcst._metadata_dependent import MetadataDependent
from libcst._removal_sentinel import RemovalSentinel
from libcst._typed_visitor import CSTTypedTransformerFunctions, CSTTypedVisitorFunctions
@ -49,7 +50,7 @@ class CSTTransformer(CSTTypedTransformerFunctions, MetadataDependent):
def on_leave(
self, original_node: CSTNodeT, updated_node: CSTNodeT
) -> Union[CSTNodeT, RemovalSentinel]:
) -> Union[CSTNodeT, RemovalSentinel, FlattenSentinel[CSTNodeT]]:
"""
Called every time we leave a node, after we've visited its children. If
the :func:`~libcst.CSTTransformer.on_visit` function for this node returns

View file

@ -21,6 +21,7 @@ generated_code.append("")
generated_code.append("# This file was generated by libcst.codegen.gen_matcher_classes")
generated_code.append("from typing import Optional, Union, TYPE_CHECKING")
generated_code.append("")
generated_code.append("from libcst._flatten_sentinel import FlattenSentinel")
generated_code.append("from libcst._maybe_sentinel import MaybeSentinel")
generated_code.append("from libcst._removal_sentinel import RemovalSentinel")
generated_code.append("from libcst._typed_visitor_base import mark_no_op")
@ -99,12 +100,11 @@ for node in sorted(nodebases.keys(), key=lambda node: node.__name__):
base_uses = nodeuses[nodebases[node]]
if node_uses.maybe or base_uses.maybe:
valid_return_types.append("MaybeSentinel")
if (
node_uses.optional
or node_uses.sequence
or base_uses.optional
or base_uses.sequence
):
if node_uses.sequence or base_uses.sequence:
valid_return_types.append(f'FlattenSentinel["{nodebases[node].__name__}"]')
valid_return_types.append("RemovalSentinel")
elif node_uses.optional or base_uses.optional:
valid_return_types.append("RemovalSentinel")
generated_code.append(

View file

@ -30,7 +30,7 @@ from typing import (
import libcst
import libcst.metadata as meta
from libcst import MaybeSentinel, RemovalSentinel
from libcst import FlattenSentinel, MaybeSentinel, RemovalSentinel
class DoNotCareSentinel(Enum):
@ -1944,4 +1944,8 @@ def replace(
fetcher = _construct_metadata_fetcher_dependent(metadata_resolver)
replacer = _ReplaceTransformer(matcher, fetcher, replacement)
return tree.visit(replacer)
new_tree = tree.visit(replacer)
if isinstance(new_tree, FlattenSentinel):
# The above transform never returns FlattenSentinel, so this isn't possible
raise Exception("Logic error, cannot get a FlattenSentinel here!")
return new_tree