diff --git a/docs/source/visitors.rst b/docs/source/visitors.rst index a2b9ee90..722959e1 100644 --- a/docs/source/visitors.rst +++ b/docs/source/visitors.rst @@ -7,6 +7,7 @@ Visitors .. autoclass:: libcst.CSTTransformer .. autofunction:: libcst.RemoveFromParent .. autoclass:: libcst.RemovalSentinel +.. autoclass:: libcst.FlattenSentinel Visit and Leave Helper Functions -------------------------------- diff --git a/libcst/__init__.py b/libcst/__init__.py index 39b0f6dc..cc71ce2a 100644 --- a/libcst/__init__.py +++ b/libcst/__init__.py @@ -5,6 +5,7 @@ from libcst._batched_visitor import BatchableCSTVisitor, visit_batched from libcst._exceptions import MetadataException, ParserSyntaxError +from libcst._flatten_sentinel import FlattenSentinel from libcst._maybe_sentinel import MaybeSentinel from libcst._metadata_dependent import MetadataDependent from libcst._nodes.base import CSTNode, CSTValidationError @@ -211,6 +212,7 @@ __all__ = [ "CSTValidationError", "CSTVisitor", "CSTVisitorT", + "FlattenSentinel", "MaybeSentinel", "MetadataException", "ParserSyntaxError", diff --git a/libcst/_flatten_sentinel.py b/libcst/_flatten_sentinel.py new file mode 100644 index 00000000..18148077 --- /dev/null +++ b/libcst/_flatten_sentinel.py @@ -0,0 +1,46 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import sys + + +# PEP 585 +if sys.version_info < (3, 9): + from typing import Iterable, Sequence +else: + from collections.abc import Iterable, Sequence + +from libcst._types import CSTNodeT_co + + +class FlattenSentinel(Sequence[CSTNodeT_co]): + """ + A :class:`FlattenSentinel` may be returned by a :meth:`CSTTransformer.on_leave` + method when one wants to replace a node with multiple nodes. The replaced + node must be contained in a `Sequence` attribute such as + :attr:`~libcst.Module.body`. This is generally the case for + :class:`~libcst.BaseStatement` and :class:`~libcst.BaseSmallStatement`. + For example to insert a print before every return:: + + def leave_Return( + self, original_node: cst.Return, updated_node: cst.Return + ) -> Union[cst.Return, cst.RemovalSentinel, cst.FlattenSentinel[cst.BaseSmallStatement]]: + log_stmt = cst.Expr(cst.parse_expression("print('returning')")) + return cst.FlattenSentinel([log_stmt, updated_node]) + + Returning an empty :class:`FlattenSentinel` is equivalent to returning + :attr:`cst.RemovalSentinel.REMOVE` and is subject to its requirements. + """ + + nodes: Sequence[CSTNodeT_co] + + def __init__(self, nodes: Iterable[CSTNodeT_co]) -> None: + self.nodes = tuple(nodes) + + def __getitem__(self, idx: int) -> CSTNodeT_co: + return self.nodes[idx] + + def __len__(self) -> int: + return len(self.nodes) diff --git a/libcst/_nodes/base.py b/libcst/_nodes/base.py index fe2988c9..47bf26ea 100644 --- a/libcst/_nodes/base.py +++ b/libcst/_nodes/base.py @@ -8,6 +8,7 @@ from copy import deepcopy from dataclasses import dataclass, field, fields, replace from typing import Any, Dict, List, Mapping, Sequence, TypeVar, Union, cast +from libcst._flatten_sentinel import FlattenSentinel from libcst._nodes.internal import CodegenState from libcst._removal_sentinel import RemovalSentinel from libcst._type_enforce import is_value_of_type @@ -207,7 +208,7 @@ class CSTNode(ABC): def visit( self: _CSTNodeSelfT, visitor: CSTVisitorT - ) -> Union[_CSTNodeSelfT, RemovalSentinel]: + ) -> Union[_CSTNodeSelfT, RemovalSentinel, FlattenSentinel[_CSTNodeSelfT]]: """ Visits the current node, its children, and all transitive children using the given visitor's callbacks. @@ -234,7 +235,7 @@ class CSTNode(ABC): leave_result = visitor.on_leave(self, with_updated_children) # validate return type of the user-defined `visitor.on_leave` method - if not isinstance(leave_result, (CSTNode, RemovalSentinel)): + if not isinstance(leave_result, (CSTNode, RemovalSentinel, FlattenSentinel)): raise Exception( "Expected a node of type CSTNode or a RemovalSentinel, " + f"but got a return value of {type(leave_result).__name__}" @@ -379,9 +380,9 @@ class CSTNode(ABC): child, all instances will be replaced. """ new_tree = self.visit(_ChildReplacementTransformer(old_node, new_node)) - if isinstance(new_tree, RemovalSentinel): - # The above transform never returns RemovalSentinel, so this isn't possible - raise Exception("Logic error, cannot get a RemovalSentinel here!") + if isinstance(new_tree, (FlattenSentinel, RemovalSentinel)): + # The above transform never returns *Sentinel, so this isn't possible + raise Exception("Logic error, cannot get a *Sentinal here!") return new_tree def deep_remove( @@ -392,10 +393,16 @@ class CSTNode(ABC): have previously modified the tree in a way that ``old_node`` appears more than once as a deep child, all instances will be removed. """ - return self.visit( + new_tree = self.visit( _ChildReplacementTransformer(old_node, RemovalSentinel.REMOVE) ) + if isinstance(new_tree, FlattenSentinel): + # The above transform never returns FlattenSentinel, so this isn't possible + raise Exception("Logic error, cannot get a FlattenSentinel here!") + + return new_tree + def with_deep_changes( self: _CSTNodeSelfT, old_node: "CSTNode", **changes: Any ) -> _CSTNodeSelfT: @@ -412,9 +419,9 @@ class CSTNode(ABC): similar API in the future. """ new_tree = self.visit(_ChildWithChangesTransformer(old_node, changes)) - if isinstance(new_tree, RemovalSentinel): + if isinstance(new_tree, (FlattenSentinel, RemovalSentinel)): # This is impossible with the above transform. - raise Exception("Logic error, cannot get a RemovalSentinel here!") + raise Exception("Logic error, cannot get a *Sentinel here!") return new_tree def __eq__(self: _CSTNodeSelfT, other: _CSTNodeSelfT) -> bool: diff --git a/libcst/_nodes/internal.py b/libcst/_nodes/internal.py index 4b5c7b00..5bbefc01 100644 --- a/libcst/_nodes/internal.py +++ b/libcst/_nodes/internal.py @@ -9,6 +9,7 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, Iterable, Iterator, List, Optional, Sequence, Union from libcst._add_slots import add_slots +from libcst._flatten_sentinel import FlattenSentinel from libcst._maybe_sentinel import MaybeSentinel from libcst._removal_sentinel import RemovalSentinel from libcst._types import CSTNodeT @@ -84,6 +85,13 @@ def visit_required( f"We got a RemovalSentinel while visiting a {type(node).__name__}. This " + "node's parent does not allow it to be removed." ) + elif isinstance(result, FlattenSentinel): + raise TypeError( + f"We got a FlattenSentinel while visiting a {type(node).__name__}. This " + + "node's parent does not allow for it to be it to be replaced with a " + + "sequence." + ) + visitor.on_leave_attribute(parent, fieldname) return result @@ -101,6 +109,12 @@ def visit_optional( return None visitor.on_visit_attribute(parent, fieldname) result = node.visit(visitor) + if isinstance(result, FlattenSentinel): + raise TypeError( + f"We got a FlattenSentinel while visiting a {type(node).__name__}. This " + + "node's parent does not allow for it to be it to be replaced with a " + + "sequence." + ) visitor.on_leave_attribute(parent, fieldname) return None if isinstance(result, RemovalSentinel) else result @@ -121,6 +135,12 @@ def visit_sentinel( return MaybeSentinel.DEFAULT visitor.on_visit_attribute(parent, fieldname) result = node.visit(visitor) + if isinstance(result, FlattenSentinel): + raise TypeError( + f"We got a FlattenSentinel while visiting a {type(node).__name__}. This " + + "node's parent does not allow for it to be it to be replaced with a " + + "sequence." + ) visitor.on_leave_attribute(parent, fieldname) return MaybeSentinel.DEFAULT if isinstance(result, RemovalSentinel) else result @@ -138,7 +158,9 @@ def visit_iterable( visitor.on_visit_attribute(parent, fieldname) for child in children: new_child = child.visit(visitor) - if not isinstance(new_child, RemovalSentinel): + if isinstance(new_child, FlattenSentinel): + yield from new_child + elif not isinstance(new_child, RemovalSentinel): yield new_child visitor.on_leave_attribute(parent, fieldname) @@ -179,11 +201,17 @@ def visit_body_iterable( # and the new child is. This means a RemovalSentinel # caused a child of this node to be dropped, and it # is now useless. - if (not child._is_removable()) and new_child._is_removable(): - continue - # Safe to yield child in this case. - yield new_child + if isinstance(new_child, FlattenSentinel): + for child_ in new_child: + if (not child._is_removable()) and child_._is_removable(): + continue + yield child_ + else: + if (not child._is_removable()) and new_child._is_removable(): + continue + # Safe to yield child in this case. + yield new_child visitor.on_leave_attribute(parent, fieldname) diff --git a/libcst/_nodes/tests/test_flatten_behavior.py b/libcst/_nodes/tests/test_flatten_behavior.py new file mode 100644 index 00000000..5f37067c --- /dev/null +++ b/libcst/_nodes/tests/test_flatten_behavior.py @@ -0,0 +1,79 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Type, Union + +import libcst as cst +from libcst import FlattenSentinel, RemovalSentinel, parse_expression, parse_module +from libcst._nodes.tests.base import CSTNodeTest +from libcst._types import CSTNodeT +from libcst._visitors import CSTTransformer +from libcst.testing.utils import data_provider + + +class InsertPrintBeforeReturn(CSTTransformer): + def leave_Return( + self, original_node: cst.Return, updated_node: cst.Return + ) -> Union[cst.Return, RemovalSentinel, FlattenSentinel[cst.BaseSmallStatement]]: + return FlattenSentinel( + [ + cst.Expr(parse_expression("print('returning')")), + updated_node, + ] + ) + + +class FlattenLines(CSTTransformer): + def on_leave( + self, original_node: CSTNodeT, updated_node: CSTNodeT + ) -> Union[CSTNodeT, RemovalSentinel, FlattenSentinel[cst.SimpleStatementLine]]: + if isinstance(updated_node, cst.SimpleStatementLine): + return FlattenSentinel( + [ + cst.SimpleStatementLine( + [stmt.with_changes(semicolon=cst.MaybeSentinel.DEFAULT)] + ) + for stmt in updated_node.body + ] + ) + else: + return updated_node + + +class RemoveReturnWithEmpty(CSTTransformer): + def leave_Return( + self, original_node: cst.Return, updated_node: cst.Return + ) -> Union[cst.Return, RemovalSentinel, FlattenSentinel[cst.BaseSmallStatement]]: + return FlattenSentinel([]) + + +class FlattenBehavior(CSTNodeTest): + @data_provider( + ( + ("return", "print('returning'); return", InsertPrintBeforeReturn), + ( + "print('returning'); return", + "print('returning')\nreturn", + FlattenLines, + ), + ( + "print('returning')\nreturn", + "print('returning')", + RemoveReturnWithEmpty, + ), + ) + ) + def test_flatten_pass_behavior( + self, before: str, after: str, visitor: Type[CSTTransformer] + ) -> None: + # Test doesn't have newline termination case + before_module = parse_module(before) + after_module = before_module.visit(visitor()) + self.assertEqual(after, after_module.code) + + # Test does have newline termination case + before_module = parse_module(before + "\n") + after_module = before_module.visit(visitor()) + self.assertEqual(after + "\n", after_module.code) diff --git a/libcst/_typed_visitor.py b/libcst/_typed_visitor.py index bbc10d55..0246c718 100644 --- a/libcst/_typed_visitor.py +++ b/libcst/_typed_visitor.py @@ -7,6 +7,7 @@ # This file was generated by libcst.codegen.gen_matcher_classes from typing import TYPE_CHECKING, Optional, Union +from libcst._flatten_sentinel import FlattenSentinel from libcst._maybe_sentinel import MaybeSentinel from libcst._removal_sentinel import RemovalSentinel from libcst._typed_visitor_base import mark_no_op @@ -5284,7 +5285,9 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions): @mark_no_op def leave_AnnAssign( self, original_node: "AnnAssign", updated_node: "AnnAssign" - ) -> Union["BaseSmallStatement", RemovalSentinel]: + ) -> Union[ + "BaseSmallStatement", FlattenSentinel["BaseSmallStatement"], RemovalSentinel + ]: return updated_node @mark_no_op @@ -5296,7 +5299,7 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions): @mark_no_op def leave_Arg( self, original_node: "Arg", updated_node: "Arg" - ) -> Union["Arg", RemovalSentinel]: + ) -> Union["Arg", FlattenSentinel["Arg"], RemovalSentinel]: return updated_node @mark_no_op @@ -5306,13 +5309,17 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions): @mark_no_op def leave_Assert( self, original_node: "Assert", updated_node: "Assert" - ) -> Union["BaseSmallStatement", RemovalSentinel]: + ) -> Union[ + "BaseSmallStatement", FlattenSentinel["BaseSmallStatement"], RemovalSentinel + ]: return updated_node @mark_no_op def leave_Assign( self, original_node: "Assign", updated_node: "Assign" - ) -> Union["BaseSmallStatement", RemovalSentinel]: + ) -> Union[ + "BaseSmallStatement", FlattenSentinel["BaseSmallStatement"], RemovalSentinel + ]: return updated_node @mark_no_op @@ -5324,7 +5331,7 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions): @mark_no_op def leave_AssignTarget( self, original_node: "AssignTarget", updated_node: "AssignTarget" - ) -> Union["AssignTarget", RemovalSentinel]: + ) -> Union["AssignTarget", FlattenSentinel["AssignTarget"], RemovalSentinel]: return updated_node @mark_no_op @@ -5342,7 +5349,9 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions): @mark_no_op def leave_AugAssign( self, original_node: "AugAssign", updated_node: "AugAssign" - ) -> Union["BaseSmallStatement", RemovalSentinel]: + ) -> Union[ + "BaseSmallStatement", FlattenSentinel["BaseSmallStatement"], RemovalSentinel + ]: return updated_node @mark_no_op @@ -5408,7 +5417,9 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions): @mark_no_op def leave_Break( self, original_node: "Break", updated_node: "Break" - ) -> Union["BaseSmallStatement", RemovalSentinel]: + ) -> Union[ + "BaseSmallStatement", FlattenSentinel["BaseSmallStatement"], RemovalSentinel + ]: return updated_node @mark_no_op @@ -5420,7 +5431,7 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions): @mark_no_op def leave_ClassDef( self, original_node: "ClassDef", updated_node: "ClassDef" - ) -> Union["BaseStatement", RemovalSentinel]: + ) -> Union["BaseStatement", FlattenSentinel["BaseStatement"], RemovalSentinel]: return updated_node @mark_no_op @@ -5460,7 +5471,9 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions): @mark_no_op def leave_ComparisonTarget( self, original_node: "ComparisonTarget", updated_node: "ComparisonTarget" - ) -> Union["ComparisonTarget", RemovalSentinel]: + ) -> Union[ + "ComparisonTarget", FlattenSentinel["ComparisonTarget"], RemovalSentinel + ]: return updated_node @mark_no_op @@ -5472,19 +5485,23 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions): @mark_no_op def leave_Continue( self, original_node: "Continue", updated_node: "Continue" - ) -> Union["BaseSmallStatement", RemovalSentinel]: + ) -> Union[ + "BaseSmallStatement", FlattenSentinel["BaseSmallStatement"], RemovalSentinel + ]: return updated_node @mark_no_op def leave_Decorator( self, original_node: "Decorator", updated_node: "Decorator" - ) -> Union["Decorator", RemovalSentinel]: + ) -> Union["Decorator", FlattenSentinel["Decorator"], RemovalSentinel]: return updated_node @mark_no_op def leave_Del( self, original_node: "Del", updated_node: "Del" - ) -> Union["BaseSmallStatement", RemovalSentinel]: + ) -> Union[ + "BaseSmallStatement", FlattenSentinel["BaseSmallStatement"], RemovalSentinel + ]: return updated_node @mark_no_op @@ -5502,7 +5519,7 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions): @mark_no_op def leave_DictElement( self, original_node: "DictElement", updated_node: "DictElement" - ) -> Union["BaseDictElement", RemovalSentinel]: + ) -> Union["BaseDictElement", FlattenSentinel["BaseDictElement"], RemovalSentinel]: return updated_node @mark_no_op @@ -5520,13 +5537,13 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions): @mark_no_op def leave_Dot( self, original_node: "Dot", updated_node: "Dot" - ) -> Union["Dot", RemovalSentinel]: + ) -> Union["Dot", FlattenSentinel["Dot"], RemovalSentinel]: return updated_node @mark_no_op def leave_Element( self, original_node: "Element", updated_node: "Element" - ) -> Union["BaseElement", RemovalSentinel]: + ) -> Union["BaseElement", FlattenSentinel["BaseElement"], RemovalSentinel]: return updated_node @mark_no_op @@ -5542,7 +5559,7 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions): @mark_no_op def leave_EmptyLine( self, original_node: "EmptyLine", updated_node: "EmptyLine" - ) -> Union["EmptyLine", RemovalSentinel]: + ) -> Union["EmptyLine", FlattenSentinel["EmptyLine"], RemovalSentinel]: return updated_node @mark_no_op @@ -5554,13 +5571,15 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions): @mark_no_op def leave_ExceptHandler( self, original_node: "ExceptHandler", updated_node: "ExceptHandler" - ) -> Union["ExceptHandler", RemovalSentinel]: + ) -> Union["ExceptHandler", FlattenSentinel["ExceptHandler"], RemovalSentinel]: return updated_node @mark_no_op def leave_Expr( self, original_node: "Expr", updated_node: "Expr" - ) -> Union["BaseSmallStatement", RemovalSentinel]: + ) -> Union[ + "BaseSmallStatement", FlattenSentinel["BaseSmallStatement"], RemovalSentinel + ]: return updated_node @mark_no_op @@ -5590,7 +5609,7 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions): @mark_no_op def leave_For( self, original_node: "For", updated_node: "For" - ) -> Union["BaseStatement", RemovalSentinel]: + ) -> Union["BaseStatement", FlattenSentinel["BaseStatement"], RemovalSentinel]: return updated_node @mark_no_op @@ -5604,13 +5623,21 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions): self, original_node: "FormattedStringExpression", updated_node: "FormattedStringExpression", - ) -> Union["BaseFormattedStringContent", RemovalSentinel]: + ) -> Union[ + "BaseFormattedStringContent", + FlattenSentinel["BaseFormattedStringContent"], + RemovalSentinel, + ]: return updated_node @mark_no_op def leave_FormattedStringText( self, original_node: "FormattedStringText", updated_node: "FormattedStringText" - ) -> Union["BaseFormattedStringContent", RemovalSentinel]: + ) -> Union[ + "BaseFormattedStringContent", + FlattenSentinel["BaseFormattedStringContent"], + RemovalSentinel, + ]: return updated_node @mark_no_op @@ -5620,7 +5647,7 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions): @mark_no_op def leave_FunctionDef( self, original_node: "FunctionDef", updated_node: "FunctionDef" - ) -> Union["BaseStatement", RemovalSentinel]: + ) -> Union["BaseStatement", FlattenSentinel["BaseStatement"], RemovalSentinel]: return updated_node @mark_no_op @@ -5632,7 +5659,9 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions): @mark_no_op def leave_Global( self, original_node: "Global", updated_node: "Global" - ) -> Union["BaseSmallStatement", RemovalSentinel]: + ) -> Union[ + "BaseSmallStatement", FlattenSentinel["BaseSmallStatement"], RemovalSentinel + ]: return updated_node @mark_no_op @@ -5650,7 +5679,7 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions): @mark_no_op def leave_If( self, original_node: "If", updated_node: "If" - ) -> Union["BaseStatement", RemovalSentinel]: + ) -> Union["BaseStatement", FlattenSentinel["BaseStatement"], RemovalSentinel]: return updated_node @mark_no_op @@ -5668,19 +5697,23 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions): @mark_no_op def leave_Import( self, original_node: "Import", updated_node: "Import" - ) -> Union["BaseSmallStatement", RemovalSentinel]: + ) -> Union[ + "BaseSmallStatement", FlattenSentinel["BaseSmallStatement"], RemovalSentinel + ]: return updated_node @mark_no_op def leave_ImportAlias( self, original_node: "ImportAlias", updated_node: "ImportAlias" - ) -> Union["ImportAlias", RemovalSentinel]: + ) -> Union["ImportAlias", FlattenSentinel["ImportAlias"], RemovalSentinel]: return updated_node @mark_no_op def leave_ImportFrom( self, original_node: "ImportFrom", updated_node: "ImportFrom" - ) -> Union["BaseSmallStatement", RemovalSentinel]: + ) -> Union[ + "BaseSmallStatement", FlattenSentinel["BaseSmallStatement"], RemovalSentinel + ]: return updated_node @mark_no_op @@ -5734,7 +5767,9 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions): @mark_no_op def leave_LeftParen( self, original_node: "LeftParen", updated_node: "LeftParen" - ) -> Union["LeftParen", MaybeSentinel, RemovalSentinel]: + ) -> Union[ + "LeftParen", MaybeSentinel, FlattenSentinel["LeftParen"], RemovalSentinel + ]: return updated_node @mark_no_op @@ -5836,7 +5871,7 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions): @mark_no_op def leave_NameItem( self, original_node: "NameItem", updated_node: "NameItem" - ) -> Union["NameItem", RemovalSentinel]: + ) -> Union["NameItem", FlattenSentinel["NameItem"], RemovalSentinel]: return updated_node @mark_no_op @@ -5854,7 +5889,9 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions): @mark_no_op def leave_Nonlocal( self, original_node: "Nonlocal", updated_node: "Nonlocal" - ) -> Union["BaseSmallStatement", RemovalSentinel]: + ) -> Union[ + "BaseSmallStatement", FlattenSentinel["BaseSmallStatement"], RemovalSentinel + ]: return updated_node @mark_no_op @@ -5880,7 +5917,7 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions): @mark_no_op def leave_Param( self, original_node: "Param", updated_node: "Param" - ) -> Union["Param", MaybeSentinel, RemovalSentinel]: + ) -> Union["Param", MaybeSentinel, FlattenSentinel["Param"], RemovalSentinel]: return updated_node @mark_no_op @@ -5912,7 +5949,9 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions): @mark_no_op def leave_Pass( self, original_node: "Pass", updated_node: "Pass" - ) -> Union["BaseSmallStatement", RemovalSentinel]: + ) -> Union[ + "BaseSmallStatement", FlattenSentinel["BaseSmallStatement"], RemovalSentinel + ]: return updated_node @mark_no_op @@ -5934,13 +5973,17 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions): @mark_no_op def leave_Raise( self, original_node: "Raise", updated_node: "Raise" - ) -> Union["BaseSmallStatement", RemovalSentinel]: + ) -> Union[ + "BaseSmallStatement", FlattenSentinel["BaseSmallStatement"], RemovalSentinel + ]: return updated_node @mark_no_op def leave_Return( self, original_node: "Return", updated_node: "Return" - ) -> Union["BaseSmallStatement", RemovalSentinel]: + ) -> Union[ + "BaseSmallStatement", FlattenSentinel["BaseSmallStatement"], RemovalSentinel + ]: return updated_node @mark_no_op @@ -5952,7 +5995,9 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions): @mark_no_op def leave_RightParen( self, original_node: "RightParen", updated_node: "RightParen" - ) -> Union["RightParen", MaybeSentinel, RemovalSentinel]: + ) -> Union[ + "RightParen", MaybeSentinel, FlattenSentinel["RightParen"], RemovalSentinel + ]: return updated_node @mark_no_op @@ -5992,7 +6037,7 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions): @mark_no_op def leave_SimpleStatementLine( self, original_node: "SimpleStatementLine", updated_node: "SimpleStatementLine" - ) -> Union["BaseStatement", RemovalSentinel]: + ) -> Union["BaseStatement", FlattenSentinel["BaseStatement"], RemovalSentinel]: return updated_node @mark_no_op @@ -6022,13 +6067,13 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions): @mark_no_op def leave_StarredDictElement( self, original_node: "StarredDictElement", updated_node: "StarredDictElement" - ) -> Union["BaseDictElement", RemovalSentinel]: + ) -> Union["BaseDictElement", FlattenSentinel["BaseDictElement"], RemovalSentinel]: return updated_node @mark_no_op def leave_StarredElement( self, original_node: "StarredElement", updated_node: "StarredElement" - ) -> Union["BaseElement", RemovalSentinel]: + ) -> Union["BaseElement", FlattenSentinel["BaseElement"], RemovalSentinel]: return updated_node @mark_no_op @@ -6040,7 +6085,9 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions): @mark_no_op def leave_SubscriptElement( self, original_node: "SubscriptElement", updated_node: "SubscriptElement" - ) -> Union["SubscriptElement", RemovalSentinel]: + ) -> Union[ + "SubscriptElement", FlattenSentinel["SubscriptElement"], RemovalSentinel + ]: return updated_node @mark_no_op @@ -6064,7 +6111,7 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions): @mark_no_op def leave_Try( self, original_node: "Try", updated_node: "Try" - ) -> Union["BaseStatement", RemovalSentinel]: + ) -> Union["BaseStatement", FlattenSentinel["BaseStatement"], RemovalSentinel]: return updated_node @mark_no_op @@ -6082,19 +6129,19 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions): @mark_no_op def leave_While( self, original_node: "While", updated_node: "While" - ) -> Union["BaseStatement", RemovalSentinel]: + ) -> Union["BaseStatement", FlattenSentinel["BaseStatement"], RemovalSentinel]: return updated_node @mark_no_op def leave_With( self, original_node: "With", updated_node: "With" - ) -> Union["BaseStatement", RemovalSentinel]: + ) -> Union["BaseStatement", FlattenSentinel["BaseStatement"], RemovalSentinel]: return updated_node @mark_no_op def leave_WithItem( self, original_node: "WithItem", updated_node: "WithItem" - ) -> Union["WithItem", RemovalSentinel]: + ) -> Union["WithItem", FlattenSentinel["WithItem"], RemovalSentinel]: return updated_node @mark_no_op diff --git a/libcst/_types.py b/libcst/_types.py index 98342da8..b6b2ea9c 100644 --- a/libcst/_types.py +++ b/libcst/_types.py @@ -12,3 +12,4 @@ if TYPE_CHECKING: CSTNodeT = TypeVar("CSTNodeT", bound="CSTNode") +CSTNodeT_co = TypeVar("CSTNodeT_co", bound="CSTNode", covariant=True) diff --git a/libcst/_visitors.py b/libcst/_visitors.py index 1d710ff2..8da37dbf 100644 --- a/libcst/_visitors.py +++ b/libcst/_visitors.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Union +from libcst._flatten_sentinel import FlattenSentinel from libcst._metadata_dependent import MetadataDependent from libcst._removal_sentinel import RemovalSentinel from libcst._typed_visitor import CSTTypedTransformerFunctions, CSTTypedVisitorFunctions @@ -49,7 +50,7 @@ class CSTTransformer(CSTTypedTransformerFunctions, MetadataDependent): def on_leave( self, original_node: CSTNodeT, updated_node: CSTNodeT - ) -> Union[CSTNodeT, RemovalSentinel]: + ) -> Union[CSTNodeT, RemovalSentinel, FlattenSentinel[CSTNodeT]]: """ Called every time we leave a node, after we've visited its children. If the :func:`~libcst.CSTTransformer.on_visit` function for this node returns diff --git a/libcst/codegen/gen_visitor_functions.py b/libcst/codegen/gen_visitor_functions.py index d9a9401b..0666691b 100644 --- a/libcst/codegen/gen_visitor_functions.py +++ b/libcst/codegen/gen_visitor_functions.py @@ -21,6 +21,7 @@ generated_code.append("") generated_code.append("# This file was generated by libcst.codegen.gen_matcher_classes") generated_code.append("from typing import Optional, Union, TYPE_CHECKING") generated_code.append("") +generated_code.append("from libcst._flatten_sentinel import FlattenSentinel") generated_code.append("from libcst._maybe_sentinel import MaybeSentinel") generated_code.append("from libcst._removal_sentinel import RemovalSentinel") generated_code.append("from libcst._typed_visitor_base import mark_no_op") @@ -99,12 +100,11 @@ for node in sorted(nodebases.keys(), key=lambda node: node.__name__): base_uses = nodeuses[nodebases[node]] if node_uses.maybe or base_uses.maybe: valid_return_types.append("MaybeSentinel") - if ( - node_uses.optional - or node_uses.sequence - or base_uses.optional - or base_uses.sequence - ): + + if node_uses.sequence or base_uses.sequence: + valid_return_types.append(f'FlattenSentinel["{nodebases[node].__name__}"]') + valid_return_types.append("RemovalSentinel") + elif node_uses.optional or base_uses.optional: valid_return_types.append("RemovalSentinel") generated_code.append( diff --git a/libcst/matchers/_matcher_base.py b/libcst/matchers/_matcher_base.py index 70a9340a..149904e3 100644 --- a/libcst/matchers/_matcher_base.py +++ b/libcst/matchers/_matcher_base.py @@ -30,7 +30,7 @@ from typing import ( import libcst import libcst.metadata as meta -from libcst import MaybeSentinel, RemovalSentinel +from libcst import FlattenSentinel, MaybeSentinel, RemovalSentinel class DoNotCareSentinel(Enum): @@ -1944,4 +1944,8 @@ def replace( fetcher = _construct_metadata_fetcher_dependent(metadata_resolver) replacer = _ReplaceTransformer(matcher, fetcher, replacement) - return tree.visit(replacer) + new_tree = tree.visit(replacer) + if isinstance(new_tree, FlattenSentinel): + # The above transform never returns FlattenSentinel, so this isn't possible + raise Exception("Logic error, cannot get a FlattenSentinel here!") + return new_tree