fix: correctly extract wildcard matchers (#355)

* fix: correctly extract wildcard matchers

Fixes #337 and #338

* refactor: use data classes instead of bare tuples
This commit is contained in:
Sebastian Kreft 2020-08-13 21:50:36 -04:00 committed by GitHub
parent 0c09c9dfbb
commit 3ada79ebcb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 100 additions and 46 deletions

View file

@ -7,7 +7,7 @@ import collections.abc
import copy
import inspect
import re
from dataclasses import fields
from dataclasses import dataclass, fields
from enum import Enum, auto
from typing import (
Callable,
@ -932,6 +932,16 @@ def SaveMatchedNode(matcher: _OtherNodeT, name: str) -> _OtherNodeT:
return cast(_OtherNodeT, _ExtractMatchingNode(matcher, name))
@dataclass(frozen=True)
class _SequenceMatchesResult:
sequence_capture: Optional[
Dict[str, Union[libcst.CSTNode, Sequence[libcst.CSTNode]]]
]
matched_nodes: Optional[
Union[libcst.CSTNode, MaybeSentinel, Sequence[libcst.CSTNode]]
]
def _sequence_matches( # noqa: C901
nodes: Sequence[Union[MaybeSentinel, libcst.CSTNode]],
matchers: Sequence[
@ -944,30 +954,35 @@ def _sequence_matches( # noqa: C901
]
],
metadata_lookup: Callable[[meta.ProviderT, libcst.CSTNode], object],
) -> Optional[Dict[str, Union[libcst.CSTNode, Sequence[libcst.CSTNode]]]]:
) -> _SequenceMatchesResult:
if not nodes and not matchers:
# Base case, empty lists are alwatys matches
return {}
# Base case, empty lists are always matches
return _SequenceMatchesResult({}, None)
if not nodes and matchers:
# Base case, we have one or more matcher that wasn't matched
return (
{}
_SequenceMatchesResult({}, [])
if all(
(isinstance(m, AtLeastN) and m.n == 0) or isinstance(m, AtMostN)
for m in matchers
)
else None
else _SequenceMatchesResult(None, None)
)
if nodes and not matchers:
# Base case, we have nodes left that don't match any matcher
return None
return _SequenceMatchesResult(None, None)
# Recursive case, nodes and matchers LHS matches
node = nodes[0]
matcher = matchers[0]
if isinstance(matcher, DoNotCareSentinel):
# We don't care about the value for this node.
return _sequence_matches(nodes[1:], matchers[1:], metadata_lookup)
return _SequenceMatchesResult(
_sequence_matches(
nodes[1:], matchers[1:], metadata_lookup
).sequence_capture,
node,
)
elif isinstance(matcher, _BaseWildcardNode):
if isinstance(matcher, AtMostN):
if matcher.n > 0:
@ -977,18 +992,24 @@ def _sequence_matches( # noqa: C901
nodes[0], matcher.matcher, metadata_lookup
)
if attribute_capture is not None:
sequence_capture = _sequence_matches(
result = _sequence_matches(
nodes[1:],
[AtMostN(matcher.matcher, n=matcher.n - 1), *matchers[1:]],
metadata_lookup,
)
if sequence_capture is not None:
return {**attribute_capture, **sequence_capture}
if result.sequence_capture is not None:
return _SequenceMatchesResult(
{**attribute_capture, **result.sequence_capture},
(node, *result.matched_nodes),
)
# Finally, assume that this does not match the current node.
# Consume the matcher but not the node.
sequence_capture = _sequence_matches(nodes, matchers[1:], metadata_lookup)
if sequence_capture is not None:
return sequence_capture
return _SequenceMatchesResult(
_sequence_matches(
nodes, matchers[1:], metadata_lookup
).sequence_capture,
(),
)
elif isinstance(matcher, AtLeastN):
if matcher.n > 0:
# Only match if we can consume one of the matches, since we still
@ -997,13 +1018,17 @@ def _sequence_matches( # noqa: C901
nodes[0], matcher.matcher, metadata_lookup
)
if attribute_capture is not None:
sequence_capture = _sequence_matches(
result = _sequence_matches(
nodes[1:],
[AtLeastN(matcher.matcher, n=matcher.n - 1), *matchers[1:]],
metadata_lookup,
)
if sequence_capture is not None:
return {**attribute_capture, **sequence_capture}
if result.sequence_capture is not None:
return _SequenceMatchesResult(
{**attribute_capture, **result.sequence_capture},
(node, *result.matched_nodes),
)
return _SequenceMatchesResult(None, None)
else:
# First, assume that this does match a node (greedy).
# Consume one node since it matched this matcher.
@ -1011,45 +1036,52 @@ def _sequence_matches( # noqa: C901
nodes[0], matcher.matcher, metadata_lookup
)
if attribute_capture is not None:
sequence_capture = _sequence_matches(
nodes[1:], matchers, metadata_lookup
)
if sequence_capture is not None:
return {**attribute_capture, **sequence_capture}
result = _sequence_matches(nodes[1:], matchers, metadata_lookup)
if result.sequence_capture is not None:
return _SequenceMatchesResult(
{**attribute_capture, **result.sequence_capture},
(node, *result.matched_nodes),
)
# Now, assume that this does not match the current node.
# Consume the matcher but not the node.
sequence_capture = _sequence_matches(
nodes, matchers[1:], metadata_lookup
return _SequenceMatchesResult(
_sequence_matches(
nodes, matchers[1:], metadata_lookup
).sequence_capture,
(),
)
if sequence_capture is not None:
return sequence_capture
else:
# There are no other types of wildcard consumers, but we're making
# pyre happy with that fact.
raise Exception(f"Logic error unrecognized wildcard {type(matcher)}!")
elif isinstance(matcher, _ExtractMatchingNode):
# See if the raw matcher matches. If it does, capture the sequence we matched and store it.
sequence_capture = _sequence_matches(
result = _sequence_matches(
nodes, [matcher.matcher, *matchers[1:]], metadata_lookup
)
if sequence_capture is not None:
return {
# Our own match capture comes first, since we wnat to allow the same
# name later in the sequence to override us.
matcher.name: nodes,
**sequence_capture,
}
return None
if result.sequence_capture is not None:
return _SequenceMatchesResult(
{
# Our own match capture comes first, since we wnat to allow the same
# name later in the sequence to override us.
matcher.name: result.matched_nodes,
**result.sequence_capture,
},
result.matched_nodes,
)
return _SequenceMatchesResult(None, None)
match_capture = _matches(node, matcher, metadata_lookup)
if match_capture is not None:
# These values match directly
sequence_capture = _sequence_matches(nodes[1:], matchers[1:], metadata_lookup)
if sequence_capture is not None:
return {**match_capture, **sequence_capture}
result = _sequence_matches(nodes[1:], matchers[1:], metadata_lookup)
if result.sequence_capture is not None:
return _SequenceMatchesResult(
{**match_capture, **result.sequence_capture}, node
)
# Failed recursive case, no match
return None
return _SequenceMatchesResult(None, None)
_AttributeValueT = Optional[Union[MaybeSentinel, libcst.CSTNode, str, bool]]
@ -1110,9 +1142,9 @@ def _attribute_matches( # noqa: C901
for m in matcher.options:
if isinstance(m, collections.abc.Sequence):
# Should match the sequence of requested nodes
sequence_capture = _sequence_matches(node, m, metadata_lookup)
if sequence_capture is not None:
return sequence_capture
result = _sequence_matches(node, m, metadata_lookup)
if result.sequence_capture is not None:
return result.sequence_capture
elif isinstance(m, MatchIfTrue):
return {} if matcher.func(node) else None
elif isinstance(matcher, AllOf):
@ -1121,10 +1153,10 @@ def _attribute_matches( # noqa: C901
for m in matcher.options:
if isinstance(m, collections.abc.Sequence):
# Should match the sequence of requested nodes
sequence_capture = _sequence_matches(node, m, metadata_lookup)
if sequence_capture is None:
result = _sequence_matches(node, m, metadata_lookup)
if result.sequence_capture is None:
return None
all_captures = {**all_captures, **sequence_capture}
all_captures = {**all_captures, **result.sequence_capture}
elif isinstance(m, MatchIfTrue):
return {} if matcher.func(node) else None
else:
@ -1150,7 +1182,7 @@ def _attribute_matches( # noqa: C901
matcher,
),
metadata_lookup,
)
).sequence_capture
# We exhausted our possibilities, there's no match
return None

View file

@ -404,3 +404,25 @@ class MatchersExtractTest(UnitTest):
),
)
self.assertIsNone(nodes)
def test_extract_sequence_multiple_wildcards(self) -> None:
expression = cst.parse_expression("1, 2, 3, 4")
nodes = m.extract(
expression,
m.Tuple(
elements=(
m.SaveMatchedNode(m.ZeroOrMore(), "head"),
m.SaveMatchedNode(m.Element(value=m.Integer(value="3")), "element"),
m.SaveMatchedNode(m.ZeroOrMore(), "tail"),
)
),
)
tuple_elements = cst.ensure_type(expression, cst.Tuple).elements
self.assertEqual(
nodes,
{
"head": tuple(tuple_elements[:2]),
"element": tuple_elements[2],
"tail": tuple(tuple_elements[3:]),
},
)