[red-knot] Add some narrowing for assignment expressions (#17448)

<!--
Thank you for contributing to Ruff! To help us out with reviewing,
please consider the following:

- Does this pull request include a summary of the change? (See below.)
- Does this pull request include a descriptive title?
- Does this pull request include references to any relevant issues?
-->

## Summary

Fixes #14866
Fixes #17437

## Test Plan

Update mdtests in `narrow/`
This commit is contained in:
Matthew Mckee 2025-04-18 01:28:06 +01:00 committed by GitHub
parent 9965cee998
commit edfa03a692
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 163 additions and 44 deletions

View file

@ -43,7 +43,7 @@ if True and (x := 1):
```py
def _(flag: bool):
flag or (x := 1) or reveal_type(x) # revealed: Literal[1]
flag or (x := 1) or reveal_type(x) # revealed: Never
# error: [unresolved-reference]
flag or reveal_type(y) or (y := 1) # revealed: Unknown

View file

@ -223,3 +223,15 @@ def _(x: str | None, y: str | None):
if y is not x:
reveal_type(y) # revealed: str | None
```
## Assignment expressions
```py
def f() -> bool:
return True
if x := f():
reveal_type(x) # revealed: Literal[True]
else:
reveal_type(x) # revealed: Literal[False]
```

View file

@ -47,3 +47,16 @@ def _(flag1: bool, flag2: bool):
# TODO should be Never
reveal_type(x) # revealed: Literal[1, 2]
```
## Assignment expressions
```py
def f() -> int | str | None: ...
if isinstance(x := f(), int):
reveal_type(x) # revealed: int
elif isinstance(x, str):
reveal_type(x) # revealed: str & ~int
else:
reveal_type(x) # revealed: None
```

View file

@ -78,3 +78,17 @@ def _(x: Literal[1, "a", "b", "c", "d"]):
else:
reveal_type(x) # revealed: Literal[1, "d"]
```
## Assignment expressions
```py
from typing import Literal
def f() -> Literal[1, 2, 3]:
return 1
if (x := f()) in (1,):
reveal_type(x) # revealed: Literal[1]
else:
reveal_type(x) # revealed: Literal[2, 3]
```

View file

@ -100,3 +100,16 @@ def _(flag: bool):
else:
reveal_type(x) # revealed: Literal[42]
```
## Assignment expressions
```py
from typing import Literal
def f() -> Literal[1, 2] | None: ...
if (x := f()) is None:
reveal_type(x) # revealed: None
else:
reveal_type(x) # revealed: Literal[1, 2]
```

View file

@ -82,3 +82,14 @@ def _(x_flag: bool, y_flag: bool):
reveal_type(x) # revealed: bool
reveal_type(y) # revealed: bool
```
## Assignment expressions
```py
def f() -> int | str | None: ...
if (x := f()) is not None:
reveal_type(x) # revealed: int | str
else:
reveal_type(x) # revealed: None
```

View file

@ -89,3 +89,18 @@ def _(flag1: bool, flag2: bool, a: int):
else:
reveal_type(x) # revealed: Literal[1, 2]
```
## Assignment expressions
```py
from typing import Literal
def f() -> Literal[1, 2, 3]:
return 1
if (x := f()) != 1:
reveal_type(x) # revealed: Literal[2, 3]
else:
# TODO should be Literal[1]
reveal_type(x) # revealed: Literal[1, 2, 3]
```

View file

@ -275,7 +275,8 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
self.evaluate_expression_node_predicate(&unary_op.operand, expression, !is_positive)
}
ast::Expr::BoolOp(bool_op) => self.evaluate_bool_op(bool_op, expression, is_positive),
_ => None, // TODO other test expression kinds
ast::Expr::Named(expr_named) => self.evaluate_expr_named(expr_named, is_positive),
_ => None,
}
}
@ -343,6 +344,18 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
NarrowingConstraints::from_iter([(symbol, ty)])
}
fn evaluate_expr_named(
&mut self,
expr_named: &ast::ExprNamed,
is_positive: bool,
) -> Option<NarrowingConstraints<'db>> {
if let ast::Expr::Name(expr_name) = expr_named.target.as_ref() {
Some(self.evaluate_expr_name(expr_name, is_positive))
} else {
None
}
}
fn evaluate_expr_in(&mut self, lhs_ty: Type<'db>, rhs_ty: Type<'db>) -> Option<Type<'db>> {
if lhs_ty.is_single_valued(self.db) || lhs_ty.is_union_of_single_valued(self.db) {
match rhs_ty {
@ -365,6 +378,44 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
}
}
fn evaluate_expr_compare_op(
&mut self,
lhs_ty: Type<'db>,
rhs_ty: Type<'db>,
op: ast::CmpOp,
) -> Option<Type<'db>> {
match op {
ast::CmpOp::IsNot => {
if rhs_ty.is_singleton(self.db) {
let ty = IntersectionBuilder::new(self.db)
.add_negative(rhs_ty)
.build();
Some(ty)
} else {
// Non-singletons cannot be safely narrowed using `is not`
None
}
}
ast::CmpOp::Is => Some(rhs_ty),
ast::CmpOp::NotEq => {
if rhs_ty.is_single_valued(self.db) {
let ty = IntersectionBuilder::new(self.db)
.add_negative(rhs_ty)
.build();
Some(ty)
} else {
None
}
}
ast::CmpOp::Eq if lhs_ty.is_literal_string() => Some(rhs_ty),
ast::CmpOp::In => self.evaluate_expr_in(lhs_ty, rhs_ty),
ast::CmpOp::NotIn => self
.evaluate_expr_in(lhs_ty, rhs_ty)
.map(|ty| ty.negate(self.db)),
_ => None,
}
}
fn evaluate_expr_compare(
&mut self,
expr_compare: &ast::ExprCompare,
@ -372,7 +423,10 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
is_positive: bool,
) -> Option<NarrowingConstraints<'db>> {
fn is_narrowing_target_candidate(expr: &ast::Expr) -> bool {
matches!(expr, ast::Expr::Name(_) | ast::Expr::Call(_))
matches!(
expr,
ast::Expr::Name(_) | ast::Expr::Call(_) | ast::Expr::Named(_)
)
}
let ast::ExprCompare {
@ -423,43 +477,24 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
}) => {
let symbol = self.expect_expr_name_symbol(id);
match if is_positive { *op } else { op.negate() } {
ast::CmpOp::IsNot => {
if rhs_ty.is_singleton(self.db) {
let ty = IntersectionBuilder::new(self.db)
.add_negative(rhs_ty)
.build();
constraints.insert(symbol, ty);
} else {
// Non-singletons cannot be safely narrowed using `is not`
}
}
ast::CmpOp::Is => {
constraints.insert(symbol, rhs_ty);
}
ast::CmpOp::NotEq => {
if rhs_ty.is_single_valued(self.db) {
let ty = IntersectionBuilder::new(self.db)
.add_negative(rhs_ty)
.build();
constraints.insert(symbol, ty);
}
}
ast::CmpOp::Eq if lhs_ty.is_literal_string() => {
constraints.insert(symbol, rhs_ty);
}
ast::CmpOp::In => {
if let Some(ty) = self.evaluate_expr_in(lhs_ty, rhs_ty) {
constraints.insert(symbol, ty);
}
}
ast::CmpOp::NotIn => {
if let Some(ty) = self.evaluate_expr_in(lhs_ty, rhs_ty) {
constraints.insert(symbol, ty.negate(self.db));
}
}
_ => {
// TODO other comparison types
let op = if is_positive { *op } else { op.negate() };
if let Some(ty) = self.evaluate_expr_compare_op(lhs_ty, rhs_ty, op) {
constraints.insert(symbol, ty);
}
}
ast::Expr::Named(ast::ExprNamed {
range: _,
target,
value: _,
}) => {
if let ast::Expr::Name(ast::ExprName { id, .. }) = target.as_ref() {
let symbol = self.expect_expr_name_symbol(id);
let op = if is_positive { *op } else { op.negate() };
if let Some(ty) = self.evaluate_expr_compare_op(lhs_ty, rhs_ty, op) {
constraints.insert(symbol, ty);
}
}
}
@ -535,10 +570,16 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
Type::FunctionLiteral(function_type) if expr_call.arguments.keywords.is_empty() => {
let function = function_type.known(self.db)?.into_constraint_function()?;
let [ast::Expr::Name(ast::ExprName { id, .. }), class_info] =
&*expr_call.arguments.args
else {
return None;
let (id, class_info) = match &*expr_call.arguments.args {
[first, class_info] => match first {
ast::Expr::Named(ast::ExprNamed { target, .. }) => match target.as_ref() {
ast::Expr::Name(ast::ExprName { id, .. }) => (id, class_info),
_ => return None,
},
ast::Expr::Name(ast::ExprName { id, .. }) => (id, class_info),
_ => return None,
},
_ => return None,
};
let symbol = self.expect_expr_name_symbol(id);