mirror of
https://github.com/astral-sh/ruff.git
synced 2025-09-30 05:45:24 +00:00
[red-knot] Narrowing for type(x) is C
checks (#14432)
## Summary Add type narrowing for `type(x) is C` conditions (and `else` clauses of `type(x) is not C` conditionals): ```py if type(x) is A: reveal_type(x) # revealed: A else: reveal_type(x) # revealed: A | B ``` closes: #14431, part of: #13694 ## Test Plan New Markdown-based tests.
This commit is contained in:
parent
3642381489
commit
d8538d8c98
2 changed files with 238 additions and 34 deletions
152
crates/red_knot_python_semantic/resources/mdtest/narrow/type.md
Normal file
152
crates/red_knot_python_semantic/resources/mdtest/narrow/type.md
Normal file
|
@ -0,0 +1,152 @@
|
||||||
|
# Narrowing for checks involving `type(x)`
|
||||||
|
|
||||||
|
## `type(x) is C`
|
||||||
|
|
||||||
|
```py
|
||||||
|
class A: ...
|
||||||
|
class B: ...
|
||||||
|
|
||||||
|
def get_a_or_b() -> A | B:
|
||||||
|
return A()
|
||||||
|
|
||||||
|
x = get_a_or_b()
|
||||||
|
|
||||||
|
if type(x) is A:
|
||||||
|
reveal_type(x) # revealed: A
|
||||||
|
else:
|
||||||
|
# It would be wrong to infer `B` here. The type
|
||||||
|
# of `x` could be a subclass of `A`, so we need
|
||||||
|
# to infer the full union type:
|
||||||
|
reveal_type(x) # revealed: A | B
|
||||||
|
```
|
||||||
|
|
||||||
|
## `type(x) is not C`
|
||||||
|
|
||||||
|
```py
|
||||||
|
class A: ...
|
||||||
|
class B: ...
|
||||||
|
|
||||||
|
def get_a_or_b() -> A | B:
|
||||||
|
return A()
|
||||||
|
|
||||||
|
x = get_a_or_b()
|
||||||
|
|
||||||
|
if type(x) is not A:
|
||||||
|
# Same reasoning as above: no narrowing should occur here.
|
||||||
|
reveal_type(x) # revealed: A | B
|
||||||
|
else:
|
||||||
|
reveal_type(x) # revealed: A
|
||||||
|
```
|
||||||
|
|
||||||
|
## `type(x) == C`, `type(x) != C`
|
||||||
|
|
||||||
|
No narrowing can occur for equality comparisons, since there might be a custom `__eq__`
|
||||||
|
implementation on the metaclass.
|
||||||
|
|
||||||
|
TODO: Narrowing might be possible in some cases where the classes themselves are `@final` or their
|
||||||
|
metaclass is `@final`.
|
||||||
|
|
||||||
|
```py
|
||||||
|
class IsEqualToEverything(type):
|
||||||
|
def __eq__(cls, other):
|
||||||
|
return True
|
||||||
|
|
||||||
|
class A(metaclass=IsEqualToEverything): ...
|
||||||
|
class B(metaclass=IsEqualToEverything): ...
|
||||||
|
|
||||||
|
def get_a_or_b() -> A | B:
|
||||||
|
return B()
|
||||||
|
|
||||||
|
x = get_a_or_b()
|
||||||
|
|
||||||
|
if type(x) == A:
|
||||||
|
reveal_type(x) # revealed: A | B
|
||||||
|
|
||||||
|
if type(x) != A:
|
||||||
|
reveal_type(x) # revealed: A | B
|
||||||
|
```
|
||||||
|
|
||||||
|
## No narrowing for custom `type` callable
|
||||||
|
|
||||||
|
```py
|
||||||
|
class A: ...
|
||||||
|
class B: ...
|
||||||
|
|
||||||
|
def type(x):
|
||||||
|
return int
|
||||||
|
|
||||||
|
def get_a_or_b() -> A | B:
|
||||||
|
return A()
|
||||||
|
|
||||||
|
x = get_a_or_b()
|
||||||
|
|
||||||
|
if type(x) is A:
|
||||||
|
reveal_type(x) # revealed: A | B
|
||||||
|
else:
|
||||||
|
reveal_type(x) # revealed: A | B
|
||||||
|
```
|
||||||
|
|
||||||
|
## No narrowing for multiple arguments
|
||||||
|
|
||||||
|
No narrowing should occur if `type` is used to dynamically create a class:
|
||||||
|
|
||||||
|
```py
|
||||||
|
def get_str_or_int() -> str | int:
|
||||||
|
return "test"
|
||||||
|
|
||||||
|
x = get_str_or_int()
|
||||||
|
|
||||||
|
if type(x, (), {}) is str:
|
||||||
|
reveal_type(x) # revealed: str | int
|
||||||
|
else:
|
||||||
|
reveal_type(x) # revealed: str | int
|
||||||
|
```
|
||||||
|
|
||||||
|
## No narrowing for keyword arguments
|
||||||
|
|
||||||
|
`type` can't be used with a keyword argument:
|
||||||
|
|
||||||
|
```py
|
||||||
|
def get_str_or_int() -> str | int:
|
||||||
|
return "test"
|
||||||
|
|
||||||
|
x = get_str_or_int()
|
||||||
|
|
||||||
|
# TODO: we could issue a diagnostic here
|
||||||
|
if type(object=x) is str:
|
||||||
|
reveal_type(x) # revealed: str | int
|
||||||
|
```
|
||||||
|
|
||||||
|
## Narrowing if `type` is aliased
|
||||||
|
|
||||||
|
```py
|
||||||
|
class A: ...
|
||||||
|
class B: ...
|
||||||
|
|
||||||
|
alias_for_type = type
|
||||||
|
|
||||||
|
def get_a_or_b() -> A | B:
|
||||||
|
return A()
|
||||||
|
|
||||||
|
x = get_a_or_b()
|
||||||
|
|
||||||
|
if alias_for_type(x) is A:
|
||||||
|
reveal_type(x) # revealed: A
|
||||||
|
```
|
||||||
|
|
||||||
|
## Limitations
|
||||||
|
|
||||||
|
```py
|
||||||
|
class Base: ...
|
||||||
|
class Derived(Base): ...
|
||||||
|
|
||||||
|
def get_base() -> Base:
|
||||||
|
return Base()
|
||||||
|
|
||||||
|
x = get_base()
|
||||||
|
|
||||||
|
if type(x) is Base:
|
||||||
|
# Ideally, this could be narrower, but there is now way to
|
||||||
|
# express a constraint like `Base & ~ProperSubtypeOf[Base]`.
|
||||||
|
reveal_type(x) # revealed: Base
|
||||||
|
```
|
|
@ -257,17 +257,26 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
||||||
expression: Expression<'db>,
|
expression: Expression<'db>,
|
||||||
is_positive: bool,
|
is_positive: bool,
|
||||||
) -> Option<NarrowingConstraints<'db>> {
|
) -> Option<NarrowingConstraints<'db>> {
|
||||||
|
fn is_narrowing_target_candidate(expr: &ast::Expr) -> bool {
|
||||||
|
matches!(expr, ast::Expr::Name(_) | ast::Expr::Call(_))
|
||||||
|
}
|
||||||
|
|
||||||
let ast::ExprCompare {
|
let ast::ExprCompare {
|
||||||
range: _,
|
range: _,
|
||||||
left,
|
left,
|
||||||
ops,
|
ops,
|
||||||
comparators,
|
comparators,
|
||||||
} = expr_compare;
|
} = expr_compare;
|
||||||
if !left.is_name_expr() && comparators.iter().all(|c| !c.is_name_expr()) {
|
|
||||||
// If none of the comparators are name expressions,
|
// Performance optimization: early return if there are no potential narrowing targets.
|
||||||
// we have no symbol to narrow down the type of.
|
if !is_narrowing_target_candidate(left)
|
||||||
|
&& comparators
|
||||||
|
.iter()
|
||||||
|
.all(|c| !is_narrowing_target_candidate(c))
|
||||||
|
{
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
if !is_positive && comparators.len() > 1 {
|
if !is_positive && comparators.len() > 1 {
|
||||||
// We can't negate a constraint made by a multi-comparator expression, since we can't
|
// We can't negate a constraint made by a multi-comparator expression, since we can't
|
||||||
// know which comparison part is the one being negated.
|
// know which comparison part is the one being negated.
|
||||||
|
@ -283,15 +292,18 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
||||||
.tuple_windows::<(&ruff_python_ast::Expr, &ruff_python_ast::Expr)>();
|
.tuple_windows::<(&ruff_python_ast::Expr, &ruff_python_ast::Expr)>();
|
||||||
let mut constraints = NarrowingConstraints::default();
|
let mut constraints = NarrowingConstraints::default();
|
||||||
for (op, (left, right)) in std::iter::zip(&**ops, comparator_tuples) {
|
for (op, (left, right)) in std::iter::zip(&**ops, comparator_tuples) {
|
||||||
if let ast::Expr::Name(ast::ExprName {
|
let rhs_ty = inference.expression_ty(right.scoped_expression_id(self.db, scope));
|
||||||
|
|
||||||
|
match left {
|
||||||
|
ast::Expr::Name(ast::ExprName {
|
||||||
range: _,
|
range: _,
|
||||||
id,
|
id,
|
||||||
ctx: _,
|
ctx: _,
|
||||||
}) = left
|
}) => {
|
||||||
{
|
let symbol = self
|
||||||
// SAFETY: we should always have a symbol for every Name node.
|
.symbols()
|
||||||
let symbol = self.symbols().symbol_id_by_name(id).unwrap();
|
.symbol_id_by_name(id)
|
||||||
let rhs_ty = inference.expression_ty(right.scoped_expression_id(self.db, scope));
|
.expect("Should always have a symbol for every Name node");
|
||||||
|
|
||||||
match if is_positive { *op } else { op.negate() } {
|
match if is_positive { *op } else { op.negate() } {
|
||||||
ast::CmpOp::IsNot => {
|
ast::CmpOp::IsNot => {
|
||||||
|
@ -320,6 +332,46 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
ast::Expr::Call(ast::ExprCall {
|
||||||
|
range: _,
|
||||||
|
func: callable,
|
||||||
|
arguments:
|
||||||
|
ast::Arguments {
|
||||||
|
args,
|
||||||
|
keywords,
|
||||||
|
range: _,
|
||||||
|
},
|
||||||
|
}) if rhs_ty.is_class_literal() && keywords.is_empty() => {
|
||||||
|
let [ast::Expr::Name(ast::ExprName { id, .. })] = &**args else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
let is_valid_constraint = if is_positive {
|
||||||
|
op == &ast::CmpOp::Is
|
||||||
|
} else {
|
||||||
|
op == &ast::CmpOp::IsNot
|
||||||
|
};
|
||||||
|
|
||||||
|
if !is_valid_constraint {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let callable_ty =
|
||||||
|
inference.expression_ty(callable.scoped_expression_id(self.db, scope));
|
||||||
|
|
||||||
|
if callable_ty
|
||||||
|
.into_class_literal()
|
||||||
|
.is_some_and(|c| c.class.is_known(self.db, KnownClass::Type))
|
||||||
|
{
|
||||||
|
let symbol = self
|
||||||
|
.symbols()
|
||||||
|
.symbol_id_by_name(id)
|
||||||
|
.expect("Should always have a symbol for every Name node");
|
||||||
|
constraints.insert(symbol, rhs_ty.to_instance(self.db));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Some(constraints)
|
Some(constraints)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue