diff --git a/crates/red_knot_python_semantic/src/semantic_index.rs b/crates/red_knot_python_semantic/src/semantic_index.rs index 7e2c5a1948..64e50dcea6 100644 --- a/crates/red_knot_python_semantic/src/semantic_index.rs +++ b/crates/red_knot_python_semantic/src/semantic_index.rs @@ -21,6 +21,7 @@ use crate::Db; pub mod ast_ids; mod builder; +pub(crate) mod constraint; pub mod definition; pub mod expression; pub mod symbol; diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index 59e514dd85..3f440a89b3 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -26,6 +26,7 @@ use crate::semantic_index::use_def::{FlowSnapshot, UseDefMapBuilder}; use crate::semantic_index::SemanticIndex; use crate::Db; +use super::constraint::{Constraint, PatternConstraint}; use super::definition::{MatchPatternDefinitionNodeRef, WithItemDefinitionNodeRef}; pub(super) struct SemanticIndexBuilder<'db> { @@ -204,13 +205,39 @@ impl<'db> SemanticIndexBuilder<'db> { definition } - fn add_constraint(&mut self, constraint_node: &ast::Expr) -> Expression<'db> { + fn add_expression_constraint(&mut self, constraint_node: &ast::Expr) -> Expression<'db> { let expression = self.add_standalone_expression(constraint_node); - self.current_use_def_map_mut().record_constraint(expression); + self.current_use_def_map_mut() + .record_constraint(Constraint::Expression(expression)); expression } + fn add_pattern_constraint( + &mut self, + subject: &ast::Expr, + pattern: &ast::Pattern, + ) -> PatternConstraint<'db> { + #[allow(unsafe_code)] + let (subject, pattern) = unsafe { + ( + AstNodeRef::new(self.module.clone(), subject), + AstNodeRef::new(self.module.clone(), pattern), + ) + }; + let pattern_constraint = PatternConstraint::new( + self.db, + self.file, + self.current_scope(), + subject, + pattern, + countme::Count::default(), + ); + self.current_use_def_map_mut() + .record_constraint(Constraint::Pattern(pattern_constraint)); + pattern_constraint + } + /// Record an expression that needs to be a Salsa ingredient, because we need to infer its type /// standalone (type narrowing tests, RHS of an assignment.) fn add_standalone_expression(&mut self, expression_node: &ast::Expr) -> Expression<'db> { @@ -523,7 +550,7 @@ where ast::Stmt::If(node) => { self.visit_expr(&node.test); let pre_if = self.flow_snapshot(); - self.add_constraint(&node.test); + self.add_expression_constraint(&node.test); self.visit_body(&node.body); let mut post_clauses: Vec = vec![]; for clause in &node.elif_else_clauses { @@ -615,9 +642,30 @@ where }) => { self.add_standalone_expression(subject); self.visit_expr(subject); - for case in cases { + + let after_subject = self.flow_snapshot(); + let Some((first, remaining)) = cases.split_first() else { + return; + }; + self.add_pattern_constraint(subject, &first.pattern); + self.visit_match_case(first); + + let mut post_case_snapshots = vec![]; + for case in remaining { + post_case_snapshots.push(self.flow_snapshot()); + self.flow_restore(after_subject.clone()); + self.add_pattern_constraint(subject, &case.pattern); self.visit_match_case(case); } + for post_clause_state in post_case_snapshots { + self.flow_merge(post_clause_state); + } + if !cases + .last() + .is_some_and(|case| case.guard.is_none() && case.pattern.is_wildcard()) + { + self.flow_merge(after_subject); + } } _ => { walk_stmt(self, stmt); diff --git a/crates/red_knot_python_semantic/src/semantic_index/constraint.rs b/crates/red_knot_python_semantic/src/semantic_index/constraint.rs new file mode 100644 index 0000000000..9659d5f82f --- /dev/null +++ b/crates/red_knot_python_semantic/src/semantic_index/constraint.rs @@ -0,0 +1,39 @@ +use ruff_db::files::File; +use ruff_python_ast as ast; + +use crate::ast_node_ref::AstNodeRef; +use crate::db::Db; +use crate::semantic_index::expression::Expression; +use crate::semantic_index::symbol::{FileScopeId, ScopeId}; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub(crate) enum Constraint<'db> { + Expression(Expression<'db>), + Pattern(PatternConstraint<'db>), +} + +#[salsa::tracked] +pub(crate) struct PatternConstraint<'db> { + #[id] + pub(crate) file: File, + + #[id] + pub(crate) file_scope: FileScopeId, + + #[no_eq] + #[return_ref] + pub(crate) subject: AstNodeRef, + + #[no_eq] + #[return_ref] + pub(crate) pattern: AstNodeRef, + + #[no_eq] + count: countme::Count>, +} + +impl<'db> PatternConstraint<'db> { + pub(crate) fn scope(self, db: &'db dyn Db) -> ScopeId<'db> { + self.file_scope(db).to_scope_id(db, self.file(db)) + } +} diff --git a/crates/red_knot_python_semantic/src/semantic_index/use_def.rs b/crates/red_knot_python_semantic/src/semantic_index/use_def.rs index 96fe0fd56d..682ee32a41 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/use_def.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/use_def.rs @@ -146,10 +146,11 @@ use self::symbol_state::{ }; use crate::semantic_index::ast_ids::ScopedUseId; use crate::semantic_index::definition::Definition; -use crate::semantic_index::expression::Expression; use crate::semantic_index::symbol::ScopedSymbolId; use ruff_index::IndexVec; +use super::constraint::Constraint; + mod bitset; mod symbol_state; @@ -159,8 +160,8 @@ pub(crate) struct UseDefMap<'db> { /// Array of [`Definition`] in this scope. all_definitions: IndexVec>, - /// Array of constraints (as [`Expression`]) in this scope. - all_constraints: IndexVec>, + /// Array of [`Constraint`] in this scope. + all_constraints: IndexVec>, /// [`SymbolState`] visible at a [`ScopedUseId`]. definitions_by_use: IndexVec, @@ -204,7 +205,7 @@ impl<'db> UseDefMap<'db> { #[derive(Debug)] pub(crate) struct DefinitionWithConstraintsIterator<'map, 'db> { all_definitions: &'map IndexVec>, - all_constraints: &'map IndexVec>, + all_constraints: &'map IndexVec>, inner: DefinitionIdWithConstraintsIterator<'map>, } @@ -232,12 +233,12 @@ pub(crate) struct DefinitionWithConstraints<'map, 'db> { } pub(crate) struct ConstraintsIterator<'map, 'db> { - all_constraints: &'map IndexVec>, + all_constraints: &'map IndexVec>, constraint_ids: ConstraintIdIterator<'map>, } impl<'map, 'db> Iterator for ConstraintsIterator<'map, 'db> { - type Item = Expression<'db>; + type Item = Constraint<'db>; fn next(&mut self) -> Option { self.constraint_ids @@ -259,8 +260,8 @@ pub(super) struct UseDefMapBuilder<'db> { /// Append-only array of [`Definition`]; None is unbound. all_definitions: IndexVec>, - /// Append-only array of constraints (as [`Expression`]). - all_constraints: IndexVec>, + /// Append-only array of [`Constraint`]. + all_constraints: IndexVec>, /// Visible definitions at each so-far-recorded use. definitions_by_use: IndexVec, @@ -290,7 +291,7 @@ impl<'db> UseDefMapBuilder<'db> { self.definitions_by_symbol[symbol] = SymbolState::with(def_id); } - pub(super) fn record_constraint(&mut self, constraint: Expression<'db>) { + pub(super) fn record_constraint(&mut self, constraint: Constraint<'db>) { let constraint_id = self.all_constraints.push(constraint); for definitions in &mut self.definitions_by_symbol { definitions.add_constraint(constraint_id); diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index dfdf263b32..252b9125b7 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -17,7 +17,6 @@ pub(crate) use self::builder::{IntersectionBuilder, UnionBuilder}; pub(crate) use self::diagnostic::TypeCheckDiagnostics; pub(crate) use self::infer::{ infer_deferred_types, infer_definition_types, infer_expression_types, infer_scope_types, - TypeInference, }; mod builder; @@ -121,8 +120,8 @@ pub(crate) fn definitions_ty<'db>( definition, constraints, }| { - let mut constraint_tys = - constraints.filter_map(|test| narrowing_constraint(db, test, definition)); + let mut constraint_tys = constraints + .filter_map(|constraint| narrowing_constraint(db, constraint, definition)); let definition_ty = definition_ty(db, definition); if let Some(first_constraint_ty) = constraint_tys.next() { let mut builder = IntersectionBuilder::new(db); diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 335be34bfa..02b0efb3b5 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -3500,6 +3500,65 @@ mod tests { Ok(()) } + #[test] + fn match_with_wildcard() { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + match 0: + case 1: + y = 2 + case _: + y = 3 +", + ) + .unwrap(); + + assert_public_ty(&db, "src/a.py", "y", "Literal[2, 3]"); + } + + #[test] + fn match_without_wildcard() { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + match 0: + case 1: + y = 2 + case 2: + y = 3 +", + ) + .unwrap(); + + assert_public_ty(&db, "src/a.py", "y", "Unbound | Literal[2, 3]"); + } + + #[test] + fn match_stmt() { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + y = 1 + y = 2 + match 0: + case 1: + y = 3 + case 2: + y = 4 +", + ) + .unwrap(); + + assert_public_ty(&db, "src/a.py", "y", "Literal[2, 3, 4]"); + } + #[test] fn import_cycle() -> anyhow::Result<()> { let mut db = setup_db(); @@ -3814,6 +3873,33 @@ mod tests { Ok(()) } + #[test] + fn narrow_singleton_pattern() { + let mut db = setup_db(); + + db.write_dedented( + "/src/a.py", + " + x = None if flag else 1 + y = 0 + match x: + case None: + y = x + ", + ) + .unwrap(); + + // TODO: The correct inferred type should be `Literal[0] | None` but currently the + // simplification logic doesn't account for this. The final type with parenthesis: + // `Literal[0] | None | (Literal[1] & None)` + assert_public_ty( + &db, + "/src/a.py", + "y", + "Literal[0] | None | Literal[1] & None", + ); + } + #[test] fn while_loop() -> anyhow::Result<()> { let mut db = setup_db(); diff --git a/crates/red_knot_python_semantic/src/types/narrow.rs b/crates/red_knot_python_semantic/src/types/narrow.rs index 381c6effa7..8ca57af116 100644 --- a/crates/red_knot_python_semantic/src/types/narrow.rs +++ b/crates/red_knot_python_semantic/src/types/narrow.rs @@ -1,9 +1,10 @@ use crate::semantic_index::ast_ids::HasScopedAstId; +use crate::semantic_index::constraint::{Constraint, PatternConstraint}; use crate::semantic_index::definition::Definition; use crate::semantic_index::expression::Expression; use crate::semantic_index::symbol::{ScopeId, ScopedSymbolId, SymbolTable}; use crate::semantic_index::symbol_table; -use crate::types::{infer_expression_types, IntersectionBuilder, Type, TypeInference}; +use crate::types::{infer_expression_types, IntersectionBuilder, Type}; use crate::Db; use ruff_python_ast as ast; use rustc_hash::FxHashMap; @@ -27,62 +28,114 @@ use std::sync::Arc; /// constraint is applied to that definition, so we'd just return `None`. pub(crate) fn narrowing_constraint<'db>( db: &'db dyn Db, - test: Expression<'db>, + constraint: Constraint<'db>, definition: Definition<'db>, ) -> Option> { - all_narrowing_constraints(db, test) - .get(&definition.symbol(db)) - .copied() + match constraint { + Constraint::Expression(expression) => { + all_narrowing_constraints_for_expression(db, expression) + .get(&definition.symbol(db)) + .copied() + } + Constraint::Pattern(pattern) => all_narrowing_constraints_for_pattern(db, pattern) + .get(&definition.symbol(db)) + .copied(), + } } #[salsa::tracked(return_ref)] -fn all_narrowing_constraints<'db>( +fn all_narrowing_constraints_for_pattern<'db>( db: &'db dyn Db, - test: Expression<'db>, + pattern: PatternConstraint<'db>, ) -> NarrowingConstraints<'db> { - NarrowingConstraintsBuilder::new(db, test).finish() + NarrowingConstraintsBuilder::new(db, Constraint::Pattern(pattern)).finish() +} + +#[salsa::tracked(return_ref)] +fn all_narrowing_constraints_for_expression<'db>( + db: &'db dyn Db, + expression: Expression<'db>, +) -> NarrowingConstraints<'db> { + NarrowingConstraintsBuilder::new(db, Constraint::Expression(expression)).finish() } type NarrowingConstraints<'db> = FxHashMap>; struct NarrowingConstraintsBuilder<'db> { db: &'db dyn Db, - expression: Expression<'db>, + constraint: Constraint<'db>, constraints: NarrowingConstraints<'db>, } impl<'db> NarrowingConstraintsBuilder<'db> { - fn new(db: &'db dyn Db, expression: Expression<'db>) -> Self { + fn new(db: &'db dyn Db, constraint: Constraint<'db>) -> Self { Self { db, - expression, + constraint, constraints: NarrowingConstraints::default(), } } fn finish(mut self) -> NarrowingConstraints<'db> { - if let ast::Expr::Compare(expr_compare) = self.expression.node_ref(self.db).node() { - self.add_expr_compare(expr_compare); + match self.constraint { + Constraint::Expression(expression) => self.evaluate_expression_constraint(expression), + Constraint::Pattern(pattern) => self.evaluate_pattern_constraint(pattern), } - // TODO other test expression kinds self.constraints.shrink_to_fit(); self.constraints } + fn evaluate_expression_constraint(&mut self, expression: Expression<'db>) { + if let ast::Expr::Compare(expr_compare) = expression.node_ref(self.db).node() { + self.add_expr_compare(expr_compare, expression); + } + // TODO other test expression kinds + } + + fn evaluate_pattern_constraint(&mut self, pattern: PatternConstraint<'db>) { + let subject = pattern.subject(self.db); + + match pattern.pattern(self.db).node() { + ast::Pattern::MatchValue(_) => { + // TODO + } + ast::Pattern::MatchSingleton(singleton_pattern) => { + self.add_match_pattern_singleton(subject, singleton_pattern); + } + ast::Pattern::MatchSequence(_) => { + // TODO + } + ast::Pattern::MatchMapping(_) => { + // TODO + } + ast::Pattern::MatchClass(_) => { + // TODO + } + ast::Pattern::MatchStar(_) => { + // TODO + } + ast::Pattern::MatchAs(_) => { + // TODO + } + ast::Pattern::MatchOr(_) => { + // TODO + } + } + } + fn symbols(&self) -> Arc { symbol_table(self.db, self.scope()) } fn scope(&self) -> ScopeId<'db> { - self.expression.scope(self.db) + match self.constraint { + Constraint::Expression(expression) => expression.scope(self.db), + Constraint::Pattern(pattern) => pattern.scope(self.db), + } } - fn inference(&self) -> &'db TypeInference<'db> { - infer_expression_types(self.db, self.expression) - } - - fn add_expr_compare(&mut self, expr_compare: &ast::ExprCompare) { + fn add_expr_compare(&mut self, expr_compare: &ast::ExprCompare, expression: Expression<'db>) { let ast::ExprCompare { range: _, left, @@ -99,7 +152,7 @@ impl<'db> NarrowingConstraintsBuilder<'db> { // SAFETY: we should always have a symbol for every Name node. let symbol = self.symbols().symbol_id_by_name(id).unwrap(); let scope = self.scope(); - let inference = self.inference(); + let inference = infer_expression_types(self.db, expression); for (op, comparator) in std::iter::zip(&**ops, &**comparators) { let comp_ty = inference.expression_ty(comparator.scoped_ast_id(self.db, scope)); if matches!(op, ast::CmpOp::IsNot) { @@ -112,4 +165,22 @@ impl<'db> NarrowingConstraintsBuilder<'db> { } } } + + fn add_match_pattern_singleton( + &mut self, + subject: &ast::Expr, + pattern: &ast::PatternMatchSingleton, + ) { + if let Some(ast::ExprName { id, .. }) = subject.as_name_expr() { + // SAFETY: we should always have a symbol for every Name node. + let symbol = self.symbols().symbol_id_by_name(id).unwrap(); + + let ty = match pattern.value { + ast::Singleton::None => Type::None, + ast::Singleton::True => Type::BooleanLiteral(true), + ast::Singleton::False => Type::BooleanLiteral(false), + }; + self.constraints.insert(symbol, ty); + } + } } diff --git a/crates/ruff_python_ast/src/nodes.rs b/crates/ruff_python_ast/src/nodes.rs index 079e9003b8..71ea0e85e7 100644 --- a/crates/ruff_python_ast/src/nodes.rs +++ b/crates/ruff_python_ast/src/nodes.rs @@ -3124,6 +3124,29 @@ impl Pattern { _ => false, } } + + /// Checks if the [`Pattern`] is a [wildcard pattern]. + /// + /// The following are wildcard patterns: + /// ```python + /// match subject: + /// case _ as x: ... + /// case _ | _: ... + /// case _: ... + /// ``` + /// + /// [wildcard pattern]: https://docs.python.org/3/reference/compound_stmts.html#wildcard-patterns + pub fn is_wildcard(&self) -> bool { + match self { + Pattern::MatchAs(PatternMatchAs { pattern, .. }) => { + pattern.as_deref().map_or(true, Pattern::is_wildcard) + } + Pattern::MatchOr(PatternMatchOr { patterns, .. }) => { + patterns.iter().all(Pattern::is_wildcard) + } + _ => false, + } + } } /// See also [MatchValue](https://docs.python.org/3/library/ast.html#ast.MatchValue) diff --git a/crates/ruff_python_ast_integration_tests/tests/match_pattern.rs b/crates/ruff_python_ast_integration_tests/tests/match_pattern.rs new file mode 100644 index 0000000000..633e8e4fd4 --- /dev/null +++ b/crates/ruff_python_ast_integration_tests/tests/match_pattern.rs @@ -0,0 +1,16 @@ +use ruff_python_parser::parse_module; + +#[test] +fn pattern_is_wildcard() { + let source_code = r" +match subject: + case _ as x: ... + case _ | _: ... + case _: ... +"; + let parsed = parse_module(source_code).unwrap(); + let cases = &parsed.syntax().body[0].as_match_stmt().unwrap().cases; + for case in cases { + assert!(case.pattern.is_wildcard()); + } +}