Add definitions for match statement (#13147)

## Summary

This PR adds definition for match patterns.

## Test Plan

Update the existing test case for match statement symbols to verify that
the definitions are added as well.
This commit is contained in:
Dhruv Manilawala 2024-09-02 14:40:09 +05:30 committed by GitHub
parent 9986397d56
commit 17eb65b26f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 189 additions and 16 deletions

View file

@ -31,10 +31,10 @@ impl<T> AstNodeRef<T> {
/// which the `AstNodeRef` belongs. /// which the `AstNodeRef` belongs.
/// ///
/// ## Safety /// ## Safety
///
/// Dereferencing the `node` can result in undefined behavior if `parsed` isn't the /// Dereferencing the `node` can result in undefined behavior if `parsed` isn't the
/// [`ParsedModule`] to which `node` belongs. It's the caller's responsibility to ensure that /// [`ParsedModule`] to which `node` belongs. It's the caller's responsibility to ensure that
/// the invariant `node belongs to parsed` is upheld. /// the invariant `node belongs to parsed` is upheld.
pub(super) unsafe fn new(parsed: ParsedModule, node: &T) -> Self { pub(super) unsafe fn new(parsed: ParsedModule, node: &T) -> Self {
Self { Self {
_parsed: parsed, _parsed: parsed,

View file

@ -1,12 +1,18 @@
use ruff_python_ast::{AnyNodeRef, NodeKind}; use ruff_python_ast::{AnyNodeRef, Identifier, NodeKind};
use ruff_text_size::{Ranged, TextRange}; use ruff_text_size::{Ranged, TextRange};
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub(super) enum Kind {
Node(NodeKind),
Identifier,
}
/// Compact key for a node for use in a hash map. /// Compact key for a node for use in a hash map.
/// ///
/// Compares two nodes by their kind and text range. /// Compares two nodes by their kind and text range.
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub(super) struct NodeKey { pub(super) struct NodeKey {
kind: NodeKind, kind: Kind,
range: TextRange, range: TextRange,
} }
@ -17,8 +23,15 @@ impl NodeKey {
{ {
let node = node.into(); let node = node.into();
NodeKey { NodeKey {
kind: node.kind(), kind: Kind::Node(node.kind()),
range: node.range(), range: node.range(),
} }
} }
pub(super) fn from_identifier(identifier: &Identifier) -> Self {
NodeKey {
kind: Kind::Identifier,
range: identifier.range(),
}
}
} }

View file

@ -1073,7 +1073,7 @@ def x():
} }
#[test] #[test]
fn match_stmt_symbols() { fn match_stmt() {
let TestCase { db, file } = test_case( let TestCase { db, file } = test_case(
" "
match subject: match subject:
@ -1087,13 +1087,27 @@ match subject:
", ",
); );
let global_table = symbol_table(&db, global_scope(&db, file)); let global_scope_id = global_scope(&db, file);
let global_table = symbol_table(&db, global_scope_id);
assert!(global_table.symbol_by_name("Foo").unwrap().is_used()); assert!(global_table.symbol_by_name("Foo").unwrap().is_used());
assert_eq!( assert_eq!(
names(&global_table), names(&global_table),
vec!["subject", "a", "b", "c", "d", "f", "e", "h", "g", "Foo", "i", "j", "k", "l"] vec!["subject", "a", "b", "c", "d", "e", "f", "g", "h", "Foo", "i", "j", "k", "l"]
); );
let use_def = use_def_map(&db, global_scope_id);
for name in ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"] {
let definition = use_def
.first_public_definition(
global_table.symbol_id_by_name(name).expect("symbol exists"),
)
.expect("Expected with item definition for {name}");
assert!(matches!(
definition.node(&db),
DefinitionKind::MatchPattern(_)
));
}
} }
#[test] #[test]

View file

@ -26,7 +26,7 @@ use crate::semantic_index::use_def::{FlowSnapshot, UseDefMapBuilder};
use crate::semantic_index::SemanticIndex; use crate::semantic_index::SemanticIndex;
use crate::Db; use crate::Db;
use super::definition::WithItemDefinitionNodeRef; use super::definition::{MatchPatternDefinitionNodeRef, WithItemDefinitionNodeRef};
pub(super) struct SemanticIndexBuilder<'db> { pub(super) struct SemanticIndexBuilder<'db> {
// Builder state // Builder state
@ -600,6 +600,17 @@ where
self.visit_body(body); self.visit_body(body);
self.visit_body(orelse); self.visit_body(orelse);
} }
ast::Stmt::Match(ast::StmtMatch {
subject,
cases,
range: _,
}) => {
self.add_standalone_expression(subject);
self.visit_expr(subject);
for case in cases {
self.visit_match_case(case);
}
}
_ => { _ => {
walk_stmt(self, stmt); walk_stmt(self, stmt);
} }
@ -803,22 +814,77 @@ where
} }
fn visit_pattern(&mut self, pattern: &'ast ast::Pattern) { fn visit_pattern(&mut self, pattern: &'ast ast::Pattern) {
if let ast::Pattern::MatchAs(ast::PatternMatchAs { // The definition visitor will recurse into the pattern so avoid walking it here.
name: Some(name), .. let mut definition_visitor = MatchPatternDefinitionVisitor::new(self, pattern);
}) definition_visitor.visit_pattern(pattern);
| ast::Pattern::MatchStar(ast::PatternMatchStar { }
}
/// A visitor that adds symbols and definitions for the identifiers in a match pattern.
struct MatchPatternDefinitionVisitor<'a, 'db> {
/// The semantic index builder in which to add the symbols and definitions.
builder: &'a mut SemanticIndexBuilder<'db>,
/// The index of the current node in the pattern.
index: u32,
/// The pattern being visited. This pattern is the outermost pattern that is being visited
/// and is required to add the definitions.
pattern: &'a ast::Pattern,
}
impl<'a, 'db> MatchPatternDefinitionVisitor<'a, 'db> {
fn new(builder: &'a mut SemanticIndexBuilder<'db>, pattern: &'a ast::Pattern) -> Self {
Self {
index: 0,
builder,
pattern,
}
}
fn add_symbol_and_definition(&mut self, identifier: &ast::Identifier) {
let symbol = self
.builder
.add_or_update_symbol(identifier.id().clone(), SymbolFlags::IS_DEFINED);
self.builder.add_definition(
symbol,
MatchPatternDefinitionNodeRef {
pattern: self.pattern,
identifier,
index: self.index,
},
);
}
}
impl<'ast, 'db> Visitor<'ast> for MatchPatternDefinitionVisitor<'_, 'db>
where
'ast: 'db,
{
fn visit_expr(&mut self, expr: &'ast ast::Expr) {
self.builder.visit_expr(expr);
}
fn visit_pattern(&mut self, pattern: &'ast ast::Pattern) {
if let ast::Pattern::MatchStar(ast::PatternMatchStar {
name: Some(name), name: Some(name),
range: _, range: _,
}) = pattern
{
self.add_symbol_and_definition(name);
}
walk_pattern(self, pattern);
if let ast::Pattern::MatchAs(ast::PatternMatchAs {
name: Some(name), ..
}) })
| ast::Pattern::MatchMapping(ast::PatternMatchMapping { | ast::Pattern::MatchMapping(ast::PatternMatchMapping {
rest: Some(name), .. rest: Some(name), ..
}) = pattern }) = pattern
{ {
// TODO(dhruvmanila): Add definition self.add_symbol_and_definition(name);
self.add_or_update_symbol(name.id.clone(), SymbolFlags::IS_DEFINED);
} }
walk_pattern(self, pattern); self.index += 1;
} }
} }

