mirror of
https://github.com/astral-sh/ruff.git
synced 2025-09-28 12:55:05 +00:00
Add definitions for match statement (#13147)
## Summary This PR adds definition for match patterns. ## Test Plan Update the existing test case for match statement symbols to verify that the definitions are added as well.
This commit is contained in:
parent
9986397d56
commit
17eb65b26f
6 changed files with 189 additions and 16 deletions
|
@ -31,10 +31,10 @@ impl<T> AstNodeRef<T> {
|
||||||
/// which the `AstNodeRef` belongs.
|
/// which the `AstNodeRef` belongs.
|
||||||
///
|
///
|
||||||
/// ## Safety
|
/// ## Safety
|
||||||
|
///
|
||||||
/// Dereferencing the `node` can result in undefined behavior if `parsed` isn't the
|
/// Dereferencing the `node` can result in undefined behavior if `parsed` isn't the
|
||||||
/// [`ParsedModule`] to which `node` belongs. It's the caller's responsibility to ensure that
|
/// [`ParsedModule`] to which `node` belongs. It's the caller's responsibility to ensure that
|
||||||
/// the invariant `node belongs to parsed` is upheld.
|
/// the invariant `node belongs to parsed` is upheld.
|
||||||
|
|
||||||
pub(super) unsafe fn new(parsed: ParsedModule, node: &T) -> Self {
|
pub(super) unsafe fn new(parsed: ParsedModule, node: &T) -> Self {
|
||||||
Self {
|
Self {
|
||||||
_parsed: parsed,
|
_parsed: parsed,
|
||||||
|
|
|
@ -1,12 +1,18 @@
|
||||||
use ruff_python_ast::{AnyNodeRef, NodeKind};
|
use ruff_python_ast::{AnyNodeRef, Identifier, NodeKind};
|
||||||
use ruff_text_size::{Ranged, TextRange};
|
use ruff_text_size::{Ranged, TextRange};
|
||||||
|
|
||||||
|
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
|
||||||
|
pub(super) enum Kind {
|
||||||
|
Node(NodeKind),
|
||||||
|
Identifier,
|
||||||
|
}
|
||||||
|
|
||||||
/// Compact key for a node for use in a hash map.
|
/// Compact key for a node for use in a hash map.
|
||||||
///
|
///
|
||||||
/// Compares two nodes by their kind and text range.
|
/// Compares two nodes by their kind and text range.
|
||||||
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
|
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
|
||||||
pub(super) struct NodeKey {
|
pub(super) struct NodeKey {
|
||||||
kind: NodeKind,
|
kind: Kind,
|
||||||
range: TextRange,
|
range: TextRange,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -17,8 +23,15 @@ impl NodeKey {
|
||||||
{
|
{
|
||||||
let node = node.into();
|
let node = node.into();
|
||||||
NodeKey {
|
NodeKey {
|
||||||
kind: node.kind(),
|
kind: Kind::Node(node.kind()),
|
||||||
range: node.range(),
|
range: node.range(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(super) fn from_identifier(identifier: &Identifier) -> Self {
|
||||||
|
NodeKey {
|
||||||
|
kind: Kind::Identifier,
|
||||||
|
range: identifier.range(),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1073,7 +1073,7 @@ def x():
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn match_stmt_symbols() {
|
fn match_stmt() {
|
||||||
let TestCase { db, file } = test_case(
|
let TestCase { db, file } = test_case(
|
||||||
"
|
"
|
||||||
match subject:
|
match subject:
|
||||||
|
@ -1087,13 +1087,27 @@ match subject:
|
||||||
",
|
",
|
||||||
);
|
);
|
||||||
|
|
||||||
let global_table = symbol_table(&db, global_scope(&db, file));
|
let global_scope_id = global_scope(&db, file);
|
||||||
|
let global_table = symbol_table(&db, global_scope_id);
|
||||||
|
|
||||||
assert!(global_table.symbol_by_name("Foo").unwrap().is_used());
|
assert!(global_table.symbol_by_name("Foo").unwrap().is_used());
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
names(&global_table),
|
names(&global_table),
|
||||||
vec!["subject", "a", "b", "c", "d", "f", "e", "h", "g", "Foo", "i", "j", "k", "l"]
|
vec!["subject", "a", "b", "c", "d", "e", "f", "g", "h", "Foo", "i", "j", "k", "l"]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let use_def = use_def_map(&db, global_scope_id);
|
||||||
|
for name in ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"] {
|
||||||
|
let definition = use_def
|
||||||
|
.first_public_definition(
|
||||||
|
global_table.symbol_id_by_name(name).expect("symbol exists"),
|
||||||
|
)
|
||||||
|
.expect("Expected with item definition for {name}");
|
||||||
|
assert!(matches!(
|
||||||
|
definition.node(&db),
|
||||||
|
DefinitionKind::MatchPattern(_)
|
||||||
|
));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
@ -26,7 +26,7 @@ use crate::semantic_index::use_def::{FlowSnapshot, UseDefMapBuilder};
|
||||||
use crate::semantic_index::SemanticIndex;
|
use crate::semantic_index::SemanticIndex;
|
||||||
use crate::Db;
|
use crate::Db;
|
||||||
|
|
||||||
use super::definition::WithItemDefinitionNodeRef;
|
use super::definition::{MatchPatternDefinitionNodeRef, WithItemDefinitionNodeRef};
|
||||||
|
|
||||||
pub(super) struct SemanticIndexBuilder<'db> {
|
pub(super) struct SemanticIndexBuilder<'db> {
|
||||||
// Builder state
|
// Builder state
|
||||||
|
@ -600,6 +600,17 @@ where
|
||||||
self.visit_body(body);
|
self.visit_body(body);
|
||||||
self.visit_body(orelse);
|
self.visit_body(orelse);
|
||||||
}
|
}
|
||||||
|
ast::Stmt::Match(ast::StmtMatch {
|
||||||
|
subject,
|
||||||
|
cases,
|
||||||
|
range: _,
|
||||||
|
}) => {
|
||||||
|
self.add_standalone_expression(subject);
|
||||||
|
self.visit_expr(subject);
|
||||||
|
for case in cases {
|
||||||
|
self.visit_match_case(case);
|
||||||
|
}
|
||||||
|
}
|
||||||
_ => {
|
_ => {
|
||||||
walk_stmt(self, stmt);
|
walk_stmt(self, stmt);
|
||||||
}
|
}
|
||||||
|
@ -803,22 +814,77 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
fn visit_pattern(&mut self, pattern: &'ast ast::Pattern) {
|
fn visit_pattern(&mut self, pattern: &'ast ast::Pattern) {
|
||||||
if let ast::Pattern::MatchAs(ast::PatternMatchAs {
|
// The definition visitor will recurse into the pattern so avoid walking it here.
|
||||||
name: Some(name), ..
|
let mut definition_visitor = MatchPatternDefinitionVisitor::new(self, pattern);
|
||||||
})
|
definition_visitor.visit_pattern(pattern);
|
||||||
| ast::Pattern::MatchStar(ast::PatternMatchStar {
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A visitor that adds symbols and definitions for the identifiers in a match pattern.
|
||||||
|
struct MatchPatternDefinitionVisitor<'a, 'db> {
|
||||||
|
/// The semantic index builder in which to add the symbols and definitions.
|
||||||
|
builder: &'a mut SemanticIndexBuilder<'db>,
|
||||||
|
/// The index of the current node in the pattern.
|
||||||
|
index: u32,
|
||||||
|
/// The pattern being visited. This pattern is the outermost pattern that is being visited
|
||||||
|
/// and is required to add the definitions.
|
||||||
|
pattern: &'a ast::Pattern,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'db> MatchPatternDefinitionVisitor<'a, 'db> {
|
||||||
|
fn new(builder: &'a mut SemanticIndexBuilder<'db>, pattern: &'a ast::Pattern) -> Self {
|
||||||
|
Self {
|
||||||
|
index: 0,
|
||||||
|
builder,
|
||||||
|
pattern,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add_symbol_and_definition(&mut self, identifier: &ast::Identifier) {
|
||||||
|
let symbol = self
|
||||||
|
.builder
|
||||||
|
.add_or_update_symbol(identifier.id().clone(), SymbolFlags::IS_DEFINED);
|
||||||
|
self.builder.add_definition(
|
||||||
|
symbol,
|
||||||
|
MatchPatternDefinitionNodeRef {
|
||||||
|
pattern: self.pattern,
|
||||||
|
identifier,
|
||||||
|
index: self.index,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ast, 'db> Visitor<'ast> for MatchPatternDefinitionVisitor<'_, 'db>
|
||||||
|
where
|
||||||
|
'ast: 'db,
|
||||||
|
{
|
||||||
|
fn visit_expr(&mut self, expr: &'ast ast::Expr) {
|
||||||
|
self.builder.visit_expr(expr);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_pattern(&mut self, pattern: &'ast ast::Pattern) {
|
||||||
|
if let ast::Pattern::MatchStar(ast::PatternMatchStar {
|
||||||
name: Some(name),
|
name: Some(name),
|
||||||
range: _,
|
range: _,
|
||||||
|
}) = pattern
|
||||||
|
{
|
||||||
|
self.add_symbol_and_definition(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
walk_pattern(self, pattern);
|
||||||
|
|
||||||
|
if let ast::Pattern::MatchAs(ast::PatternMatchAs {
|
||||||
|
name: Some(name), ..
|
||||||
})
|
})
|
||||||
| ast::Pattern::MatchMapping(ast::PatternMatchMapping {
|
| ast::Pattern::MatchMapping(ast::PatternMatchMapping {
|
||||||
rest: Some(name), ..
|
rest: Some(name), ..
|
||||||
}) = pattern
|
}) = pattern
|
||||||
{
|
{
|
||||||
// TODO(dhruvmanila): Add definition
|
self.add_symbol_and_definition(name);
|
||||||
self.add_or_update_symbol(name.id.clone(), SymbolFlags::IS_DEFINED);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
walk_pattern(self, pattern);
|
self.index += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -49,6 +49,7 @@ pub(crate) enum DefinitionNodeRef<'a> {
|
||||||
Comprehension(ComprehensionDefinitionNodeRef<'a>),
|
Comprehension(ComprehensionDefinitionNodeRef<'a>),
|
||||||
Parameter(ast::AnyParameterRef<'a>),
|
Parameter(ast::AnyParameterRef<'a>),
|
||||||
WithItem(WithItemDefinitionNodeRef<'a>),
|
WithItem(WithItemDefinitionNodeRef<'a>),
|
||||||
|
MatchPattern(MatchPatternDefinitionNodeRef<'a>),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> From<&'a ast::StmtFunctionDef> for DefinitionNodeRef<'a> {
|
impl<'a> From<&'a ast::StmtFunctionDef> for DefinitionNodeRef<'a> {
|
||||||
|
@ -123,6 +124,12 @@ impl<'a> From<ast::AnyParameterRef<'a>> for DefinitionNodeRef<'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'a> From<MatchPatternDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
|
||||||
|
fn from(node: MatchPatternDefinitionNodeRef<'a>) -> Self {
|
||||||
|
Self::MatchPattern(node)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Copy, Clone, Debug)]
|
#[derive(Copy, Clone, Debug)]
|
||||||
pub(crate) struct ImportFromDefinitionNodeRef<'a> {
|
pub(crate) struct ImportFromDefinitionNodeRef<'a> {
|
||||||
pub(crate) node: &'a ast::StmtImportFrom,
|
pub(crate) node: &'a ast::StmtImportFrom,
|
||||||
|
@ -153,6 +160,17 @@ pub(crate) struct ComprehensionDefinitionNodeRef<'a> {
|
||||||
pub(crate) first: bool,
|
pub(crate) first: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Copy, Clone, Debug)]
|
||||||
|
pub(crate) struct MatchPatternDefinitionNodeRef<'a> {
|
||||||
|
/// The outermost pattern node in which the identifier being defined occurs.
|
||||||
|
pub(crate) pattern: &'a ast::Pattern,
|
||||||
|
/// The identifier being defined.
|
||||||
|
pub(crate) identifier: &'a ast::Identifier,
|
||||||
|
/// The index of the identifier in the pattern when visiting the `pattern` node in evaluation
|
||||||
|
/// order.
|
||||||
|
pub(crate) index: u32,
|
||||||
|
}
|
||||||
|
|
||||||
impl DefinitionNodeRef<'_> {
|
impl DefinitionNodeRef<'_> {
|
||||||
#[allow(unsafe_code)]
|
#[allow(unsafe_code)]
|
||||||
pub(super) unsafe fn into_owned(self, parsed: ParsedModule) -> DefinitionKind {
|
pub(super) unsafe fn into_owned(self, parsed: ParsedModule) -> DefinitionKind {
|
||||||
|
@ -213,6 +231,15 @@ impl DefinitionNodeRef<'_> {
|
||||||
target: AstNodeRef::new(parsed, target),
|
target: AstNodeRef::new(parsed, target),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
DefinitionNodeRef::MatchPattern(MatchPatternDefinitionNodeRef {
|
||||||
|
pattern,
|
||||||
|
identifier,
|
||||||
|
index,
|
||||||
|
}) => DefinitionKind::MatchPattern(MatchPatternDefinitionKind {
|
||||||
|
pattern: AstNodeRef::new(parsed.clone(), pattern),
|
||||||
|
identifier: AstNodeRef::new(parsed, identifier),
|
||||||
|
index,
|
||||||
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -241,6 +268,9 @@ impl DefinitionNodeRef<'_> {
|
||||||
ast::AnyParameterRef::NonVariadic(parameter) => parameter.into(),
|
ast::AnyParameterRef::NonVariadic(parameter) => parameter.into(),
|
||||||
},
|
},
|
||||||
Self::WithItem(WithItemDefinitionNodeRef { node: _, target }) => target.into(),
|
Self::WithItem(WithItemDefinitionNodeRef { node: _, target }) => target.into(),
|
||||||
|
Self::MatchPattern(MatchPatternDefinitionNodeRef { identifier, .. }) => {
|
||||||
|
identifier.into()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -260,6 +290,25 @@ pub enum DefinitionKind {
|
||||||
Parameter(AstNodeRef<ast::Parameter>),
|
Parameter(AstNodeRef<ast::Parameter>),
|
||||||
ParameterWithDefault(AstNodeRef<ast::ParameterWithDefault>),
|
ParameterWithDefault(AstNodeRef<ast::ParameterWithDefault>),
|
||||||
WithItem(WithItemDefinitionKind),
|
WithItem(WithItemDefinitionKind),
|
||||||
|
MatchPattern(MatchPatternDefinitionKind),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub struct MatchPatternDefinitionKind {
|
||||||
|
pattern: AstNodeRef<ast::Pattern>,
|
||||||
|
identifier: AstNodeRef<ast::Identifier>,
|
||||||
|
index: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MatchPatternDefinitionKind {
|
||||||
|
pub(crate) fn pattern(&self) -> &ast::Pattern {
|
||||||
|
self.pattern.node()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn index(&self) -> u32 {
|
||||||
|
self.index
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
|
@ -410,3 +459,9 @@ impl From<&ast::ParameterWithDefault> for DefinitionNodeKey {
|
||||||
Self(NodeKey::from_node(node))
|
Self(NodeKey::from_node(node))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<&ast::Identifier> for DefinitionNodeKey {
|
||||||
|
fn from(identifier: &ast::Identifier) -> Self {
|
||||||
|
Self(NodeKey::from_identifier(identifier))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -416,6 +416,13 @@ impl<'db> TypeInferenceBuilder<'db> {
|
||||||
DefinitionKind::WithItem(with_item) => {
|
DefinitionKind::WithItem(with_item) => {
|
||||||
self.infer_with_item_definition(with_item.target(), with_item.node(), definition);
|
self.infer_with_item_definition(with_item.target(), with_item.node(), definition);
|
||||||
}
|
}
|
||||||
|
DefinitionKind::MatchPattern(match_pattern) => {
|
||||||
|
self.infer_match_pattern_definition(
|
||||||
|
match_pattern.pattern(),
|
||||||
|
match_pattern.index(),
|
||||||
|
definition,
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -795,7 +802,10 @@ impl<'db> TypeInferenceBuilder<'db> {
|
||||||
cases,
|
cases,
|
||||||
} = match_statement;
|
} = match_statement;
|
||||||
|
|
||||||
self.infer_expression(subject);
|
let expression = self.index.expression(subject.as_ref());
|
||||||
|
let result = infer_expression_types(self.db, expression);
|
||||||
|
self.extend(result);
|
||||||
|
|
||||||
for case in cases {
|
for case in cases {
|
||||||
let ast::MatchCase {
|
let ast::MatchCase {
|
||||||
range: _,
|
range: _,
|
||||||
|
@ -809,7 +819,22 @@ impl<'db> TypeInferenceBuilder<'db> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn infer_match_pattern_definition(
|
||||||
|
&mut self,
|
||||||
|
_pattern: &ast::Pattern,
|
||||||
|
_index: u32,
|
||||||
|
definition: Definition<'db>,
|
||||||
|
) {
|
||||||
|
// TODO(dhruvmanila): The correct way to infer types here is to perform structural matching
|
||||||
|
// against the subject expression type (which we can query via `infer_expression_types`)
|
||||||
|
// and extract the type at the `index` position if the pattern matches. This will be
|
||||||
|
// similar to the logic in `self.infer_assignment_definition`.
|
||||||
|
self.types.definitions.insert(definition, Type::Unknown);
|
||||||
|
}
|
||||||
|
|
||||||
fn infer_match_pattern(&mut self, pattern: &ast::Pattern) {
|
fn infer_match_pattern(&mut self, pattern: &ast::Pattern) {
|
||||||
|
// TODO(dhruvmanila): Add a Salsa query for inferring pattern types and matching against
|
||||||
|
// the subject expression: https://github.com/astral-sh/ruff/pull/13147#discussion_r1739424510
|
||||||
match pattern {
|
match pattern {
|
||||||
ast::Pattern::MatchValue(match_value) => {
|
ast::Pattern::MatchValue(match_value) => {
|
||||||
self.infer_expression(&match_value.value);
|
self.infer_expression(&match_value.value);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue