mirror of
https://github.com/astral-sh/ruff.git
synced 2025-08-04 10:48:32 +00:00
[red-knot] Add some narrowing for assignment expressions (#17448)
<!-- Thank you for contributing to Ruff! To help us out with reviewing, please consider the following: - Does this pull request include a summary of the change? (See below.) - Does this pull request include a descriptive title? - Does this pull request include references to any relevant issues? --> ## Summary Fixes #14866 Fixes #17437 ## Test Plan Update mdtests in `narrow/`
This commit is contained in:
parent
9965cee998
commit
edfa03a692
8 changed files with 163 additions and 44 deletions
|
@ -43,7 +43,7 @@ if True and (x := 1):
|
|||
|
||||
```py
|
||||
def _(flag: bool):
|
||||
flag or (x := 1) or reveal_type(x) # revealed: Literal[1]
|
||||
flag or (x := 1) or reveal_type(x) # revealed: Never
|
||||
|
||||
# error: [unresolved-reference]
|
||||
flag or reveal_type(y) or (y := 1) # revealed: Unknown
|
||||
|
|
|
@ -223,3 +223,15 @@ def _(x: str | None, y: str | None):
|
|||
if y is not x:
|
||||
reveal_type(y) # revealed: str | None
|
||||
```
|
||||
|
||||
## Assignment expressions
|
||||
|
||||
```py
|
||||
def f() -> bool:
|
||||
return True
|
||||
|
||||
if x := f():
|
||||
reveal_type(x) # revealed: Literal[True]
|
||||
else:
|
||||
reveal_type(x) # revealed: Literal[False]
|
||||
```
|
||||
|
|
|
@ -47,3 +47,16 @@ def _(flag1: bool, flag2: bool):
|
|||
# TODO should be Never
|
||||
reveal_type(x) # revealed: Literal[1, 2]
|
||||
```
|
||||
|
||||
## Assignment expressions
|
||||
|
||||
```py
|
||||
def f() -> int | str | None: ...
|
||||
|
||||
if isinstance(x := f(), int):
|
||||
reveal_type(x) # revealed: int
|
||||
elif isinstance(x, str):
|
||||
reveal_type(x) # revealed: str & ~int
|
||||
else:
|
||||
reveal_type(x) # revealed: None
|
||||
```
|
||||
|
|
|
@ -78,3 +78,17 @@ def _(x: Literal[1, "a", "b", "c", "d"]):
|
|||
else:
|
||||
reveal_type(x) # revealed: Literal[1, "d"]
|
||||
```
|
||||
|
||||
## Assignment expressions
|
||||
|
||||
```py
|
||||
from typing import Literal
|
||||
|
||||
def f() -> Literal[1, 2, 3]:
|
||||
return 1
|
||||
|
||||
if (x := f()) in (1,):
|
||||
reveal_type(x) # revealed: Literal[1]
|
||||
else:
|
||||
reveal_type(x) # revealed: Literal[2, 3]
|
||||
```
|
||||
|
|
|
@ -100,3 +100,16 @@ def _(flag: bool):
|
|||
else:
|
||||
reveal_type(x) # revealed: Literal[42]
|
||||
```
|
||||
|
||||
## Assignment expressions
|
||||
|
||||
```py
|
||||
from typing import Literal
|
||||
|
||||
def f() -> Literal[1, 2] | None: ...
|
||||
|
||||
if (x := f()) is None:
|
||||
reveal_type(x) # revealed: None
|
||||
else:
|
||||
reveal_type(x) # revealed: Literal[1, 2]
|
||||
```
|
||||
|
|
|
@ -82,3 +82,14 @@ def _(x_flag: bool, y_flag: bool):
|
|||
reveal_type(x) # revealed: bool
|
||||
reveal_type(y) # revealed: bool
|
||||
```
|
||||
|
||||
## Assignment expressions
|
||||
|
||||
```py
|
||||
def f() -> int | str | None: ...
|
||||
|
||||
if (x := f()) is not None:
|
||||
reveal_type(x) # revealed: int | str
|
||||
else:
|
||||
reveal_type(x) # revealed: None
|
||||
```
|
||||
|
|
|
@ -89,3 +89,18 @@ def _(flag1: bool, flag2: bool, a: int):
|
|||
else:
|
||||
reveal_type(x) # revealed: Literal[1, 2]
|
||||
```
|
||||
|
||||
## Assignment expressions
|
||||
|
||||
```py
|
||||
from typing import Literal
|
||||
|
||||
def f() -> Literal[1, 2, 3]:
|
||||
return 1
|
||||
|
||||
if (x := f()) != 1:
|
||||
reveal_type(x) # revealed: Literal[2, 3]
|
||||
else:
|
||||
# TODO should be Literal[1]
|
||||
reveal_type(x) # revealed: Literal[1, 2, 3]
|
||||
```
|
||||
|
|
|
@ -275,7 +275,8 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
self.evaluate_expression_node_predicate(&unary_op.operand, expression, !is_positive)
|
||||
}
|
||||
ast::Expr::BoolOp(bool_op) => self.evaluate_bool_op(bool_op, expression, is_positive),
|
||||
_ => None, // TODO other test expression kinds
|
||||
ast::Expr::Named(expr_named) => self.evaluate_expr_named(expr_named, is_positive),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -343,6 +344,18 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
NarrowingConstraints::from_iter([(symbol, ty)])
|
||||
}
|
||||
|
||||
fn evaluate_expr_named(
|
||||
&mut self,
|
||||
expr_named: &ast::ExprNamed,
|
||||
is_positive: bool,
|
||||
) -> Option<NarrowingConstraints<'db>> {
|
||||
if let ast::Expr::Name(expr_name) = expr_named.target.as_ref() {
|
||||
Some(self.evaluate_expr_name(expr_name, is_positive))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn evaluate_expr_in(&mut self, lhs_ty: Type<'db>, rhs_ty: Type<'db>) -> Option<Type<'db>> {
|
||||
if lhs_ty.is_single_valued(self.db) || lhs_ty.is_union_of_single_valued(self.db) {
|
||||
match rhs_ty {
|
||||
|
@ -365,6 +378,44 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
fn evaluate_expr_compare_op(
|
||||
&mut self,
|
||||
lhs_ty: Type<'db>,
|
||||
rhs_ty: Type<'db>,
|
||||
op: ast::CmpOp,
|
||||
) -> Option<Type<'db>> {
|
||||
match op {
|
||||
ast::CmpOp::IsNot => {
|
||||
if rhs_ty.is_singleton(self.db) {
|
||||
let ty = IntersectionBuilder::new(self.db)
|
||||
.add_negative(rhs_ty)
|
||||
.build();
|
||||
Some(ty)
|
||||
} else {
|
||||
// Non-singletons cannot be safely narrowed using `is not`
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::CmpOp::Is => Some(rhs_ty),
|
||||
ast::CmpOp::NotEq => {
|
||||
if rhs_ty.is_single_valued(self.db) {
|
||||
let ty = IntersectionBuilder::new(self.db)
|
||||
.add_negative(rhs_ty)
|
||||
.build();
|
||||
Some(ty)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ast::CmpOp::Eq if lhs_ty.is_literal_string() => Some(rhs_ty),
|
||||
ast::CmpOp::In => self.evaluate_expr_in(lhs_ty, rhs_ty),
|
||||
ast::CmpOp::NotIn => self
|
||||
.evaluate_expr_in(lhs_ty, rhs_ty)
|
||||
.map(|ty| ty.negate(self.db)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn evaluate_expr_compare(
|
||||
&mut self,
|
||||
expr_compare: &ast::ExprCompare,
|
||||
|
@ -372,7 +423,10 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
is_positive: bool,
|
||||
) -> Option<NarrowingConstraints<'db>> {
|
||||
fn is_narrowing_target_candidate(expr: &ast::Expr) -> bool {
|
||||
matches!(expr, ast::Expr::Name(_) | ast::Expr::Call(_))
|
||||
matches!(
|
||||
expr,
|
||||
ast::Expr::Name(_) | ast::Expr::Call(_) | ast::Expr::Named(_)
|
||||
)
|
||||
}
|
||||
|
||||
let ast::ExprCompare {
|
||||
|
@ -423,43 +477,24 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
}) => {
|
||||
let symbol = self.expect_expr_name_symbol(id);
|
||||
|
||||
match if is_positive { *op } else { op.negate() } {
|
||||
ast::CmpOp::IsNot => {
|
||||
if rhs_ty.is_singleton(self.db) {
|
||||
let ty = IntersectionBuilder::new(self.db)
|
||||
.add_negative(rhs_ty)
|
||||
.build();
|
||||
constraints.insert(symbol, ty);
|
||||
} else {
|
||||
// Non-singletons cannot be safely narrowed using `is not`
|
||||
}
|
||||
}
|
||||
ast::CmpOp::Is => {
|
||||
constraints.insert(symbol, rhs_ty);
|
||||
}
|
||||
ast::CmpOp::NotEq => {
|
||||
if rhs_ty.is_single_valued(self.db) {
|
||||
let ty = IntersectionBuilder::new(self.db)
|
||||
.add_negative(rhs_ty)
|
||||
.build();
|
||||
constraints.insert(symbol, ty);
|
||||
}
|
||||
}
|
||||
ast::CmpOp::Eq if lhs_ty.is_literal_string() => {
|
||||
constraints.insert(symbol, rhs_ty);
|
||||
}
|
||||
ast::CmpOp::In => {
|
||||
if let Some(ty) = self.evaluate_expr_in(lhs_ty, rhs_ty) {
|
||||
constraints.insert(symbol, ty);
|
||||
}
|
||||
}
|
||||
ast::CmpOp::NotIn => {
|
||||
if let Some(ty) = self.evaluate_expr_in(lhs_ty, rhs_ty) {
|
||||
constraints.insert(symbol, ty.negate(self.db));
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// TODO other comparison types
|
||||
let op = if is_positive { *op } else { op.negate() };
|
||||
|
||||
if let Some(ty) = self.evaluate_expr_compare_op(lhs_ty, rhs_ty, op) {
|
||||
constraints.insert(symbol, ty);
|
||||
}
|
||||
}
|
||||
ast::Expr::Named(ast::ExprNamed {
|
||||
range: _,
|
||||
target,
|
||||
value: _,
|
||||
}) => {
|
||||
if let ast::Expr::Name(ast::ExprName { id, .. }) = target.as_ref() {
|
||||
let symbol = self.expect_expr_name_symbol(id);
|
||||
|
||||
let op = if is_positive { *op } else { op.negate() };
|
||||
|
||||
if let Some(ty) = self.evaluate_expr_compare_op(lhs_ty, rhs_ty, op) {
|
||||
constraints.insert(symbol, ty);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -535,10 +570,16 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
Type::FunctionLiteral(function_type) if expr_call.arguments.keywords.is_empty() => {
|
||||
let function = function_type.known(self.db)?.into_constraint_function()?;
|
||||
|
||||
let [ast::Expr::Name(ast::ExprName { id, .. }), class_info] =
|
||||
&*expr_call.arguments.args
|
||||
else {
|
||||
return None;
|
||||
let (id, class_info) = match &*expr_call.arguments.args {
|
||||
[first, class_info] => match first {
|
||||
ast::Expr::Named(ast::ExprNamed { target, .. }) => match target.as_ref() {
|
||||
ast::Expr::Name(ast::ExprName { id, .. }) => (id, class_info),
|
||||
_ => return None,
|
||||
},
|
||||
ast::Expr::Name(ast::ExprName { id, .. }) => (id, class_info),
|
||||
_ => return None,
|
||||
},
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
let symbol = self.expect_expr_name_symbol(id);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue