mirror of
https://github.com/astral-sh/ruff.git
synced 2025-09-30 13:51:37 +00:00
[red-knot] Improve chained comparisons handling (#13825)
## Summary
A small fix for comparisons of multiple comparators.
Instead of comparing each comparator to the leftmost item, we should
compare it to the closest item on the left.
While implementing this, I noticed that we don’t yet narrow Yoda
comparisons (e.g., `True is x`), so I didn’t change that behavior in
this PR.
## Test Plan
Added some mdtests 🎉
This commit is contained in:
parent
e9dd92107c
commit
a77512df68
3 changed files with 51 additions and 12 deletions
|
@ -6,6 +6,7 @@ use crate::semantic_index::symbol::{ScopeId, ScopedSymbolId, SymbolTable};
|
|||
use crate::semantic_index::symbol_table;
|
||||
use crate::types::{infer_expression_types, IntersectionBuilder, Type};
|
||||
use crate::Db;
|
||||
use itertools::Itertools;
|
||||
use ruff_python_ast as ast;
|
||||
use rustc_hash::FxHashMap;
|
||||
use std::sync::Arc;
|
||||
|
@ -142,19 +143,27 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
ops,
|
||||
comparators,
|
||||
} = expr_compare;
|
||||
if !left.is_name_expr() && comparators.iter().all(|c| !c.is_name_expr()) {
|
||||
// If none of the comparators are name expressions,
|
||||
// we have no symbol to narrow down the type of.
|
||||
return;
|
||||
}
|
||||
let scope = self.scope();
|
||||
let inference = infer_expression_types(self.db, expression);
|
||||
|
||||
if let ast::Expr::Name(ast::ExprName {
|
||||
range: _,
|
||||
id,
|
||||
ctx: _,
|
||||
}) = left.as_ref()
|
||||
{
|
||||
// SAFETY: we should always have a symbol for every Name node.
|
||||
let symbol = self.symbols().symbol_id_by_name(id).unwrap();
|
||||
let scope = self.scope();
|
||||
let inference = infer_expression_types(self.db, expression);
|
||||
for (op, comparator) in std::iter::zip(ops, comparators) {
|
||||
let comp_ty = inference.expression_ty(comparator.scoped_ast_id(self.db, scope));
|
||||
let comparator_tuples = std::iter::once(&**left)
|
||||
.chain(comparators)
|
||||
.tuple_windows::<(&ruff_python_ast::Expr, &ruff_python_ast::Expr)>();
|
||||
for (op, (left, right)) in std::iter::zip(&**ops, comparator_tuples) {
|
||||
if let ast::Expr::Name(ast::ExprName {
|
||||
range: _,
|
||||
id,
|
||||
ctx: _,
|
||||
}) = left
|
||||
{
|
||||
// SAFETY: we should always have a symbol for every Name node.
|
||||
let symbol = self.symbols().symbol_id_by_name(id).unwrap();
|
||||
let comp_ty = inference.expression_ty(right.scoped_ast_id(self.db, scope));
|
||||
match op {
|
||||
ast::CmpOp::IsNot => {
|
||||
if comp_ty.is_singleton() {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue