[red-knot] Remove match pattern definition visitor (#13209)

## Summary

This PR is based on this discussion:
https://github.com/astral-sh/ruff/pull/13147#discussion_r1739408653.

**Todo**

- [x] Add documentation for `MatchPatternState`

## Test Plan

`cargo insta test` and `cargo clippy`
This commit is contained in:
Dhruv Manilawala 2024-09-03 14:23:35 +05:30 committed by GitHub
parent 46e687e8d1
commit facf6febf0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 108 additions and 55 deletions

View file

@ -1097,16 +1097,62 @@ match subject:
);
let use_def = use_def_map(&db, global_scope_id);
for name in ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"] {
for (name, expected_index) in [
("a", 0),
("b", 0),
("c", 1),
("d", 2),
("e", 0),
("f", 1),
("g", 0),
("h", 1),
("i", 0),
("j", 1),
("k", 0),
("l", 1),
] {
let definition = use_def
.first_public_definition(
global_table.symbol_id_by_name(name).expect("symbol exists"),
)
.expect("Expected with item definition for {name}");
assert!(matches!(
definition.node(&db),
DefinitionKind::MatchPattern(_)
));
if let DefinitionKind::MatchPattern(pattern) = definition.node(&db) {
assert_eq!(pattern.index(), expected_index);
} else {
panic!("Expected match pattern definition for {name}");
}
}
}
#[test]
fn nested_match_case() {
let TestCase { db, file } = test_case(
"
match 1:
case first:
match 2:
case second:
pass
",
);
let global_scope_id = global_scope(&db, file);
let global_table = symbol_table(&db, global_scope_id);
assert_eq!(names(&global_table), vec!["first", "second"]);
let use_def = use_def_map(&db, global_scope_id);
for (name, expected_index) in [("first", 0), ("second", 0)] {
let definition = use_def
.first_public_definition(
global_table.symbol_id_by_name(name).expect("symbol exists"),
)
.expect("Expected with item definition for {name}");
if let DefinitionKind::MatchPattern(pattern) = definition.node(&db) {
assert_eq!(pattern.index(), expected_index);
} else {
panic!("Expected match pattern definition for {name}");
}
}
}

View file

@ -36,6 +36,8 @@ pub(super) struct SemanticIndexBuilder<'db> {
scope_stack: Vec<FileScopeId>,
/// The assignment we're currently visiting.
current_assignment: Option<CurrentAssignment<'db>>,
/// The match case we're currently visiting.
current_match_case: Option<CurrentMatchCase<'db>>,
/// Flow states at each `break` in the current loop.
loop_break_states: Vec<FlowSnapshot>,
@ -59,6 +61,7 @@ impl<'db> SemanticIndexBuilder<'db> {
module: parsed,
scope_stack: Vec::new(),
current_assignment: None,
current_match_case: None,
loop_break_states: vec![],
scopes: IndexVec::new(),
@ -805,7 +808,7 @@ where
}
}
fn visit_parameters(&mut self, parameters: &'ast ruff_python_ast::Parameters) {
fn visit_parameters(&mut self, parameters: &'ast ast::Parameters) {
// Intentionally avoid walking default expressions, as we handle them in the enclosing
// scope.
for parameter in parameters.iter().map(ast::AnyParameterRef::as_parameter) {
@ -813,54 +816,16 @@ where
}
}
fn visit_pattern(&mut self, pattern: &'ast ast::Pattern) {
// The definition visitor will recurse into the pattern so avoid walking it here.
let mut definition_visitor = MatchPatternDefinitionVisitor::new(self, pattern);
definition_visitor.visit_pattern(pattern);
}
}
fn visit_match_case(&mut self, match_case: &'ast ast::MatchCase) {
debug_assert!(self.current_match_case.is_none());
self.current_match_case = Some(CurrentMatchCase::new(&match_case.pattern));
self.visit_pattern(&match_case.pattern);
self.current_match_case = None;
/// A visitor that adds symbols and definitions for the identifiers in a match pattern.
struct MatchPatternDefinitionVisitor<'a, 'db> {
/// The semantic index builder in which to add the symbols and definitions.
builder: &'a mut SemanticIndexBuilder<'db>,
/// The index of the current node in the pattern.
index: u32,
/// The pattern being visited. This pattern is the outermost pattern that is being visited
/// and is required to add the definitions.
pattern: &'a ast::Pattern,
if let Some(expr) = &match_case.guard {
self.visit_expr(expr);
}
impl<'a, 'db> MatchPatternDefinitionVisitor<'a, 'db> {
fn new(builder: &'a mut SemanticIndexBuilder<'db>, pattern: &'a ast::Pattern) -> Self {
Self {
index: 0,
builder,
pattern,
}
}
fn add_symbol_and_definition(&mut self, identifier: &ast::Identifier) {
let symbol = self
.builder
.add_or_update_symbol(identifier.id().clone(), SymbolFlags::IS_DEFINED);
self.builder.add_definition(
symbol,
MatchPatternDefinitionNodeRef {
pattern: self.pattern,
identifier,
index: self.index,
},
);
}
}
impl<'ast, 'db> Visitor<'ast> for MatchPatternDefinitionVisitor<'_, 'db>
where
'ast: 'db,
{
fn visit_expr(&mut self, expr: &'ast ast::Expr) {
self.builder.visit_expr(expr);
self.visit_body(&match_case.body);
}
fn visit_pattern(&mut self, pattern: &'ast ast::Pattern) {
@ -869,7 +834,16 @@ where
range: _,
}) = pattern
{
self.add_symbol_and_definition(name);
let symbol = self.add_or_update_symbol(name.id().clone(), SymbolFlags::IS_DEFINED);
let state = self.current_match_case.as_ref().unwrap();
self.add_definition(
symbol,
MatchPatternDefinitionNodeRef {
pattern: state.pattern,
identifier: name,
index: state.index,
},
);
}
walk_pattern(self, pattern);
@ -881,10 +855,19 @@ where
rest: Some(name), ..
}) = pattern
{
self.add_symbol_and_definition(name);
let symbol = self.add_or_update_symbol(name.id().clone(), SymbolFlags::IS_DEFINED);
let state = self.current_match_case.as_ref().unwrap();
self.add_definition(
symbol,
MatchPatternDefinitionNodeRef {
pattern: state.pattern,
identifier: name,
index: state.index,
},
);
}
self.index += 1;
self.current_match_case.as_mut().unwrap().index += 1;
}
}
@ -937,3 +920,27 @@ impl<'a> From<&'a ast::WithItem> for CurrentAssignment<'a> {
Self::WithItem(value)
}
}
struct CurrentMatchCase<'a> {
/// The pattern that's part of the current match case.
pattern: &'a ast::Pattern,
/// The index of the sub-pattern that's being currently visited within the pattern.
///
/// For example:
/// ```py
/// match subject:
/// case a as b: ...
/// case [a, b]: ...
/// case a | b: ...
/// ```
///
/// In all of the above cases, the index would be 0 for `a` and 1 for `b`.
index: u32,
}
impl<'a> CurrentMatchCase<'a> {
fn new(pattern: &'a ast::Pattern) -> Self {
Self { pattern, index: 0 }
}
}