[red-knot] Add control flow support for match statement (#13241)

## Summary

This PR adds support for control flow for match statement.

It also adds the necessary infrastructure required for narrowing
constraints in case blocks and implements the logic for
`PatternMatchSingleton` which is either `None` / `True` / `False`. Even
after this the inferred type doesn't get simplified completely, there's
a TODO for that in the test code.

## Test Plan

Add test cases for control flow for (a) when there's a wildcard pattern
and (b) when there isn't. There's also a test case to verify the
narrowing logic.

---------

Co-authored-by: Carl Meyer <carl@astral.sh>
This commit is contained in:
Dhruv Manilawala 2024-09-10 02:14:19 +05:30 committed by GitHub
parent 6f53aaf931
commit 62c7d8f6ba
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 321 additions and 37 deletions

View file

@ -21,6 +21,7 @@ use crate::Db;
pub mod ast_ids;
mod builder;
pub(crate) mod constraint;
pub mod definition;
pub mod expression;
pub mod symbol;

View file

@ -26,6 +26,7 @@ use crate::semantic_index::use_def::{FlowSnapshot, UseDefMapBuilder};
use crate::semantic_index::SemanticIndex;
use crate::Db;
use super::constraint::{Constraint, PatternConstraint};
use super::definition::{MatchPatternDefinitionNodeRef, WithItemDefinitionNodeRef};
pub(super) struct SemanticIndexBuilder<'db> {
@ -204,13 +205,39 @@ impl<'db> SemanticIndexBuilder<'db> {
definition
}
fn add_constraint(&mut self, constraint_node: &ast::Expr) -> Expression<'db> {
fn add_expression_constraint(&mut self, constraint_node: &ast::Expr) -> Expression<'db> {
let expression = self.add_standalone_expression(constraint_node);
self.current_use_def_map_mut().record_constraint(expression);
self.current_use_def_map_mut()
.record_constraint(Constraint::Expression(expression));
expression
}
fn add_pattern_constraint(
&mut self,
subject: &ast::Expr,
pattern: &ast::Pattern,
) -> PatternConstraint<'db> {
#[allow(unsafe_code)]
let (subject, pattern) = unsafe {
(
AstNodeRef::new(self.module.clone(), subject),
AstNodeRef::new(self.module.clone(), pattern),
)
};
let pattern_constraint = PatternConstraint::new(
self.db,
self.file,
self.current_scope(),
subject,
pattern,
countme::Count::default(),
);
self.current_use_def_map_mut()
.record_constraint(Constraint::Pattern(pattern_constraint));
pattern_constraint
}
/// Record an expression that needs to be a Salsa ingredient, because we need to infer its type
/// standalone (type narrowing tests, RHS of an assignment.)
fn add_standalone_expression(&mut self, expression_node: &ast::Expr) -> Expression<'db> {
@ -523,7 +550,7 @@ where
ast::Stmt::If(node) => {
self.visit_expr(&node.test);
let pre_if = self.flow_snapshot();
self.add_constraint(&node.test);
self.add_expression_constraint(&node.test);
self.visit_body(&node.body);
let mut post_clauses: Vec<FlowSnapshot> = vec![];
for clause in &node.elif_else_clauses {
@ -615,9 +642,30 @@ where
}) => {
self.add_standalone_expression(subject);
self.visit_expr(subject);
for case in cases {
let after_subject = self.flow_snapshot();
let Some((first, remaining)) = cases.split_first() else {
return;
};
self.add_pattern_constraint(subject, &first.pattern);
self.visit_match_case(first);
let mut post_case_snapshots = vec![];
for case in remaining {
post_case_snapshots.push(self.flow_snapshot());
self.flow_restore(after_subject.clone());
self.add_pattern_constraint(subject, &case.pattern);
self.visit_match_case(case);
}
for post_clause_state in post_case_snapshots {
self.flow_merge(post_clause_state);
}
if !cases
.last()
.is_some_and(|case| case.guard.is_none() && case.pattern.is_wildcard())
{
self.flow_merge(after_subject);
}
}
_ => {
walk_stmt(self, stmt);

View file

@ -0,0 +1,39 @@
use ruff_db::files::File;
use ruff_python_ast as ast;
use crate::ast_node_ref::AstNodeRef;
use crate::db::Db;
use crate::semantic_index::expression::Expression;
use crate::semantic_index::symbol::{FileScopeId, ScopeId};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(crate) enum Constraint<'db> {
Expression(Expression<'db>),
Pattern(PatternConstraint<'db>),
}
#[salsa::tracked]
pub(crate) struct PatternConstraint<'db> {
#[id]
pub(crate) file: File,
#[id]
pub(crate) file_scope: FileScopeId,
#[no_eq]
#[return_ref]
pub(crate) subject: AstNodeRef<ast::Expr>,
#[no_eq]
#[return_ref]
pub(crate) pattern: AstNodeRef<ast::Pattern>,
#[no_eq]
count: countme::Count<PatternConstraint<'static>>,
}
impl<'db> PatternConstraint<'db> {
pub(crate) fn scope(self, db: &'db dyn Db) -> ScopeId<'db> {
self.file_scope(db).to_scope_id(db, self.file(db))
}
}

View file

@ -146,10 +146,11 @@ use self::symbol_state::{
};
use crate::semantic_index::ast_ids::ScopedUseId;
use crate::semantic_index::definition::Definition;
use crate::semantic_index::expression::Expression;
use crate::semantic_index::symbol::ScopedSymbolId;
use ruff_index::IndexVec;
use super::constraint::Constraint;
mod bitset;
mod symbol_state;
@ -159,8 +160,8 @@ pub(crate) struct UseDefMap<'db> {
/// Array of [`Definition`] in this scope.
all_definitions: IndexVec<ScopedDefinitionId, Definition<'db>>,
/// Array of constraints (as [`Expression`]) in this scope.
all_constraints: IndexVec<ScopedConstraintId, Expression<'db>>,
/// Array of [`Constraint`] in this scope.
all_constraints: IndexVec<ScopedConstraintId, Constraint<'db>>,
/// [`SymbolState`] visible at a [`ScopedUseId`].
definitions_by_use: IndexVec<ScopedUseId, SymbolState>,
@ -204,7 +205,7 @@ impl<'db> UseDefMap<'db> {
#[derive(Debug)]
pub(crate) struct DefinitionWithConstraintsIterator<'map, 'db> {
all_definitions: &'map IndexVec<ScopedDefinitionId, Definition<'db>>,
all_constraints: &'map IndexVec<ScopedConstraintId, Expression<'db>>,
all_constraints: &'map IndexVec<ScopedConstraintId, Constraint<'db>>,
inner: DefinitionIdWithConstraintsIterator<'map>,
}
@ -232,12 +233,12 @@ pub(crate) struct DefinitionWithConstraints<'map, 'db> {
}
pub(crate) struct ConstraintsIterator<'map, 'db> {
all_constraints: &'map IndexVec<ScopedConstraintId, Expression<'db>>,
all_constraints: &'map IndexVec<ScopedConstraintId, Constraint<'db>>,
constraint_ids: ConstraintIdIterator<'map>,
}
impl<'map, 'db> Iterator for ConstraintsIterator<'map, 'db> {
type Item = Expression<'db>;
type Item = Constraint<'db>;
fn next(&mut self) -> Option<Self::Item> {
self.constraint_ids
@ -259,8 +260,8 @@ pub(super) struct UseDefMapBuilder<'db> {
/// Append-only array of [`Definition`]; None is unbound.
all_definitions: IndexVec<ScopedDefinitionId, Definition<'db>>,
/// Append-only array of constraints (as [`Expression`]).
all_constraints: IndexVec<ScopedConstraintId, Expression<'db>>,
/// Append-only array of [`Constraint`].
all_constraints: IndexVec<ScopedConstraintId, Constraint<'db>>,
/// Visible definitions at each so-far-recorded use.
definitions_by_use: IndexVec<ScopedUseId, SymbolState>,
@ -290,7 +291,7 @@ impl<'db> UseDefMapBuilder<'db> {
self.definitions_by_symbol[symbol] = SymbolState::with(def_id);
}
pub(super) fn record_constraint(&mut self, constraint: Expression<'db>) {
pub(super) fn record_constraint(&mut self, constraint: Constraint<'db>) {
let constraint_id = self.all_constraints.push(constraint);
for definitions in &mut self.definitions_by_symbol {
definitions.add_constraint(constraint_id);

View file

@ -17,7 +17,6 @@ pub(crate) use self::builder::{IntersectionBuilder, UnionBuilder};
pub(crate) use self::diagnostic::TypeCheckDiagnostics;
pub(crate) use self::infer::{
infer_deferred_types, infer_definition_types, infer_expression_types, infer_scope_types,
TypeInference,
};
mod builder;
@ -121,8 +120,8 @@ pub(crate) fn definitions_ty<'db>(
definition,
constraints,
}| {
let mut constraint_tys =
constraints.filter_map(|test| narrowing_constraint(db, test, definition));
let mut constraint_tys = constraints
.filter_map(|constraint| narrowing_constraint(db, constraint, definition));
let definition_ty = definition_ty(db, definition);
if let Some(first_constraint_ty) = constraint_tys.next() {
let mut builder = IntersectionBuilder::new(db);

View file

@ -3500,6 +3500,65 @@ mod tests {
Ok(())
}
#[test]
fn match_with_wildcard() {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
"
match 0:
case 1:
y = 2
case _:
y = 3
",
)
.unwrap();
assert_public_ty(&db, "src/a.py", "y", "Literal[2, 3]");
}
#[test]
fn match_without_wildcard() {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
"
match 0:
case 1:
y = 2
case 2:
y = 3
",
)
.unwrap();
assert_public_ty(&db, "src/a.py", "y", "Unbound | Literal[2, 3]");
}
#[test]
fn match_stmt() {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
"
y = 1
y = 2
match 0:
case 1:
y = 3
case 2:
y = 4
",
)
.unwrap();
assert_public_ty(&db, "src/a.py", "y", "Literal[2, 3, 4]");
}
#[test]
fn import_cycle() -> anyhow::Result<()> {
let mut db = setup_db();
@ -3814,6 +3873,33 @@ mod tests {
Ok(())
}
#[test]
fn narrow_singleton_pattern() {
let mut db = setup_db();
db.write_dedented(
"/src/a.py",
"
x = None if flag else 1
y = 0
match x:
case None:
y = x
",
)
.unwrap();
// TODO: The correct inferred type should be `Literal[0] | None` but currently the
// simplification logic doesn't account for this. The final type with parenthesis:
// `Literal[0] | None | (Literal[1] & None)`
assert_public_ty(
&db,
"/src/a.py",
"y",
"Literal[0] | None | Literal[1] & None",
);
}
#[test]
fn while_loop() -> anyhow::Result<()> {
let mut db = setup_db();

View file

@ -1,9 +1,10 @@
use crate::semantic_index::ast_ids::HasScopedAstId;
use crate::semantic_index::constraint::{Constraint, PatternConstraint};
use crate::semantic_index::definition::Definition;
use crate::semantic_index::expression::Expression;
use crate::semantic_index::symbol::{ScopeId, ScopedSymbolId, SymbolTable};
use crate::semantic_index::symbol_table;
use crate::types::{infer_expression_types, IntersectionBuilder, Type, TypeInference};
use crate::types::{infer_expression_types, IntersectionBuilder, Type};
use crate::Db;
use ruff_python_ast as ast;
use rustc_hash::FxHashMap;
@ -27,62 +28,114 @@ use std::sync::Arc;
/// constraint is applied to that definition, so we'd just return `None`.
pub(crate) fn narrowing_constraint<'db>(
db: &'db dyn Db,
test: Expression<'db>,
constraint: Constraint<'db>,
definition: Definition<'db>,
) -> Option<Type<'db>> {
all_narrowing_constraints(db, test)
.get(&definition.symbol(db))
.copied()
match constraint {
Constraint::Expression(expression) => {
all_narrowing_constraints_for_expression(db, expression)
.get(&definition.symbol(db))
.copied()
}
Constraint::Pattern(pattern) => all_narrowing_constraints_for_pattern(db, pattern)
.get(&definition.symbol(db))
.copied(),
}
}
#[salsa::tracked(return_ref)]
fn all_narrowing_constraints<'db>(
fn all_narrowing_constraints_for_pattern<'db>(
db: &'db dyn Db,
test: Expression<'db>,
pattern: PatternConstraint<'db>,
) -> NarrowingConstraints<'db> {
NarrowingConstraintsBuilder::new(db, test).finish()
NarrowingConstraintsBuilder::new(db, Constraint::Pattern(pattern)).finish()
}
#[salsa::tracked(return_ref)]
fn all_narrowing_constraints_for_expression<'db>(
db: &'db dyn Db,
expression: Expression<'db>,
) -> NarrowingConstraints<'db> {
NarrowingConstraintsBuilder::new(db, Constraint::Expression(expression)).finish()
}
type NarrowingConstraints<'db> = FxHashMap<ScopedSymbolId, Type<'db>>;
struct NarrowingConstraintsBuilder<'db> {
db: &'db dyn Db,
expression: Expression<'db>,
constraint: Constraint<'db>,
constraints: NarrowingConstraints<'db>,
}
impl<'db> NarrowingConstraintsBuilder<'db> {
fn new(db: &'db dyn Db, expression: Expression<'db>) -> Self {
fn new(db: &'db dyn Db, constraint: Constraint<'db>) -> Self {
Self {
db,
expression,
constraint,
constraints: NarrowingConstraints::default(),
}
}
fn finish(mut self) -> NarrowingConstraints<'db> {
if let ast::Expr::Compare(expr_compare) = self.expression.node_ref(self.db).node() {
self.add_expr_compare(expr_compare);
match self.constraint {
Constraint::Expression(expression) => self.evaluate_expression_constraint(expression),
Constraint::Pattern(pattern) => self.evaluate_pattern_constraint(pattern),
}
// TODO other test expression kinds
self.constraints.shrink_to_fit();
self.constraints
}
fn evaluate_expression_constraint(&mut self, expression: Expression<'db>) {
if let ast::Expr::Compare(expr_compare) = expression.node_ref(self.db).node() {
self.add_expr_compare(expr_compare, expression);
}
// TODO other test expression kinds
}
fn evaluate_pattern_constraint(&mut self, pattern: PatternConstraint<'db>) {
let subject = pattern.subject(self.db);
match pattern.pattern(self.db).node() {
ast::Pattern::MatchValue(_) => {
// TODO
}
ast::Pattern::MatchSingleton(singleton_pattern) => {
self.add_match_pattern_singleton(subject, singleton_pattern);
}
ast::Pattern::MatchSequence(_) => {
// TODO
}
ast::Pattern::MatchMapping(_) => {
// TODO
}
ast::Pattern::MatchClass(_) => {
// TODO
}
ast::Pattern::MatchStar(_) => {
// TODO
}
ast::Pattern::MatchAs(_) => {
// TODO
}
ast::Pattern::MatchOr(_) => {
// TODO
}
}
}
fn symbols(&self) -> Arc<SymbolTable> {
symbol_table(self.db, self.scope())
}
fn scope(&self) -> ScopeId<'db> {
self.expression.scope(self.db)
match self.constraint {
Constraint::Expression(expression) => expression.scope(self.db),
Constraint::Pattern(pattern) => pattern.scope(self.db),
}
}
fn inference(&self) -> &'db TypeInference<'db> {
infer_expression_types(self.db, self.expression)
}
fn add_expr_compare(&mut self, expr_compare: &ast::ExprCompare) {
fn add_expr_compare(&mut self, expr_compare: &ast::ExprCompare, expression: Expression<'db>) {
let ast::ExprCompare {
range: _,
left,
@ -99,7 +152,7 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
// SAFETY: we should always have a symbol for every Name node.
let symbol = self.symbols().symbol_id_by_name(id).unwrap();
let scope = self.scope();
let inference = self.inference();
let inference = infer_expression_types(self.db, expression);
for (op, comparator) in std::iter::zip(&**ops, &**comparators) {
let comp_ty = inference.expression_ty(comparator.scoped_ast_id(self.db, scope));
if matches!(op, ast::CmpOp::IsNot) {
@ -112,4 +165,22 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
}
}
}
fn add_match_pattern_singleton(
&mut self,
subject: &ast::Expr,
pattern: &ast::PatternMatchSingleton,
) {
if let Some(ast::ExprName { id, .. }) = subject.as_name_expr() {
// SAFETY: we should always have a symbol for every Name node.
let symbol = self.symbols().symbol_id_by_name(id).unwrap();
let ty = match pattern.value {
ast::Singleton::None => Type::None,
ast::Singleton::True => Type::BooleanLiteral(true),
ast::Singleton::False => Type::BooleanLiteral(false),
};
self.constraints.insert(symbol, ty);
}
}
}

View file

@ -3124,6 +3124,29 @@ impl Pattern {
_ => false,
}
}
/// Checks if the [`Pattern`] is a [wildcard pattern].
///
/// The following are wildcard patterns:
/// ```python
/// match subject:
/// case _ as x: ...
/// case _ | _: ...
/// case _: ...
/// ```
///
/// [wildcard pattern]: https://docs.python.org/3/reference/compound_stmts.html#wildcard-patterns
pub fn is_wildcard(&self) -> bool {
match self {
Pattern::MatchAs(PatternMatchAs { pattern, .. }) => {
pattern.as_deref().map_or(true, Pattern::is_wildcard)
}
Pattern::MatchOr(PatternMatchOr { patterns, .. }) => {
patterns.iter().all(Pattern::is_wildcard)
}
_ => false,
}
}
}
/// See also [MatchValue](https://docs.python.org/3/library/ast.html#ast.MatchValue)

View file

@ -0,0 +1,16 @@
use ruff_python_parser::parse_module;
#[test]
fn pattern_is_wildcard() {
let source_code = r"
match subject:
case _ as x: ...
case _ | _: ...
case _: ...
";
let parsed = parse_module(source_code).unwrap();
let cases = &parsed.syntax().body[0].as_match_stmt().unwrap().cases;
for case in cases {
assert!(case.pattern.is_wildcard());
}
}