mirror of
https://github.com/astral-sh/ruff.git
synced 2025-09-29 13:25:17 +00:00
[red-knot] Remove match pattern definition visitor (#13209)
## Summary This PR is based on this discussion: https://github.com/astral-sh/ruff/pull/13147#discussion_r1739408653. **Todo** - [x] Add documentation for `MatchPatternState` ## Test Plan `cargo insta test` and `cargo clippy`
This commit is contained in:
parent
46e687e8d1
commit
facf6febf0
2 changed files with 108 additions and 55 deletions
|
@ -1097,16 +1097,62 @@ match subject:
|
|||
);
|
||||
|
||||
let use_def = use_def_map(&db, global_scope_id);
|
||||
for name in ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"] {
|
||||
for (name, expected_index) in [
|
||||
("a", 0),
|
||||
("b", 0),
|
||||
("c", 1),
|
||||
("d", 2),
|
||||
("e", 0),
|
||||
("f", 1),
|
||||
("g", 0),
|
||||
("h", 1),
|
||||
("i", 0),
|
||||
("j", 1),
|
||||
("k", 0),
|
||||
("l", 1),
|
||||
] {
|
||||
let definition = use_def
|
||||
.first_public_definition(
|
||||
global_table.symbol_id_by_name(name).expect("symbol exists"),
|
||||
)
|
||||
.expect("Expected with item definition for {name}");
|
||||
assert!(matches!(
|
||||
definition.node(&db),
|
||||
DefinitionKind::MatchPattern(_)
|
||||
));
|
||||
if let DefinitionKind::MatchPattern(pattern) = definition.node(&db) {
|
||||
assert_eq!(pattern.index(), expected_index);
|
||||
} else {
|
||||
panic!("Expected match pattern definition for {name}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn nested_match_case() {
|
||||
let TestCase { db, file } = test_case(
|
||||
"
|
||||
match 1:
|
||||
case first:
|
||||
match 2:
|
||||
case second:
|
||||
pass
|
||||
",
|
||||
);
|
||||
|
||||
let global_scope_id = global_scope(&db, file);
|
||||
let global_table = symbol_table(&db, global_scope_id);
|
||||
|
||||
assert_eq!(names(&global_table), vec!["first", "second"]);
|
||||
|
||||
let use_def = use_def_map(&db, global_scope_id);
|
||||
for (name, expected_index) in [("first", 0), ("second", 0)] {
|
||||
let definition = use_def
|
||||
.first_public_definition(
|
||||
global_table.symbol_id_by_name(name).expect("symbol exists"),
|
||||
)
|
||||
.expect("Expected with item definition for {name}");
|
||||
if let DefinitionKind::MatchPattern(pattern) = definition.node(&db) {
|
||||
assert_eq!(pattern.index(), expected_index);
|
||||
} else {
|
||||
panic!("Expected match pattern definition for {name}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -36,6 +36,8 @@ pub(super) struct SemanticIndexBuilder<'db> {
|
|||
scope_stack: Vec<FileScopeId>,
|
||||
/// The assignment we're currently visiting.
|
||||
current_assignment: Option<CurrentAssignment<'db>>,
|
||||
/// The match case we're currently visiting.
|
||||
current_match_case: Option<CurrentMatchCase<'db>>,
|
||||
/// Flow states at each `break` in the current loop.
|
||||
loop_break_states: Vec<FlowSnapshot>,
|
||||
|
||||
|
@ -59,6 +61,7 @@ impl<'db> SemanticIndexBuilder<'db> {
|
|||
module: parsed,
|
||||
scope_stack: Vec::new(),
|
||||
current_assignment: None,
|
||||
current_match_case: None,
|
||||
loop_break_states: vec![],
|
||||
|
||||
scopes: IndexVec::new(),
|
||||
|
@ -805,7 +808,7 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
fn visit_parameters(&mut self, parameters: &'ast ruff_python_ast::Parameters) {
|
||||
fn visit_parameters(&mut self, parameters: &'ast ast::Parameters) {
|
||||
// Intentionally avoid walking default expressions, as we handle them in the enclosing
|
||||
// scope.
|
||||
for parameter in parameters.iter().map(ast::AnyParameterRef::as_parameter) {
|
||||
|
@ -813,54 +816,16 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
fn visit_pattern(&mut self, pattern: &'ast ast::Pattern) {
|
||||
// The definition visitor will recurse into the pattern so avoid walking it here.
|
||||
let mut definition_visitor = MatchPatternDefinitionVisitor::new(self, pattern);
|
||||
definition_visitor.visit_pattern(pattern);
|
||||
}
|
||||
}
|
||||
fn visit_match_case(&mut self, match_case: &'ast ast::MatchCase) {
|
||||
debug_assert!(self.current_match_case.is_none());
|
||||
self.current_match_case = Some(CurrentMatchCase::new(&match_case.pattern));
|
||||
self.visit_pattern(&match_case.pattern);
|
||||
self.current_match_case = None;
|
||||
|
||||
/// A visitor that adds symbols and definitions for the identifiers in a match pattern.
|
||||
struct MatchPatternDefinitionVisitor<'a, 'db> {
|
||||
/// The semantic index builder in which to add the symbols and definitions.
|
||||
builder: &'a mut SemanticIndexBuilder<'db>,
|
||||
/// The index of the current node in the pattern.
|
||||
index: u32,
|
||||
/// The pattern being visited. This pattern is the outermost pattern that is being visited
|
||||
/// and is required to add the definitions.
|
||||
pattern: &'a ast::Pattern,
|
||||
}
|
||||
|
||||
impl<'a, 'db> MatchPatternDefinitionVisitor<'a, 'db> {
|
||||
fn new(builder: &'a mut SemanticIndexBuilder<'db>, pattern: &'a ast::Pattern) -> Self {
|
||||
Self {
|
||||
index: 0,
|
||||
builder,
|
||||
pattern,
|
||||
if let Some(expr) = &match_case.guard {
|
||||
self.visit_expr(expr);
|
||||
}
|
||||
}
|
||||
|
||||
fn add_symbol_and_definition(&mut self, identifier: &ast::Identifier) {
|
||||
let symbol = self
|
||||
.builder
|
||||
.add_or_update_symbol(identifier.id().clone(), SymbolFlags::IS_DEFINED);
|
||||
self.builder.add_definition(
|
||||
symbol,
|
||||
MatchPatternDefinitionNodeRef {
|
||||
pattern: self.pattern,
|
||||
identifier,
|
||||
index: self.index,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, 'db> Visitor<'ast> for MatchPatternDefinitionVisitor<'_, 'db>
|
||||
where
|
||||
'ast: 'db,
|
||||
{
|
||||
fn visit_expr(&mut self, expr: &'ast ast::Expr) {
|
||||
self.builder.visit_expr(expr);
|
||||
self.visit_body(&match_case.body);
|
||||
}
|
||||
|
||||
fn visit_pattern(&mut self, pattern: &'ast ast::Pattern) {
|
||||
|
@ -869,7 +834,16 @@ where
|
|||
range: _,
|
||||
}) = pattern
|
||||
{
|
||||
self.add_symbol_and_definition(name);
|
||||
let symbol = self.add_or_update_symbol(name.id().clone(), SymbolFlags::IS_DEFINED);
|
||||
let state = self.current_match_case.as_ref().unwrap();
|
||||
self.add_definition(
|
||||
symbol,
|
||||
MatchPatternDefinitionNodeRef {
|
||||
pattern: state.pattern,
|
||||
identifier: name,
|
||||
index: state.index,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
walk_pattern(self, pattern);
|
||||
|
@ -881,10 +855,19 @@ where
|
|||
rest: Some(name), ..
|
||||
}) = pattern
|
||||
{
|
||||
self.add_symbol_and_definition(name);
|
||||
let symbol = self.add_or_update_symbol(name.id().clone(), SymbolFlags::IS_DEFINED);
|
||||
let state = self.current_match_case.as_ref().unwrap();
|
||||
self.add_definition(
|
||||
symbol,
|
||||
MatchPatternDefinitionNodeRef {
|
||||
pattern: state.pattern,
|
||||
identifier: name,
|
||||
index: state.index,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
self.index += 1;
|
||||
self.current_match_case.as_mut().unwrap().index += 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -937,3 +920,27 @@ impl<'a> From<&'a ast::WithItem> for CurrentAssignment<'a> {
|
|||
Self::WithItem(value)
|
||||
}
|
||||
}
|
||||
|
||||
struct CurrentMatchCase<'a> {
|
||||
/// The pattern that's part of the current match case.
|
||||
pattern: &'a ast::Pattern,
|
||||
|
||||
/// The index of the sub-pattern that's being currently visited within the pattern.
|
||||
///
|
||||
/// For example:
|
||||
/// ```py
|
||||
/// match subject:
|
||||
/// case a as b: ...
|
||||
/// case [a, b]: ...
|
||||
/// case a | b: ...
|
||||
/// ```
|
||||
///
|
||||
/// In all of the above cases, the index would be 0 for `a` and 1 for `b`.
|
||||
index: u32,
|
||||
}
|
||||
|
||||
impl<'a> CurrentMatchCase<'a> {
|
||||
fn new(pattern: &'a ast::Pattern) -> Self {
|
||||
Self { pattern, index: 0 }
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue