[red-knot] Improve chained comparisons handling (#13825)

## Summary

A small fix for comparisons of multiple comparators.
Instead of comparing each comparator to the leftmost item, we should
compare it to the closest item on the left.

While implementing this, I noticed that we don’t yet narrow Yoda
comparisons (e.g., `True is x`), so I didn’t change that behavior in
this PR.

## Test Plan

Added some mdtests 🎉
This commit is contained in:
TomerBin 2024-10-21 22:38:08 +03:00 committed by GitHub
parent e9dd92107c
commit a77512df68
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 51 additions and 12 deletions

View file

@ -25,3 +25,17 @@ if y is x:
reveal_type(y) # revealed: A | None reveal_type(y) # revealed: A | None
``` ```
## `is` in chained comparisons
```py
x = True if x_flag else False
y = True if y_flag else False
reveal_type(x) # revealed: bool
reveal_type(y) # revealed: bool
if y is x is False: # Interpreted as `(y is x) and (x is False)`
reveal_type(x) # revealed: Literal[False]
reveal_type(y) # revealed: bool
```

View file

@ -36,3 +36,19 @@ y = 345
if x is not y: if x is not y:
reveal_type(x) # revealed: Literal[345] reveal_type(x) # revealed: Literal[345]
``` ```
## `is not` in chained comparisons
The type guard removes `False` from the union type of the tested value only.
```py
x = True if x_flag else False
y = True if y_flag else False
reveal_type(x) # revealed: bool
reveal_type(y) # revealed: bool
if y is not x is not False: # Interpreted as `(y is not x) and (x is not False)`
reveal_type(x) # revealed: Literal[True]
reveal_type(y) # revealed: bool
```

View file

@ -6,6 +6,7 @@ use crate::semantic_index::symbol::{ScopeId, ScopedSymbolId, SymbolTable};
use crate::semantic_index::symbol_table; use crate::semantic_index::symbol_table;
use crate::types::{infer_expression_types, IntersectionBuilder, Type}; use crate::types::{infer_expression_types, IntersectionBuilder, Type};
use crate::Db; use crate::Db;
use itertools::Itertools;
use ruff_python_ast as ast; use ruff_python_ast as ast;
use rustc_hash::FxHashMap; use rustc_hash::FxHashMap;
use std::sync::Arc; use std::sync::Arc;
@ -142,19 +143,27 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
ops, ops,
comparators, comparators,
} = expr_compare; } = expr_compare;
if !left.is_name_expr() && comparators.iter().all(|c| !c.is_name_expr()) {
// If none of the comparators are name expressions,
// we have no symbol to narrow down the type of.
return;
}
let scope = self.scope();
let inference = infer_expression_types(self.db, expression);
if let ast::Expr::Name(ast::ExprName { let comparator_tuples = std::iter::once(&**left)
range: _, .chain(comparators)
id, .tuple_windows::<(&ruff_python_ast::Expr, &ruff_python_ast::Expr)>();
ctx: _, for (op, (left, right)) in std::iter::zip(&**ops, comparator_tuples) {
}) = left.as_ref() if let ast::Expr::Name(ast::ExprName {
{ range: _,
// SAFETY: we should always have a symbol for every Name node. id,
let symbol = self.symbols().symbol_id_by_name(id).unwrap(); ctx: _,
let scope = self.scope(); }) = left
let inference = infer_expression_types(self.db, expression); {
for (op, comparator) in std::iter::zip(ops, comparators) { // SAFETY: we should always have a symbol for every Name node.
let comp_ty = inference.expression_ty(comparator.scoped_ast_id(self.db, scope)); let symbol = self.symbols().symbol_id_by_name(id).unwrap();
let comp_ty = inference.expression_ty(right.scoped_ast_id(self.db, scope));
match op { match op {
ast::CmpOp::IsNot => { ast::CmpOp::IsNot => {
if comp_ty.is_singleton() { if comp_ty.is_singleton() {