diff --git a/crates/red_knot_python_semantic/src/semantic_index.rs b/crates/red_knot_python_semantic/src/semantic_index.rs index 6a5c96842f..1a60ef729b 100644 --- a/crates/red_knot_python_semantic/src/semantic_index.rs +++ b/crates/red_knot_python_semantic/src/semantic_index.rs @@ -1097,16 +1097,62 @@ match subject: ); let use_def = use_def_map(&db, global_scope_id); - for name in ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"] { + for (name, expected_index) in [ + ("a", 0), + ("b", 0), + ("c", 1), + ("d", 2), + ("e", 0), + ("f", 1), + ("g", 0), + ("h", 1), + ("i", 0), + ("j", 1), + ("k", 0), + ("l", 1), + ] { let definition = use_def .first_public_definition( global_table.symbol_id_by_name(name).expect("symbol exists"), ) .expect("Expected with item definition for {name}"); - assert!(matches!( - definition.node(&db), - DefinitionKind::MatchPattern(_) - )); + if let DefinitionKind::MatchPattern(pattern) = definition.node(&db) { + assert_eq!(pattern.index(), expected_index); + } else { + panic!("Expected match pattern definition for {name}"); + } + } + } + + #[test] + fn nested_match_case() { + let TestCase { db, file } = test_case( + " +match 1: + case first: + match 2: + case second: + pass +", + ); + + let global_scope_id = global_scope(&db, file); + let global_table = symbol_table(&db, global_scope_id); + + assert_eq!(names(&global_table), vec!["first", "second"]); + + let use_def = use_def_map(&db, global_scope_id); + for (name, expected_index) in [("first", 0), ("second", 0)] { + let definition = use_def + .first_public_definition( + global_table.symbol_id_by_name(name).expect("symbol exists"), + ) + .expect("Expected with item definition for {name}"); + if let DefinitionKind::MatchPattern(pattern) = definition.node(&db) { + assert_eq!(pattern.index(), expected_index); + } else { + panic!("Expected match pattern definition for {name}"); + } } } diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index dfdab1ec71..3f6d0c23e0 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -36,6 +36,8 @@ pub(super) struct SemanticIndexBuilder<'db> { scope_stack: Vec, /// The assignment we're currently visiting. current_assignment: Option>, + /// The match case we're currently visiting. + current_match_case: Option>, /// Flow states at each `break` in the current loop. loop_break_states: Vec, @@ -59,6 +61,7 @@ impl<'db> SemanticIndexBuilder<'db> { module: parsed, scope_stack: Vec::new(), current_assignment: None, + current_match_case: None, loop_break_states: vec![], scopes: IndexVec::new(), @@ -805,7 +808,7 @@ where } } - fn visit_parameters(&mut self, parameters: &'ast ruff_python_ast::Parameters) { + fn visit_parameters(&mut self, parameters: &'ast ast::Parameters) { // Intentionally avoid walking default expressions, as we handle them in the enclosing // scope. for parameter in parameters.iter().map(ast::AnyParameterRef::as_parameter) { @@ -813,54 +816,16 @@ where } } - fn visit_pattern(&mut self, pattern: &'ast ast::Pattern) { - // The definition visitor will recurse into the pattern so avoid walking it here. - let mut definition_visitor = MatchPatternDefinitionVisitor::new(self, pattern); - definition_visitor.visit_pattern(pattern); - } -} + fn visit_match_case(&mut self, match_case: &'ast ast::MatchCase) { + debug_assert!(self.current_match_case.is_none()); + self.current_match_case = Some(CurrentMatchCase::new(&match_case.pattern)); + self.visit_pattern(&match_case.pattern); + self.current_match_case = None; -/// A visitor that adds symbols and definitions for the identifiers in a match pattern. -struct MatchPatternDefinitionVisitor<'a, 'db> { - /// The semantic index builder in which to add the symbols and definitions. - builder: &'a mut SemanticIndexBuilder<'db>, - /// The index of the current node in the pattern. - index: u32, - /// The pattern being visited. This pattern is the outermost pattern that is being visited - /// and is required to add the definitions. - pattern: &'a ast::Pattern, -} - -impl<'a, 'db> MatchPatternDefinitionVisitor<'a, 'db> { - fn new(builder: &'a mut SemanticIndexBuilder<'db>, pattern: &'a ast::Pattern) -> Self { - Self { - index: 0, - builder, - pattern, + if let Some(expr) = &match_case.guard { + self.visit_expr(expr); } - } - - fn add_symbol_and_definition(&mut self, identifier: &ast::Identifier) { - let symbol = self - .builder - .add_or_update_symbol(identifier.id().clone(), SymbolFlags::IS_DEFINED); - self.builder.add_definition( - symbol, - MatchPatternDefinitionNodeRef { - pattern: self.pattern, - identifier, - index: self.index, - }, - ); - } -} - -impl<'ast, 'db> Visitor<'ast> for MatchPatternDefinitionVisitor<'_, 'db> -where - 'ast: 'db, -{ - fn visit_expr(&mut self, expr: &'ast ast::Expr) { - self.builder.visit_expr(expr); + self.visit_body(&match_case.body); } fn visit_pattern(&mut self, pattern: &'ast ast::Pattern) { @@ -869,7 +834,16 @@ where range: _, }) = pattern { - self.add_symbol_and_definition(name); + let symbol = self.add_or_update_symbol(name.id().clone(), SymbolFlags::IS_DEFINED); + let state = self.current_match_case.as_ref().unwrap(); + self.add_definition( + symbol, + MatchPatternDefinitionNodeRef { + pattern: state.pattern, + identifier: name, + index: state.index, + }, + ); } walk_pattern(self, pattern); @@ -881,10 +855,19 @@ where rest: Some(name), .. }) = pattern { - self.add_symbol_and_definition(name); + let symbol = self.add_or_update_symbol(name.id().clone(), SymbolFlags::IS_DEFINED); + let state = self.current_match_case.as_ref().unwrap(); + self.add_definition( + symbol, + MatchPatternDefinitionNodeRef { + pattern: state.pattern, + identifier: name, + index: state.index, + }, + ); } - self.index += 1; + self.current_match_case.as_mut().unwrap().index += 1; } } @@ -937,3 +920,27 @@ impl<'a> From<&'a ast::WithItem> for CurrentAssignment<'a> { Self::WithItem(value) } } + +struct CurrentMatchCase<'a> { + /// The pattern that's part of the current match case. + pattern: &'a ast::Pattern, + + /// The index of the sub-pattern that's being currently visited within the pattern. + /// + /// For example: + /// ```py + /// match subject: + /// case a as b: ... + /// case [a, b]: ... + /// case a | b: ... + /// ``` + /// + /// In all of the above cases, the index would be 0 for `a` and 1 for `b`. + index: u32, +} + +impl<'a> CurrentMatchCase<'a> { + fn new(pattern: &'a ast::Pattern) -> Self { + Self { pattern, index: 0 } + } +}