View file

@ -49,6 +49,7 @@ pub(crate) enum DefinitionNodeRef<'a> {
Comprehension(ComprehensionDefinitionNodeRef<'a>), Comprehension(ComprehensionDefinitionNodeRef<'a>),
Parameter(ast::AnyParameterRef<'a>), Parameter(ast::AnyParameterRef<'a>),
WithItem(WithItemDefinitionNodeRef<'a>), WithItem(WithItemDefinitionNodeRef<'a>),
MatchPattern(MatchPatternDefinitionNodeRef<'a>),
} }
impl<'a> From<&'a ast::StmtFunctionDef> for DefinitionNodeRef<'a> { impl<'a> From<&'a ast::StmtFunctionDef> for DefinitionNodeRef<'a> {
@ -123,6 +124,12 @@ impl<'a> From<ast::AnyParameterRef<'a>> for DefinitionNodeRef<'a> {
} }
} }
impl<'a> From<MatchPatternDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
fn from(node: MatchPatternDefinitionNodeRef<'a>) -> Self {
Self::MatchPattern(node)
}
}
#[derive(Copy, Clone, Debug)] #[derive(Copy, Clone, Debug)]
pub(crate) struct ImportFromDefinitionNodeRef<'a> { pub(crate) struct ImportFromDefinitionNodeRef<'a> {
pub(crate) node: &'a ast::StmtImportFrom, pub(crate) node: &'a ast::StmtImportFrom,
@ -153,6 +160,17 @@ pub(crate) struct ComprehensionDefinitionNodeRef<'a> {
pub(crate) first: bool, pub(crate) first: bool,
} }
#[derive(Copy, Clone, Debug)]
pub(crate) struct MatchPatternDefinitionNodeRef<'a> {
/// The outermost pattern node in which the identifier being defined occurs.
pub(crate) pattern: &'a ast::Pattern,
/// The identifier being defined.
pub(crate) identifier: &'a ast::Identifier,
/// The index of the identifier in the pattern when visiting the `pattern` node in evaluation
/// order.
pub(crate) index: u32,
}
impl DefinitionNodeRef<'_> { impl DefinitionNodeRef<'_> {
#[allow(unsafe_code)] #[allow(unsafe_code)]
pub(super) unsafe fn into_owned(self, parsed: ParsedModule) -> DefinitionKind { pub(super) unsafe fn into_owned(self, parsed: ParsedModule) -> DefinitionKind {
@ -213,6 +231,15 @@ impl DefinitionNodeRef<'_> {
target: AstNodeRef::new(parsed, target), target: AstNodeRef::new(parsed, target),
}) })
} }
DefinitionNodeRef::MatchPattern(MatchPatternDefinitionNodeRef {
pattern,
identifier,
index,
}) => DefinitionKind::MatchPattern(MatchPatternDefinitionKind {
pattern: AstNodeRef::new(parsed.clone(), pattern),
identifier: AstNodeRef::new(parsed, identifier),
index,
}),
} }
} }
@ -241,6 +268,9 @@ impl DefinitionNodeRef<'_> {
ast::AnyParameterRef::NonVariadic(parameter) => parameter.into(), ast::AnyParameterRef::NonVariadic(parameter) => parameter.into(),
}, },
Self::WithItem(WithItemDefinitionNodeRef { node: _, target }) => target.into(), Self::WithItem(WithItemDefinitionNodeRef { node: _, target }) => target.into(),
Self::MatchPattern(MatchPatternDefinitionNodeRef { identifier, .. }) => {
identifier.into()
}
} }
} }
} }
@ -260,6 +290,25 @@ pub enum DefinitionKind {
Parameter(AstNodeRef<ast::Parameter>), Parameter(AstNodeRef<ast::Parameter>),
ParameterWithDefault(AstNodeRef<ast::ParameterWithDefault>), ParameterWithDefault(AstNodeRef<ast::ParameterWithDefault>),
WithItem(WithItemDefinitionKind), WithItem(WithItemDefinitionKind),
MatchPattern(MatchPatternDefinitionKind),
}
#[derive(Clone, Debug)]
#[allow(dead_code)]
pub struct MatchPatternDefinitionKind {
pattern: AstNodeRef<ast::Pattern>,
identifier: AstNodeRef<ast::Identifier>,
index: u32,
}
impl MatchPatternDefinitionKind {
pub(crate) fn pattern(&self) -> &ast::Pattern {
self.pattern.node()
}
pub(crate) fn index(&self) -> u32 {
self.index
}
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -410,3 +459,9 @@ impl From<&ast::ParameterWithDefault> for DefinitionNodeKey {
Self(NodeKey::from_node(node)) Self(NodeKey::from_node(node))
} }
} }
impl From<&ast::Identifier> for DefinitionNodeKey {
fn from(identifier: &ast::Identifier) -> Self {
Self(NodeKey::from_identifier(identifier))
}
}

