diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_is.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_is.md index 3d9af15a64..ea51c2c724 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_is.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_is.md @@ -25,3 +25,17 @@ if y is x: reveal_type(y) # revealed: A | None ``` + +## `is` in chained comparisons + +```py +x = True if x_flag else False +y = True if y_flag else False + +reveal_type(x) # revealed: bool +reveal_type(y) # revealed: bool + +if y is x is False: # Interpreted as `(y is x) and (x is False)` + reveal_type(x) # revealed: Literal[False] + reveal_type(y) # revealed: bool +``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_is_not.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_is_not.md index dc094096a9..9495830883 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_is_not.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_is_not.md @@ -36,3 +36,19 @@ y = 345 if x is not y: reveal_type(x) # revealed: Literal[345] ``` + +## `is not` in chained comparisons + +The type guard removes `False` from the union type of the tested value only. + +```py +x = True if x_flag else False +y = True if y_flag else False + +reveal_type(x) # revealed: bool +reveal_type(y) # revealed: bool + +if y is not x is not False: # Interpreted as `(y is not x) and (x is not False)` + reveal_type(x) # revealed: Literal[True] + reveal_type(y) # revealed: bool +``` diff --git a/crates/red_knot_python_semantic/src/types/narrow.rs b/crates/red_knot_python_semantic/src/types/narrow.rs index 7d9d589953..a4e69aed1f 100644 --- a/crates/red_knot_python_semantic/src/types/narrow.rs +++ b/crates/red_knot_python_semantic/src/types/narrow.rs @@ -6,6 +6,7 @@ use crate::semantic_index::symbol::{ScopeId, ScopedSymbolId, SymbolTable}; use crate::semantic_index::symbol_table; use crate::types::{infer_expression_types, IntersectionBuilder, Type}; use crate::Db; +use itertools::Itertools; use ruff_python_ast as ast; use rustc_hash::FxHashMap; use std::sync::Arc; @@ -142,19 +143,27 @@ impl<'db> NarrowingConstraintsBuilder<'db> { ops, comparators, } = expr_compare; + if !left.is_name_expr() && comparators.iter().all(|c| !c.is_name_expr()) { + // If none of the comparators are name expressions, + // we have no symbol to narrow down the type of. + return; + } + let scope = self.scope(); + let inference = infer_expression_types(self.db, expression); - if let ast::Expr::Name(ast::ExprName { - range: _, - id, - ctx: _, - }) = left.as_ref() - { - // SAFETY: we should always have a symbol for every Name node. - let symbol = self.symbols().symbol_id_by_name(id).unwrap(); - let scope = self.scope(); - let inference = infer_expression_types(self.db, expression); - for (op, comparator) in std::iter::zip(ops, comparators) { - let comp_ty = inference.expression_ty(comparator.scoped_ast_id(self.db, scope)); + let comparator_tuples = std::iter::once(&**left) + .chain(comparators) + .tuple_windows::<(&ruff_python_ast::Expr, &ruff_python_ast::Expr)>(); + for (op, (left, right)) in std::iter::zip(&**ops, comparator_tuples) { + if let ast::Expr::Name(ast::ExprName { + range: _, + id, + ctx: _, + }) = left + { + // SAFETY: we should always have a symbol for every Name node. + let symbol = self.symbols().symbol_id_by_name(id).unwrap(); + let comp_ty = inference.expression_ty(right.scoped_ast_id(self.db, scope)); match op { ast::CmpOp::IsNot => { if comp_ty.is_singleton() {