mirror of
https://github.com/astral-sh/ruff.git
synced 2025-08-04 02:39:12 +00:00
[red-knot] support narrowing on constants in matches (#16974)
## Summary Part of #13694 The implementation here was suspiciously straightforward so please lmk if I missed something Also some drive-by changes to DRY things up a bit ## Test Plan Add new tests to narrow/match.md --------- Co-authored-by: Carl Meyer <carl@astral.sh>
This commit is contained in:
parent
992a1af4c2
commit
3acf4e716d
2 changed files with 94 additions and 49 deletions
|
@ -64,3 +64,53 @@ match x:
|
|||
|
||||
reveal_type(x) # revealed: object
|
||||
```
|
||||
|
||||
## Value patterns
|
||||
|
||||
```py
|
||||
def get_object() -> object:
|
||||
return object()
|
||||
|
||||
x = get_object()
|
||||
|
||||
reveal_type(x) # revealed: object
|
||||
|
||||
match x:
|
||||
case "foo":
|
||||
reveal_type(x) # revealed: Literal["foo"]
|
||||
case 42:
|
||||
reveal_type(x) # revealed: Literal[42]
|
||||
case 6.0:
|
||||
reveal_type(x) # revealed: float
|
||||
case 1j:
|
||||
reveal_type(x) # revealed: complex
|
||||
case b"foo":
|
||||
reveal_type(x) # revealed: Literal[b"foo"]
|
||||
|
||||
reveal_type(x) # revealed: object
|
||||
```
|
||||
|
||||
## Value patterns with guard
|
||||
|
||||
```py
|
||||
def get_object() -> object:
|
||||
return object()
|
||||
|
||||
x = get_object()
|
||||
|
||||
reveal_type(x) # revealed: object
|
||||
|
||||
match x:
|
||||
case "foo" if reveal_type(x): # revealed: Literal["foo"]
|
||||
pass
|
||||
case 42 if reveal_type(x): # revealed: Literal[42]
|
||||
pass
|
||||
case 6.0 if reveal_type(x): # revealed: float
|
||||
pass
|
||||
case 1j if reveal_type(x): # revealed: complex
|
||||
pass
|
||||
case b"foo" if reveal_type(x): # revealed: Literal[b"foo"]
|
||||
pass
|
||||
|
||||
reveal_type(x) # revealed: object
|
||||
```
|
||||
|
|
|
@ -238,8 +238,11 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
PatternPredicateKind::Class(cls, _guard) => {
|
||||
self.evaluate_match_pattern_class(subject, *cls)
|
||||
}
|
||||
PatternPredicateKind::Value(expr, _guard) => {
|
||||
self.evaluate_match_pattern_value(subject, *expr)
|
||||
}
|
||||
// TODO: support more pattern kinds
|
||||
PatternPredicateKind::Value(..) | PatternPredicateKind::Unsupported => None,
|
||||
PatternPredicateKind::Unsupported => None,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -254,6 +257,13 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
fn expect_expr_name_symbol(&self, symbol: &str) -> ScopedSymbolId {
|
||||
self.symbols()
|
||||
.symbol_id_by_name(symbol)
|
||||
.expect("We should always have a symbol for every `Name` node")
|
||||
}
|
||||
|
||||
fn evaluate_expr_name(
|
||||
&mut self,
|
||||
expr_name: &ast::ExprName,
|
||||
|
@ -261,22 +271,15 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
) -> NarrowingConstraints<'db> {
|
||||
let ast::ExprName { id, .. } = expr_name;
|
||||
|
||||
let symbol = self
|
||||
.symbols()
|
||||
.symbol_id_by_name(id)
|
||||
.expect("Should always have a symbol for every Name node");
|
||||
let mut constraints = NarrowingConstraints::default();
|
||||
let symbol = self.expect_expr_name_symbol(id);
|
||||
|
||||
constraints.insert(
|
||||
symbol,
|
||||
if is_positive {
|
||||
Type::AlwaysFalsy.negate(self.db)
|
||||
} else {
|
||||
Type::AlwaysTruthy.negate(self.db)
|
||||
},
|
||||
);
|
||||
let ty = if is_positive {
|
||||
Type::AlwaysFalsy.negate(self.db)
|
||||
} else {
|
||||
Type::AlwaysTruthy.negate(self.db)
|
||||
};
|
||||
|
||||
constraints
|
||||
NarrowingConstraints::from_iter([(symbol, ty)])
|
||||
}
|
||||
|
||||
fn evaluate_expr_compare(
|
||||
|
@ -335,10 +338,7 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
id,
|
||||
ctx: _,
|
||||
}) => {
|
||||
let symbol = self
|
||||
.symbols()
|
||||
.symbol_id_by_name(id)
|
||||
.expect("Should always have a symbol for every Name node");
|
||||
let symbol = self.expect_expr_name_symbol(id);
|
||||
|
||||
match if is_positive { *op } else { op.negate() } {
|
||||
ast::CmpOp::IsNot => {
|
||||
|
@ -405,10 +405,7 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
.into_class_literal()
|
||||
.is_some_and(|c| c.class().is_known(self.db, KnownClass::Type))
|
||||
{
|
||||
let symbol = self
|
||||
.symbols()
|
||||
.symbol_id_by_name(id)
|
||||
.expect("Should always have a symbol for every Name node");
|
||||
let symbol = self.expect_expr_name_symbol(id);
|
||||
constraints.insert(symbol, Type::instance(rhs_class));
|
||||
}
|
||||
}
|
||||
|
@ -442,7 +439,7 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
return None;
|
||||
};
|
||||
|
||||
let symbol = self.symbols().symbol_id_by_name(id).unwrap();
|
||||
let symbol = self.expect_expr_name_symbol(id);
|
||||
|
||||
let class_info_ty =
|
||||
inference.expression_type(class_info.scoped_expression_id(self.db, scope));
|
||||
|
@ -450,9 +447,10 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
function
|
||||
.generate_constraint(self.db, class_info_ty)
|
||||
.map(|constraint| {
|
||||
let mut constraints = NarrowingConstraints::default();
|
||||
constraints.insert(symbol, constraint.negate_if(self.db, !is_positive));
|
||||
constraints
|
||||
NarrowingConstraints::from_iter([(
|
||||
symbol,
|
||||
constraint.negate_if(self.db, !is_positive),
|
||||
)])
|
||||
})
|
||||
}
|
||||
// for the expression `bool(E)`, we further narrow the type based on `E`
|
||||
|
@ -476,21 +474,14 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
subject: Expression<'db>,
|
||||
singleton: ast::Singleton,
|
||||
) -> Option<NarrowingConstraints<'db>> {
|
||||
if let Some(ast::ExprName { id, .. }) = subject.node_ref(self.db).as_name_expr() {
|
||||
// SAFETY: we should always have a symbol for every Name node.
|
||||
let symbol = self.symbols().symbol_id_by_name(id).unwrap();
|
||||
let symbol = self.expect_expr_name_symbol(&subject.node_ref(self.db).as_name_expr()?.id);
|
||||
|
||||
let ty = match singleton {
|
||||
ast::Singleton::None => Type::none(self.db),
|
||||
ast::Singleton::True => Type::BooleanLiteral(true),
|
||||
ast::Singleton::False => Type::BooleanLiteral(false),
|
||||
};
|
||||
let mut constraints = NarrowingConstraints::default();
|
||||
constraints.insert(symbol, ty);
|
||||
Some(constraints)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
let ty = match singleton {
|
||||
ast::Singleton::None => Type::none(self.db),
|
||||
ast::Singleton::True => Type::BooleanLiteral(true),
|
||||
ast::Singleton::False => Type::BooleanLiteral(false),
|
||||
};
|
||||
Some(NarrowingConstraints::from_iter([(symbol, ty)]))
|
||||
}
|
||||
|
||||
fn evaluate_match_pattern_class(
|
||||
|
@ -498,16 +489,20 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
subject: Expression<'db>,
|
||||
cls: Expression<'db>,
|
||||
) -> Option<NarrowingConstraints<'db>> {
|
||||
let ast::ExprName { id, .. } = subject.node_ref(self.db).as_name_expr()?;
|
||||
let symbol = self
|
||||
.symbols()
|
||||
.symbol_id_by_name(id)
|
||||
.expect("We should always have a symbol for every `Name` node");
|
||||
let symbol = self.expect_expr_name_symbol(&subject.node_ref(self.db).as_name_expr()?.id);
|
||||
let ty = infer_same_file_expression_type(self.db, cls).to_instance(self.db)?;
|
||||
|
||||
let mut constraints = NarrowingConstraints::default();
|
||||
constraints.insert(symbol, ty);
|
||||
Some(constraints)
|
||||
Some(NarrowingConstraints::from_iter([(symbol, ty)]))
|
||||
}
|
||||
|
||||
fn evaluate_match_pattern_value(
|
||||
&mut self,
|
||||
subject: Expression<'db>,
|
||||
value: Expression<'db>,
|
||||
) -> Option<NarrowingConstraints<'db>> {
|
||||
let symbol = self.expect_expr_name_symbol(&subject.node_ref(self.db).as_name_expr()?.id);
|
||||
let ty = infer_same_file_expression_type(self.db, value);
|
||||
Some(NarrowingConstraints::from_iter([(symbol, ty)]))
|
||||
}
|
||||
|
||||
fn evaluate_bool_op(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue