From 3ada79ebcb14224e0c0cd4e40a622540656cc0e3 Mon Sep 17 00:00:00 2001 From: Sebastian Kreft Date: Thu, 13 Aug 2020 21:50:36 -0400 Subject: [PATCH] fix: correctly extract wildcard matchers (#355) * fix: correctly extract wildcard matchers Fixes #337 and #338 * refactor: use data classes instead of bare tuples --- libcst/matchers/_matcher_base.py | 124 ++++++++++++++++---------- libcst/matchers/tests/test_extract.py | 22 +++++ 2 files changed, 100 insertions(+), 46 deletions(-) diff --git a/libcst/matchers/_matcher_base.py b/libcst/matchers/_matcher_base.py index 16c16d5b..27475d5b 100644 --- a/libcst/matchers/_matcher_base.py +++ b/libcst/matchers/_matcher_base.py @@ -7,7 +7,7 @@ import collections.abc import copy import inspect import re -from dataclasses import fields +from dataclasses import dataclass, fields from enum import Enum, auto from typing import ( Callable, @@ -932,6 +932,16 @@ def SaveMatchedNode(matcher: _OtherNodeT, name: str) -> _OtherNodeT: return cast(_OtherNodeT, _ExtractMatchingNode(matcher, name)) +@dataclass(frozen=True) +class _SequenceMatchesResult: + sequence_capture: Optional[ + Dict[str, Union[libcst.CSTNode, Sequence[libcst.CSTNode]]] + ] + matched_nodes: Optional[ + Union[libcst.CSTNode, MaybeSentinel, Sequence[libcst.CSTNode]] + ] + + def _sequence_matches( # noqa: C901 nodes: Sequence[Union[MaybeSentinel, libcst.CSTNode]], matchers: Sequence[ @@ -944,30 +954,35 @@ def _sequence_matches( # noqa: C901 ] ], metadata_lookup: Callable[[meta.ProviderT, libcst.CSTNode], object], -) -> Optional[Dict[str, Union[libcst.CSTNode, Sequence[libcst.CSTNode]]]]: +) -> _SequenceMatchesResult: if not nodes and not matchers: - # Base case, empty lists are alwatys matches - return {} + # Base case, empty lists are always matches + return _SequenceMatchesResult({}, None) if not nodes and matchers: # Base case, we have one or more matcher that wasn't matched return ( - {} + _SequenceMatchesResult({}, []) if all( (isinstance(m, AtLeastN) and m.n == 0) or isinstance(m, AtMostN) for m in matchers ) - else None + else _SequenceMatchesResult(None, None) ) if nodes and not matchers: # Base case, we have nodes left that don't match any matcher - return None + return _SequenceMatchesResult(None, None) # Recursive case, nodes and matchers LHS matches node = nodes[0] matcher = matchers[0] if isinstance(matcher, DoNotCareSentinel): # We don't care about the value for this node. - return _sequence_matches(nodes[1:], matchers[1:], metadata_lookup) + return _SequenceMatchesResult( + _sequence_matches( + nodes[1:], matchers[1:], metadata_lookup + ).sequence_capture, + node, + ) elif isinstance(matcher, _BaseWildcardNode): if isinstance(matcher, AtMostN): if matcher.n > 0: @@ -977,18 +992,24 @@ def _sequence_matches( # noqa: C901 nodes[0], matcher.matcher, metadata_lookup ) if attribute_capture is not None: - sequence_capture = _sequence_matches( + result = _sequence_matches( nodes[1:], [AtMostN(matcher.matcher, n=matcher.n - 1), *matchers[1:]], metadata_lookup, ) - if sequence_capture is not None: - return {**attribute_capture, **sequence_capture} + if result.sequence_capture is not None: + return _SequenceMatchesResult( + {**attribute_capture, **result.sequence_capture}, + (node, *result.matched_nodes), + ) # Finally, assume that this does not match the current node. # Consume the matcher but not the node. - sequence_capture = _sequence_matches(nodes, matchers[1:], metadata_lookup) - if sequence_capture is not None: - return sequence_capture + return _SequenceMatchesResult( + _sequence_matches( + nodes, matchers[1:], metadata_lookup + ).sequence_capture, + (), + ) elif isinstance(matcher, AtLeastN): if matcher.n > 0: # Only match if we can consume one of the matches, since we still @@ -997,13 +1018,17 @@ def _sequence_matches( # noqa: C901 nodes[0], matcher.matcher, metadata_lookup ) if attribute_capture is not None: - sequence_capture = _sequence_matches( + result = _sequence_matches( nodes[1:], [AtLeastN(matcher.matcher, n=matcher.n - 1), *matchers[1:]], metadata_lookup, ) - if sequence_capture is not None: - return {**attribute_capture, **sequence_capture} + if result.sequence_capture is not None: + return _SequenceMatchesResult( + {**attribute_capture, **result.sequence_capture}, + (node, *result.matched_nodes), + ) + return _SequenceMatchesResult(None, None) else: # First, assume that this does match a node (greedy). # Consume one node since it matched this matcher. @@ -1011,45 +1036,52 @@ def _sequence_matches( # noqa: C901 nodes[0], matcher.matcher, metadata_lookup ) if attribute_capture is not None: - sequence_capture = _sequence_matches( - nodes[1:], matchers, metadata_lookup - ) - if sequence_capture is not None: - return {**attribute_capture, **sequence_capture} + result = _sequence_matches(nodes[1:], matchers, metadata_lookup) + if result.sequence_capture is not None: + return _SequenceMatchesResult( + {**attribute_capture, **result.sequence_capture}, + (node, *result.matched_nodes), + ) # Now, assume that this does not match the current node. # Consume the matcher but not the node. - sequence_capture = _sequence_matches( - nodes, matchers[1:], metadata_lookup + return _SequenceMatchesResult( + _sequence_matches( + nodes, matchers[1:], metadata_lookup + ).sequence_capture, + (), ) - if sequence_capture is not None: - return sequence_capture else: # There are no other types of wildcard consumers, but we're making # pyre happy with that fact. raise Exception(f"Logic error unrecognized wildcard {type(matcher)}!") elif isinstance(matcher, _ExtractMatchingNode): # See if the raw matcher matches. If it does, capture the sequence we matched and store it. - sequence_capture = _sequence_matches( + result = _sequence_matches( nodes, [matcher.matcher, *matchers[1:]], metadata_lookup ) - if sequence_capture is not None: - return { - # Our own match capture comes first, since we wnat to allow the same - # name later in the sequence to override us. - matcher.name: nodes, - **sequence_capture, - } - return None + if result.sequence_capture is not None: + return _SequenceMatchesResult( + { + # Our own match capture comes first, since we wnat to allow the same + # name later in the sequence to override us. + matcher.name: result.matched_nodes, + **result.sequence_capture, + }, + result.matched_nodes, + ) + return _SequenceMatchesResult(None, None) match_capture = _matches(node, matcher, metadata_lookup) if match_capture is not None: # These values match directly - sequence_capture = _sequence_matches(nodes[1:], matchers[1:], metadata_lookup) - if sequence_capture is not None: - return {**match_capture, **sequence_capture} + result = _sequence_matches(nodes[1:], matchers[1:], metadata_lookup) + if result.sequence_capture is not None: + return _SequenceMatchesResult( + {**match_capture, **result.sequence_capture}, node + ) # Failed recursive case, no match - return None + return _SequenceMatchesResult(None, None) _AttributeValueT = Optional[Union[MaybeSentinel, libcst.CSTNode, str, bool]] @@ -1110,9 +1142,9 @@ def _attribute_matches( # noqa: C901 for m in matcher.options: if isinstance(m, collections.abc.Sequence): # Should match the sequence of requested nodes - sequence_capture = _sequence_matches(node, m, metadata_lookup) - if sequence_capture is not None: - return sequence_capture + result = _sequence_matches(node, m, metadata_lookup) + if result.sequence_capture is not None: + return result.sequence_capture elif isinstance(m, MatchIfTrue): return {} if matcher.func(node) else None elif isinstance(matcher, AllOf): @@ -1121,10 +1153,10 @@ def _attribute_matches( # noqa: C901 for m in matcher.options: if isinstance(m, collections.abc.Sequence): # Should match the sequence of requested nodes - sequence_capture = _sequence_matches(node, m, metadata_lookup) - if sequence_capture is None: + result = _sequence_matches(node, m, metadata_lookup) + if result.sequence_capture is None: return None - all_captures = {**all_captures, **sequence_capture} + all_captures = {**all_captures, **result.sequence_capture} elif isinstance(m, MatchIfTrue): return {} if matcher.func(node) else None else: @@ -1150,7 +1182,7 @@ def _attribute_matches( # noqa: C901 matcher, ), metadata_lookup, - ) + ).sequence_capture # We exhausted our possibilities, there's no match return None diff --git a/libcst/matchers/tests/test_extract.py b/libcst/matchers/tests/test_extract.py index 2bf45b91..5c3cf12a 100644 --- a/libcst/matchers/tests/test_extract.py +++ b/libcst/matchers/tests/test_extract.py @@ -404,3 +404,25 @@ class MatchersExtractTest(UnitTest): ), ) self.assertIsNone(nodes) + + def test_extract_sequence_multiple_wildcards(self) -> None: + expression = cst.parse_expression("1, 2, 3, 4") + nodes = m.extract( + expression, + m.Tuple( + elements=( + m.SaveMatchedNode(m.ZeroOrMore(), "head"), + m.SaveMatchedNode(m.Element(value=m.Integer(value="3")), "element"), + m.SaveMatchedNode(m.ZeroOrMore(), "tail"), + ) + ), + ) + tuple_elements = cst.ensure_type(expression, cst.Tuple).elements + self.assertEqual( + nodes, + { + "head": tuple(tuple_elements[:2]), + "element": tuple_elements[2], + "tail": tuple(tuple_elements[3:]), + }, + )