[red-knot] support narrowing on constants in matches (#16974)

## Summary

Part of #13694

The implementation here was suspiciously straightforward so please lmk
if I missed something

Also some drive-by changes to DRY things up a bit

## Test Plan

Add new tests to narrow/match.md

---------

Co-authored-by: Carl Meyer <carl@astral.sh>
This commit is contained in:
Eric Mark Martin 2025-03-27 22:36:51 -04:00 committed by GitHub
parent 992a1af4c2
commit 3acf4e716d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 94 additions and 49 deletions

View file

@ -64,3 +64,53 @@ match x:
reveal_type(x) # revealed: object
```
## Value patterns
```py
def get_object() -> object:
return object()
x = get_object()
reveal_type(x) # revealed: object
match x:
case "foo":
reveal_type(x) # revealed: Literal["foo"]
case 42:
reveal_type(x) # revealed: Literal[42]
case 6.0:
reveal_type(x) # revealed: float
case 1j:
reveal_type(x) # revealed: complex
case b"foo":
reveal_type(x) # revealed: Literal[b"foo"]
reveal_type(x) # revealed: object
```
## Value patterns with guard
```py
def get_object() -> object:
return object()
x = get_object()
reveal_type(x) # revealed: object
match x:
case "foo" if reveal_type(x): # revealed: Literal["foo"]
pass
case 42 if reveal_type(x): # revealed: Literal[42]
pass
case 6.0 if reveal_type(x): # revealed: float
pass
case 1j if reveal_type(x): # revealed: complex
pass
case b"foo" if reveal_type(x): # revealed: Literal[b"foo"]
pass
reveal_type(x) # revealed: object
```

View file

@ -238,8 +238,11 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
PatternPredicateKind::Class(cls, _guard) => {
self.evaluate_match_pattern_class(subject, *cls)
}
PatternPredicateKind::Value(expr, _guard) => {
self.evaluate_match_pattern_value(subject, *expr)
}
// TODO: support more pattern kinds
PatternPredicateKind::Value(..) | PatternPredicateKind::Unsupported => None,
PatternPredicateKind::Unsupported => None,
}
}
@ -254,6 +257,13 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
}
}
#[track_caller]
fn expect_expr_name_symbol(&self, symbol: &str) -> ScopedSymbolId {
self.symbols()
.symbol_id_by_name(symbol)
.expect("We should always have a symbol for every `Name` node")
}
fn evaluate_expr_name(
&mut self,
expr_name: &ast::ExprName,
@ -261,22 +271,15 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
) -> NarrowingConstraints<'db> {
let ast::ExprName { id, .. } = expr_name;
let symbol = self
.symbols()
.symbol_id_by_name(id)
.expect("Should always have a symbol for every Name node");
let mut constraints = NarrowingConstraints::default();
let symbol = self.expect_expr_name_symbol(id);
constraints.insert(
symbol,
if is_positive {
Type::AlwaysFalsy.negate(self.db)
} else {
Type::AlwaysTruthy.negate(self.db)
},
);
let ty = if is_positive {
Type::AlwaysFalsy.negate(self.db)
} else {
Type::AlwaysTruthy.negate(self.db)
};
constraints
NarrowingConstraints::from_iter([(symbol, ty)])
}
fn evaluate_expr_compare(
@ -335,10 +338,7 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
id,
ctx: _,
}) => {
let symbol = self
.symbols()
.symbol_id_by_name(id)
.expect("Should always have a symbol for every Name node");
let symbol = self.expect_expr_name_symbol(id);
match if is_positive { *op } else { op.negate() } {
ast::CmpOp::IsNot => {
@ -405,10 +405,7 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
.into_class_literal()
.is_some_and(|c| c.class().is_known(self.db, KnownClass::Type))
{
let symbol = self
.symbols()
.symbol_id_by_name(id)
.expect("Should always have a symbol for every Name node");
let symbol = self.expect_expr_name_symbol(id);
constraints.insert(symbol, Type::instance(rhs_class));
}
}
@ -442,7 +439,7 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
return None;
};
let symbol = self.symbols().symbol_id_by_name(id).unwrap();
let symbol = self.expect_expr_name_symbol(id);
let class_info_ty =
inference.expression_type(class_info.scoped_expression_id(self.db, scope));
@ -450,9 +447,10 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
function
.generate_constraint(self.db, class_info_ty)
.map(|constraint| {
let mut constraints = NarrowingConstraints::default();
constraints.insert(symbol, constraint.negate_if(self.db, !is_positive));
constraints
NarrowingConstraints::from_iter([(
symbol,
constraint.negate_if(self.db, !is_positive),
)])
})
}
// for the expression `bool(E)`, we further narrow the type based on `E`
@ -476,21 +474,14 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
subject: Expression<'db>,
singleton: ast::Singleton,
) -> Option<NarrowingConstraints<'db>> {
if let Some(ast::ExprName { id, .. }) = subject.node_ref(self.db).as_name_expr() {
// SAFETY: we should always have a symbol for every Name node.
let symbol = self.symbols().symbol_id_by_name(id).unwrap();
let symbol = self.expect_expr_name_symbol(&subject.node_ref(self.db).as_name_expr()?.id);
let ty = match singleton {
ast::Singleton::None => Type::none(self.db),
ast::Singleton::True => Type::BooleanLiteral(true),
ast::Singleton::False => Type::BooleanLiteral(false),
};
let mut constraints = NarrowingConstraints::default();
constraints.insert(symbol, ty);
Some(constraints)
} else {
None
}
let ty = match singleton {
ast::Singleton::None => Type::none(self.db),
ast::Singleton::True => Type::BooleanLiteral(true),
ast::Singleton::False => Type::BooleanLiteral(false),
};
Some(NarrowingConstraints::from_iter([(symbol, ty)]))
}
fn evaluate_match_pattern_class(
@ -498,16 +489,20 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
subject: Expression<'db>,
cls: Expression<'db>,
) -> Option<NarrowingConstraints<'db>> {
let ast::ExprName { id, .. } = subject.node_ref(self.db).as_name_expr()?;
let symbol = self
.symbols()
.symbol_id_by_name(id)
.expect("We should always have a symbol for every `Name` node");
let symbol = self.expect_expr_name_symbol(&subject.node_ref(self.db).as_name_expr()?.id);
let ty = infer_same_file_expression_type(self.db, cls).to_instance(self.db)?;
let mut constraints = NarrowingConstraints::default();
constraints.insert(symbol, ty);
Some(constraints)
Some(NarrowingConstraints::from_iter([(symbol, ty)]))
}
fn evaluate_match_pattern_value(
&mut self,
subject: Expression<'db>,
value: Expression<'db>,
) -> Option<NarrowingConstraints<'db>> {
let symbol = self.expect_expr_name_symbol(&subject.node_ref(self.db).as_name_expr()?.id);
let ty = infer_same_file_expression_type(self.db, value);
Some(NarrowingConstraints::from_iter([(symbol, ty)]))
}
fn evaluate_bool_op(