mirror of
https://github.com/astral-sh/ruff.git
synced 2025-09-30 22:01:47 +00:00
[red-knot] Improve chained comparisons handling (#13825)
## Summary
A small fix for comparisons of multiple comparators.
Instead of comparing each comparator to the leftmost item, we should
compare it to the closest item on the left.
While implementing this, I noticed that we don’t yet narrow Yoda
comparisons (e.g., `True is x`), so I didn’t change that behavior in
this PR.
## Test Plan
Added some mdtests 🎉
This commit is contained in:
parent
e9dd92107c
commit
a77512df68
3 changed files with 51 additions and 12 deletions
|
@ -25,3 +25,17 @@ if y is x:
|
||||||
|
|
||||||
reveal_type(y) # revealed: A | None
|
reveal_type(y) # revealed: A | None
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## `is` in chained comparisons
|
||||||
|
|
||||||
|
```py
|
||||||
|
x = True if x_flag else False
|
||||||
|
y = True if y_flag else False
|
||||||
|
|
||||||
|
reveal_type(x) # revealed: bool
|
||||||
|
reveal_type(y) # revealed: bool
|
||||||
|
|
||||||
|
if y is x is False: # Interpreted as `(y is x) and (x is False)`
|
||||||
|
reveal_type(x) # revealed: Literal[False]
|
||||||
|
reveal_type(y) # revealed: bool
|
||||||
|
```
|
||||||
|
|
|
@ -36,3 +36,19 @@ y = 345
|
||||||
if x is not y:
|
if x is not y:
|
||||||
reveal_type(x) # revealed: Literal[345]
|
reveal_type(x) # revealed: Literal[345]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## `is not` in chained comparisons
|
||||||
|
|
||||||
|
The type guard removes `False` from the union type of the tested value only.
|
||||||
|
|
||||||
|
```py
|
||||||
|
x = True if x_flag else False
|
||||||
|
y = True if y_flag else False
|
||||||
|
|
||||||
|
reveal_type(x) # revealed: bool
|
||||||
|
reveal_type(y) # revealed: bool
|
||||||
|
|
||||||
|
if y is not x is not False: # Interpreted as `(y is not x) and (x is not False)`
|
||||||
|
reveal_type(x) # revealed: Literal[True]
|
||||||
|
reveal_type(y) # revealed: bool
|
||||||
|
```
|
||||||
|
|
|
@ -6,6 +6,7 @@ use crate::semantic_index::symbol::{ScopeId, ScopedSymbolId, SymbolTable};
|
||||||
use crate::semantic_index::symbol_table;
|
use crate::semantic_index::symbol_table;
|
||||||
use crate::types::{infer_expression_types, IntersectionBuilder, Type};
|
use crate::types::{infer_expression_types, IntersectionBuilder, Type};
|
||||||
use crate::Db;
|
use crate::Db;
|
||||||
|
use itertools::Itertools;
|
||||||
use ruff_python_ast as ast;
|
use ruff_python_ast as ast;
|
||||||
use rustc_hash::FxHashMap;
|
use rustc_hash::FxHashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
@ -142,19 +143,27 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
||||||
ops,
|
ops,
|
||||||
comparators,
|
comparators,
|
||||||
} = expr_compare;
|
} = expr_compare;
|
||||||
|
if !left.is_name_expr() && comparators.iter().all(|c| !c.is_name_expr()) {
|
||||||
|
// If none of the comparators are name expressions,
|
||||||
|
// we have no symbol to narrow down the type of.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let scope = self.scope();
|
||||||
|
let inference = infer_expression_types(self.db, expression);
|
||||||
|
|
||||||
if let ast::Expr::Name(ast::ExprName {
|
let comparator_tuples = std::iter::once(&**left)
|
||||||
range: _,
|
.chain(comparators)
|
||||||
id,
|
.tuple_windows::<(&ruff_python_ast::Expr, &ruff_python_ast::Expr)>();
|
||||||
ctx: _,
|
for (op, (left, right)) in std::iter::zip(&**ops, comparator_tuples) {
|
||||||
}) = left.as_ref()
|
if let ast::Expr::Name(ast::ExprName {
|
||||||
{
|
range: _,
|
||||||
// SAFETY: we should always have a symbol for every Name node.
|
id,
|
||||||
let symbol = self.symbols().symbol_id_by_name(id).unwrap();
|
ctx: _,
|
||||||
let scope = self.scope();
|
}) = left
|
||||||
let inference = infer_expression_types(self.db, expression);
|
{
|
||||||
for (op, comparator) in std::iter::zip(ops, comparators) {
|
// SAFETY: we should always have a symbol for every Name node.
|
||||||
let comp_ty = inference.expression_ty(comparator.scoped_ast_id(self.db, scope));
|
let symbol = self.symbols().symbol_id_by_name(id).unwrap();
|
||||||
|
let comp_ty = inference.expression_ty(right.scoped_ast_id(self.db, scope));
|
||||||
match op {
|
match op {
|
||||||
ast::CmpOp::IsNot => {
|
ast::CmpOp::IsNot => {
|
||||||
if comp_ty.is_singleton() {
|
if comp_ty.is_singleton() {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue