mirror of
https://github.com/astral-sh/ruff.git
synced 2025-09-28 12:55:05 +00:00
Add definitions for match statement (#13147)
## Summary This PR adds definition for match patterns. ## Test Plan Update the existing test case for match statement symbols to verify that the definitions are added as well.
This commit is contained in:
parent
9986397d56
commit
17eb65b26f
6 changed files with 189 additions and 16 deletions
|
@ -31,10 +31,10 @@ impl<T> AstNodeRef<T> {
|
|||
/// which the `AstNodeRef` belongs.
|
||||
///
|
||||
/// ## Safety
|
||||
///
|
||||
/// Dereferencing the `node` can result in undefined behavior if `parsed` isn't the
|
||||
/// [`ParsedModule`] to which `node` belongs. It's the caller's responsibility to ensure that
|
||||
/// the invariant `node belongs to parsed` is upheld.
|
||||
|
||||
pub(super) unsafe fn new(parsed: ParsedModule, node: &T) -> Self {
|
||||
Self {
|
||||
_parsed: parsed,
|
||||
|
|
|
@ -1,12 +1,18 @@
|
|||
use ruff_python_ast::{AnyNodeRef, NodeKind};
|
||||
use ruff_python_ast::{AnyNodeRef, Identifier, NodeKind};
|
||||
use ruff_text_size::{Ranged, TextRange};
|
||||
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
|
||||
pub(super) enum Kind {
|
||||
Node(NodeKind),
|
||||
Identifier,
|
||||
}
|
||||
|
||||
/// Compact key for a node for use in a hash map.
|
||||
///
|
||||
/// Compares two nodes by their kind and text range.
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
|
||||
pub(super) struct NodeKey {
|
||||
kind: NodeKind,
|
||||
kind: Kind,
|
||||
range: TextRange,
|
||||
}
|
||||
|
||||
|
@ -17,8 +23,15 @@ impl NodeKey {
|
|||
{
|
||||
let node = node.into();
|
||||
NodeKey {
|
||||
kind: node.kind(),
|
||||
kind: Kind::Node(node.kind()),
|
||||
range: node.range(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn from_identifier(identifier: &Identifier) -> Self {
|
||||
NodeKey {
|
||||
kind: Kind::Identifier,
|
||||
range: identifier.range(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1073,7 +1073,7 @@ def x():
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn match_stmt_symbols() {
|
||||
fn match_stmt() {
|
||||
let TestCase { db, file } = test_case(
|
||||
"
|
||||
match subject:
|
||||
|
@ -1087,13 +1087,27 @@ match subject:
|
|||
",
|
||||
);
|
||||
|
||||
let global_table = symbol_table(&db, global_scope(&db, file));
|
||||
let global_scope_id = global_scope(&db, file);
|
||||
let global_table = symbol_table(&db, global_scope_id);
|
||||
|
||||
assert!(global_table.symbol_by_name("Foo").unwrap().is_used());
|
||||
assert_eq!(
|
||||
names(&global_table),
|
||||
vec!["subject", "a", "b", "c", "d", "f", "e", "h", "g", "Foo", "i", "j", "k", "l"]
|
||||
vec!["subject", "a", "b", "c", "d", "e", "f", "g", "h", "Foo", "i", "j", "k", "l"]
|
||||
);
|
||||
|
||||
let use_def = use_def_map(&db, global_scope_id);
|
||||
for name in ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"] {
|
||||
let definition = use_def
|
||||
.first_public_definition(
|
||||
global_table.symbol_id_by_name(name).expect("symbol exists"),
|
||||
)
|
||||
.expect("Expected with item definition for {name}");
|
||||
assert!(matches!(
|
||||
definition.node(&db),
|
||||
DefinitionKind::MatchPattern(_)
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
@ -26,7 +26,7 @@ use crate::semantic_index::use_def::{FlowSnapshot, UseDefMapBuilder};
|
|||
use crate::semantic_index::SemanticIndex;
|
||||
use crate::Db;
|
||||
|
||||
use super::definition::WithItemDefinitionNodeRef;
|
||||
use super::definition::{MatchPatternDefinitionNodeRef, WithItemDefinitionNodeRef};
|
||||
|
||||
pub(super) struct SemanticIndexBuilder<'db> {
|
||||
// Builder state
|
||||
|
@ -600,6 +600,17 @@ where
|
|||
self.visit_body(body);
|
||||
self.visit_body(orelse);
|
||||
}
|
||||
ast::Stmt::Match(ast::StmtMatch {
|
||||
subject,
|
||||
cases,
|
||||
range: _,
|
||||
}) => {
|
||||
self.add_standalone_expression(subject);
|
||||
self.visit_expr(subject);
|
||||
for case in cases {
|
||||
self.visit_match_case(case);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
walk_stmt(self, stmt);
|
||||
}
|
||||
|
@ -803,22 +814,77 @@ where
|
|||
}
|
||||
|
||||
fn visit_pattern(&mut self, pattern: &'ast ast::Pattern) {
|
||||
if let ast::Pattern::MatchAs(ast::PatternMatchAs {
|
||||
name: Some(name), ..
|
||||
})
|
||||
| ast::Pattern::MatchStar(ast::PatternMatchStar {
|
||||
// The definition visitor will recurse into the pattern so avoid walking it here.
|
||||
let mut definition_visitor = MatchPatternDefinitionVisitor::new(self, pattern);
|
||||
definition_visitor.visit_pattern(pattern);
|
||||
}
|
||||
}
|
||||
|
||||
/// A visitor that adds symbols and definitions for the identifiers in a match pattern.
|
||||
struct MatchPatternDefinitionVisitor<'a, 'db> {
|
||||
/// The semantic index builder in which to add the symbols and definitions.
|
||||
builder: &'a mut SemanticIndexBuilder<'db>,
|
||||
/// The index of the current node in the pattern.
|
||||
index: u32,
|
||||
/// The pattern being visited. This pattern is the outermost pattern that is being visited
|
||||
/// and is required to add the definitions.
|
||||
pattern: &'a ast::Pattern,
|
||||
}
|
||||
|
||||
impl<'a, 'db> MatchPatternDefinitionVisitor<'a, 'db> {
|
||||
fn new(builder: &'a mut SemanticIndexBuilder<'db>, pattern: &'a ast::Pattern) -> Self {
|
||||
Self {
|
||||
index: 0,
|
||||
builder,
|
||||
pattern,
|
||||
}
|
||||
}
|
||||
|
||||
fn add_symbol_and_definition(&mut self, identifier: &ast::Identifier) {
|
||||
let symbol = self
|
||||
.builder
|
||||
.add_or_update_symbol(identifier.id().clone(), SymbolFlags::IS_DEFINED);
|
||||
self.builder.add_definition(
|
||||
symbol,
|
||||
MatchPatternDefinitionNodeRef {
|
||||
pattern: self.pattern,
|
||||
identifier,
|
||||
index: self.index,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, 'db> Visitor<'ast> for MatchPatternDefinitionVisitor<'_, 'db>
|
||||
where
|
||||
'ast: 'db,
|
||||
{
|
||||
fn visit_expr(&mut self, expr: &'ast ast::Expr) {
|
||||
self.builder.visit_expr(expr);
|
||||
}
|
||||
|
||||
fn visit_pattern(&mut self, pattern: &'ast ast::Pattern) {
|
||||
if let ast::Pattern::MatchStar(ast::PatternMatchStar {
|
||||
name: Some(name),
|
||||
range: _,
|
||||
}) = pattern
|
||||
{
|
||||
self.add_symbol_and_definition(name);
|
||||
}
|
||||
|
||||
walk_pattern(self, pattern);
|
||||
|
||||
if let ast::Pattern::MatchAs(ast::PatternMatchAs {
|
||||
name: Some(name), ..
|
||||
})
|
||||
| ast::Pattern::MatchMapping(ast::PatternMatchMapping {
|
||||
rest: Some(name), ..
|
||||
}) = pattern
|
||||
{
|
||||
// TODO(dhruvmanila): Add definition
|
||||
self.add_or_update_symbol(name.id.clone(), SymbolFlags::IS_DEFINED);
|
||||
self.add_symbol_and_definition(name);
|
||||
}
|
||||
|
||||
walk_pattern(self, pattern);
|
||||
self.index += 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -49,6 +49,7 @@ pub(crate) enum DefinitionNodeRef<'a> {
|
|||
Comprehension(ComprehensionDefinitionNodeRef<'a>),
|
||||
Parameter(ast::AnyParameterRef<'a>),
|
||||
WithItem(WithItemDefinitionNodeRef<'a>),
|
||||
MatchPattern(MatchPatternDefinitionNodeRef<'a>),
|
||||
}
|
||||
|
||||
impl<'a> From<&'a ast::StmtFunctionDef> for DefinitionNodeRef<'a> {
|
||||
|
@ -123,6 +124,12 @@ impl<'a> From<ast::AnyParameterRef<'a>> for DefinitionNodeRef<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'a> From<MatchPatternDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
|
||||
fn from(node: MatchPatternDefinitionNodeRef<'a>) -> Self {
|
||||
Self::MatchPattern(node)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub(crate) struct ImportFromDefinitionNodeRef<'a> {
|
||||
pub(crate) node: &'a ast::StmtImportFrom,
|
||||
|
@ -153,6 +160,17 @@ pub(crate) struct ComprehensionDefinitionNodeRef<'a> {
|
|||
pub(crate) first: bool,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub(crate) struct MatchPatternDefinitionNodeRef<'a> {
|
||||
/// The outermost pattern node in which the identifier being defined occurs.
|
||||
pub(crate) pattern: &'a ast::Pattern,
|
||||
/// The identifier being defined.
|
||||
pub(crate) identifier: &'a ast::Identifier,
|
||||
/// The index of the identifier in the pattern when visiting the `pattern` node in evaluation
|
||||
/// order.
|
||||
pub(crate) index: u32,
|
||||
}
|
||||
|
||||
impl DefinitionNodeRef<'_> {
|
||||
#[allow(unsafe_code)]
|
||||
pub(super) unsafe fn into_owned(self, parsed: ParsedModule) -> DefinitionKind {
|
||||
|
@ -213,6 +231,15 @@ impl DefinitionNodeRef<'_> {
|
|||
target: AstNodeRef::new(parsed, target),
|
||||
})
|
||||
}
|
||||
DefinitionNodeRef::MatchPattern(MatchPatternDefinitionNodeRef {
|
||||
pattern,
|
||||
identifier,
|
||||
index,
|
||||
}) => DefinitionKind::MatchPattern(MatchPatternDefinitionKind {
|
||||
pattern: AstNodeRef::new(parsed.clone(), pattern),
|
||||
identifier: AstNodeRef::new(parsed, identifier),
|
||||
index,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -241,6 +268,9 @@ impl DefinitionNodeRef<'_> {
|
|||
ast::AnyParameterRef::NonVariadic(parameter) => parameter.into(),
|
||||
},
|
||||
Self::WithItem(WithItemDefinitionNodeRef { node: _, target }) => target.into(),
|
||||
Self::MatchPattern(MatchPatternDefinitionNodeRef { identifier, .. }) => {
|
||||
identifier.into()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -260,6 +290,25 @@ pub enum DefinitionKind {
|
|||
Parameter(AstNodeRef<ast::Parameter>),
|
||||
ParameterWithDefault(AstNodeRef<ast::ParameterWithDefault>),
|
||||
WithItem(WithItemDefinitionKind),
|
||||
MatchPattern(MatchPatternDefinitionKind),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[allow(dead_code)]
|
||||
pub struct MatchPatternDefinitionKind {
|
||||
pattern: AstNodeRef<ast::Pattern>,
|
||||
identifier: AstNodeRef<ast::Identifier>,
|
||||
index: u32,
|
||||
}
|
||||
|
||||
impl MatchPatternDefinitionKind {
|
||||
pub(crate) fn pattern(&self) -> &ast::Pattern {
|
||||
self.pattern.node()
|
||||
}
|
||||
|
||||
pub(crate) fn index(&self) -> u32 {
|
||||
self.index
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
|
@ -410,3 +459,9 @@ impl From<&ast::ParameterWithDefault> for DefinitionNodeKey {
|
|||
Self(NodeKey::from_node(node))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&ast::Identifier> for DefinitionNodeKey {
|
||||
fn from(identifier: &ast::Identifier) -> Self {
|
||||
Self(NodeKey::from_identifier(identifier))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -416,6 +416,13 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
DefinitionKind::WithItem(with_item) => {
|
||||
self.infer_with_item_definition(with_item.target(), with_item.node(), definition);
|
||||
}
|
||||
DefinitionKind::MatchPattern(match_pattern) => {
|
||||
self.infer_match_pattern_definition(
|
||||
match_pattern.pattern(),
|
||||
match_pattern.index(),
|
||||
definition,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -795,7 +802,10 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
cases,
|
||||
} = match_statement;
|
||||
|
||||
self.infer_expression(subject);
|
||||
let expression = self.index.expression(subject.as_ref());
|
||||
let result = infer_expression_types(self.db, expression);
|
||||
self.extend(result);
|
||||
|
||||
for case in cases {
|
||||
let ast::MatchCase {
|
||||
range: _,
|
||||
|
@ -809,7 +819,22 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
fn infer_match_pattern_definition(
|
||||
&mut self,
|
||||
_pattern: &ast::Pattern,
|
||||
_index: u32,
|
||||
definition: Definition<'db>,
|
||||
) {
|
||||
// TODO(dhruvmanila): The correct way to infer types here is to perform structural matching
|
||||
// against the subject expression type (which we can query via `infer_expression_types`)
|
||||
// and extract the type at the `index` position if the pattern matches. This will be
|
||||
// similar to the logic in `self.infer_assignment_definition`.
|
||||
self.types.definitions.insert(definition, Type::Unknown);
|
||||
}
|
||||
|
||||
fn infer_match_pattern(&mut self, pattern: &ast::Pattern) {
|
||||
// TODO(dhruvmanila): Add a Salsa query for inferring pattern types and matching against
|
||||
// the subject expression: https://github.com/astral-sh/ruff/pull/13147#discussion_r1739424510
|
||||
match pattern {
|
||||
ast::Pattern::MatchValue(match_value) => {
|
||||
self.infer_expression(&match_value.value);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue