mirror of
https://github.com/astral-sh/ruff.git
synced 2025-09-29 13:25:17 +00:00
[red-knot] Narrowing for type(x) is C
checks (#14432)
## Summary Add type narrowing for `type(x) is C` conditions (and `else` clauses of `type(x) is not C` conditionals): ```py if type(x) is A: reveal_type(x) # revealed: A else: reveal_type(x) # revealed: A | B ``` closes: #14431, part of: #13694 ## Test Plan New Markdown-based tests.
This commit is contained in:
parent
3642381489
commit
d8538d8c98
2 changed files with 238 additions and 34 deletions
|
@ -257,17 +257,26 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
expression: Expression<'db>,
|
||||
is_positive: bool,
|
||||
) -> Option<NarrowingConstraints<'db>> {
|
||||
fn is_narrowing_target_candidate(expr: &ast::Expr) -> bool {
|
||||
matches!(expr, ast::Expr::Name(_) | ast::Expr::Call(_))
|
||||
}
|
||||
|
||||
let ast::ExprCompare {
|
||||
range: _,
|
||||
left,
|
||||
ops,
|
||||
comparators,
|
||||
} = expr_compare;
|
||||
if !left.is_name_expr() && comparators.iter().all(|c| !c.is_name_expr()) {
|
||||
// If none of the comparators are name expressions,
|
||||
// we have no symbol to narrow down the type of.
|
||||
|
||||
// Performance optimization: early return if there are no potential narrowing targets.
|
||||
if !is_narrowing_target_candidate(left)
|
||||
&& comparators
|
||||
.iter()
|
||||
.all(|c| !is_narrowing_target_candidate(c))
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
if !is_positive && comparators.len() > 1 {
|
||||
// We can't negate a constraint made by a multi-comparator expression, since we can't
|
||||
// know which comparison part is the one being negated.
|
||||
|
@ -283,42 +292,85 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
.tuple_windows::<(&ruff_python_ast::Expr, &ruff_python_ast::Expr)>();
|
||||
let mut constraints = NarrowingConstraints::default();
|
||||
for (op, (left, right)) in std::iter::zip(&**ops, comparator_tuples) {
|
||||
if let ast::Expr::Name(ast::ExprName {
|
||||
range: _,
|
||||
id,
|
||||
ctx: _,
|
||||
}) = left
|
||||
{
|
||||
// SAFETY: we should always have a symbol for every Name node.
|
||||
let symbol = self.symbols().symbol_id_by_name(id).unwrap();
|
||||
let rhs_ty = inference.expression_ty(right.scoped_expression_id(self.db, scope));
|
||||
let rhs_ty = inference.expression_ty(right.scoped_expression_id(self.db, scope));
|
||||
|
||||
match if is_positive { *op } else { op.negate() } {
|
||||
ast::CmpOp::IsNot => {
|
||||
if rhs_ty.is_singleton(self.db) {
|
||||
let ty = IntersectionBuilder::new(self.db)
|
||||
.add_negative(rhs_ty)
|
||||
.build();
|
||||
constraints.insert(symbol, ty);
|
||||
} else {
|
||||
// Non-singletons cannot be safely narrowed using `is not`
|
||||
match left {
|
||||
ast::Expr::Name(ast::ExprName {
|
||||
range: _,
|
||||
id,
|
||||
ctx: _,
|
||||
}) => {
|
||||
let symbol = self
|
||||
.symbols()
|
||||
.symbol_id_by_name(id)
|
||||
.expect("Should always have a symbol for every Name node");
|
||||
|
||||
match if is_positive { *op } else { op.negate() } {
|
||||
ast::CmpOp::IsNot => {
|
||||
if rhs_ty.is_singleton(self.db) {
|
||||
let ty = IntersectionBuilder::new(self.db)
|
||||
.add_negative(rhs_ty)
|
||||
.build();
|
||||
constraints.insert(symbol, ty);
|
||||
} else {
|
||||
// Non-singletons cannot be safely narrowed using `is not`
|
||||
}
|
||||
}
|
||||
}
|
||||
ast::CmpOp::Is => {
|
||||
constraints.insert(symbol, rhs_ty);
|
||||
}
|
||||
ast::CmpOp::NotEq => {
|
||||
if rhs_ty.is_single_valued(self.db) {
|
||||
let ty = IntersectionBuilder::new(self.db)
|
||||
.add_negative(rhs_ty)
|
||||
.build();
|
||||
constraints.insert(symbol, ty);
|
||||
ast::CmpOp::Is => {
|
||||
constraints.insert(symbol, rhs_ty);
|
||||
}
|
||||
ast::CmpOp::NotEq => {
|
||||
if rhs_ty.is_single_valued(self.db) {
|
||||
let ty = IntersectionBuilder::new(self.db)
|
||||
.add_negative(rhs_ty)
|
||||
.build();
|
||||
constraints.insert(symbol, ty);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// TODO other comparison types
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// TODO other comparison types
|
||||
}
|
||||
}
|
||||
ast::Expr::Call(ast::ExprCall {
|
||||
range: _,
|
||||
func: callable,
|
||||
arguments:
|
||||
ast::Arguments {
|
||||
args,
|
||||
keywords,
|
||||
range: _,
|
||||
},
|
||||
}) if rhs_ty.is_class_literal() && keywords.is_empty() => {
|
||||
let [ast::Expr::Name(ast::ExprName { id, .. })] = &**args else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let is_valid_constraint = if is_positive {
|
||||
op == &ast::CmpOp::Is
|
||||
} else {
|
||||
op == &ast::CmpOp::IsNot
|
||||
};
|
||||
|
||||
if !is_valid_constraint {
|
||||
continue;
|
||||
}
|
||||
|
||||
let callable_ty =
|
||||
inference.expression_ty(callable.scoped_expression_id(self.db, scope));
|
||||
|
||||
if callable_ty
|
||||
.into_class_literal()
|
||||
.is_some_and(|c| c.class.is_known(self.db, KnownClass::Type))
|
||||
{
|
||||
let symbol = self
|
||||
.symbols()
|
||||
.symbol_id_by_name(id)
|
||||
.expect("Should always have a symbol for every Name node");
|
||||
constraints.insert(symbol, rhs_ty.to_instance(self.db));
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
Some(constraints)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue