[red-knot] Improve chained comparisons handling (#13825)

## Summary

A small fix for comparisons of multiple comparators.
Instead of comparing each comparator to the leftmost item, we should
compare it to the closest item on the left.

While implementing this, I noticed that we don’t yet narrow Yoda
comparisons (e.g., `True is x`), so I didn’t change that behavior in
this PR.

## Test Plan

Added some mdtests 🎉
This commit is contained in:
TomerBin 2024-10-21 22:38:08 +03:00 committed by GitHub
parent e9dd92107c
commit a77512df68
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 51 additions and 12 deletions

View file

@ -6,6 +6,7 @@ use crate::semantic_index::symbol::{ScopeId, ScopedSymbolId, SymbolTable};
use crate::semantic_index::symbol_table;
use crate::types::{infer_expression_types, IntersectionBuilder, Type};
use crate::Db;
use itertools::Itertools;
use ruff_python_ast as ast;
use rustc_hash::FxHashMap;
use std::sync::Arc;
@ -142,19 +143,27 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
ops,
comparators,
} = expr_compare;
if !left.is_name_expr() && comparators.iter().all(|c| !c.is_name_expr()) {
// If none of the comparators are name expressions,
// we have no symbol to narrow down the type of.
return;
}
let scope = self.scope();
let inference = infer_expression_types(self.db, expression);
if let ast::Expr::Name(ast::ExprName {
range: _,
id,
ctx: _,
}) = left.as_ref()
{
// SAFETY: we should always have a symbol for every Name node.
let symbol = self.symbols().symbol_id_by_name(id).unwrap();
let scope = self.scope();
let inference = infer_expression_types(self.db, expression);
for (op, comparator) in std::iter::zip(ops, comparators) {
let comp_ty = inference.expression_ty(comparator.scoped_ast_id(self.db, scope));
let comparator_tuples = std::iter::once(&**left)
.chain(comparators)
.tuple_windows::<(&ruff_python_ast::Expr, &ruff_python_ast::Expr)>();
for (op, (left, right)) in std::iter::zip(&**ops, comparator_tuples) {
if let ast::Expr::Name(ast::ExprName {
range: _,
id,
ctx: _,
}) = left
{
// SAFETY: we should always have a symbol for every Name node.
let symbol = self.symbols().symbol_id_by_name(id).unwrap();
let comp_ty = inference.expression_ty(right.scoped_ast_id(self.db, scope));
match op {
ast::CmpOp::IsNot => {
if comp_ty.is_singleton() {