View file

@ -416,6 +416,13 @@ impl<'db> TypeInferenceBuilder<'db> {
DefinitionKind::WithItem(with_item) => { DefinitionKind::WithItem(with_item) => {
self.infer_with_item_definition(with_item.target(), with_item.node(), definition); self.infer_with_item_definition(with_item.target(), with_item.node(), definition);
} }
DefinitionKind::MatchPattern(match_pattern) => {
self.infer_match_pattern_definition(
match_pattern.pattern(),
match_pattern.index(),
definition,
);
}
} }
} }
@ -795,7 +802,10 @@ impl<'db> TypeInferenceBuilder<'db> {
cases, cases,
} = match_statement; } = match_statement;
self.infer_expression(subject); let expression = self.index.expression(subject.as_ref());
let result = infer_expression_types(self.db, expression);
self.extend(result);
for case in cases { for case in cases {
let ast::MatchCase { let ast::MatchCase {
range: _, range: _,
@ -809,7 +819,22 @@ impl<'db> TypeInferenceBuilder<'db> {
} }
} }
fn infer_match_pattern_definition(
&mut self,
_pattern: &ast::Pattern,
_index: u32,
definition: Definition<'db>,
) {
// TODO(dhruvmanila): The correct way to infer types here is to perform structural matching
// against the subject expression type (which we can query via `infer_expression_types`)
// and extract the type at the `index` position if the pattern matches. This will be
// similar to the logic in `self.infer_assignment_definition`.
self.types.definitions.insert(definition, Type::Unknown);
}
fn infer_match_pattern(&mut self, pattern: &ast::Pattern) { fn infer_match_pattern(&mut self, pattern: &ast::Pattern) {
// TODO(dhruvmanila): Add a Salsa query for inferring pattern types and matching against
// the subject expression: https://github.com/astral-sh/ruff/pull/13147#discussion_r1739424510
match pattern { match pattern {
ast::Pattern::MatchValue(match_value) => { ast::Pattern::MatchValue(match_value) => {
self.infer_expression(&match_value.value); self.infer_expression(&match_value.value);