mirror of
https://github.com/astral-sh/ruff.git
synced 2025-09-30 13:51:37 +00:00
[red-knot] Type inference for comparisons involving intersection types (#14138)
## Summary This adds type inference for comparison expressions involving intersection types. For example: ```py x = get_random_int() if x != 42: reveal_type(x == 42) # revealed: Literal[False] reveal_type(x == 43) # bool ``` closes #13854 ## Test Plan New Markdown-based tests. --------- Co-authored-by: Carl Meyer <carl@astral.sh>
This commit is contained in:
parent
4f74db5630
commit
57ba25caaf
2 changed files with 262 additions and 4 deletions
|
@ -0,0 +1,155 @@
|
||||||
|
# Comparison: Intersections
|
||||||
|
|
||||||
|
## Positive contributions
|
||||||
|
|
||||||
|
If we have an intersection type `A & B` and we get a definitive true/false answer for one of the
|
||||||
|
types, we can infer that the result for the intersection type is also true/false:
|
||||||
|
|
||||||
|
```py
|
||||||
|
class Base: ...
|
||||||
|
|
||||||
|
class Child1(Base):
|
||||||
|
def __eq__(self, other) -> Literal[True]:
|
||||||
|
return True
|
||||||
|
|
||||||
|
class Child2(Base): ...
|
||||||
|
|
||||||
|
def get_base() -> Base: ...
|
||||||
|
|
||||||
|
x = get_base()
|
||||||
|
c1 = Child1()
|
||||||
|
|
||||||
|
# Create an intersection type through narrowing:
|
||||||
|
if isinstance(x, Child1):
|
||||||
|
if isinstance(x, Child2):
|
||||||
|
reveal_type(x) # revealed: Child1 & Child2
|
||||||
|
|
||||||
|
reveal_type(x == 1) # revealed: Literal[True]
|
||||||
|
|
||||||
|
# Other comparison operators fall back to the base type:
|
||||||
|
reveal_type(x > 1) # revealed: bool
|
||||||
|
reveal_type(x is c1) # revealed: bool
|
||||||
|
```
|
||||||
|
|
||||||
|
## Negative contributions
|
||||||
|
|
||||||
|
Negative contributions to the intersection type only allow simplifications in a few special cases
|
||||||
|
(equality and identity comparisons).
|
||||||
|
|
||||||
|
### Equality comparisons
|
||||||
|
|
||||||
|
#### Literal strings
|
||||||
|
|
||||||
|
```py
|
||||||
|
x = "x" * 1_000_000_000
|
||||||
|
y = "y" * 1_000_000_000
|
||||||
|
reveal_type(x) # revealed: LiteralString
|
||||||
|
|
||||||
|
if x != "abc":
|
||||||
|
reveal_type(x) # revealed: LiteralString & ~Literal["abc"]
|
||||||
|
|
||||||
|
reveal_type(x == "abc") # revealed: Literal[False]
|
||||||
|
reveal_type("abc" == x) # revealed: Literal[False]
|
||||||
|
reveal_type(x == "something else") # revealed: bool
|
||||||
|
reveal_type("something else" == x) # revealed: bool
|
||||||
|
|
||||||
|
reveal_type(x != "abc") # revealed: Literal[True]
|
||||||
|
reveal_type("abc" != x) # revealed: Literal[True]
|
||||||
|
reveal_type(x != "something else") # revealed: bool
|
||||||
|
reveal_type("something else" != x) # revealed: bool
|
||||||
|
|
||||||
|
reveal_type(x == y) # revealed: bool
|
||||||
|
reveal_type(y == x) # revealed: bool
|
||||||
|
reveal_type(x != y) # revealed: bool
|
||||||
|
reveal_type(y != x) # revealed: bool
|
||||||
|
|
||||||
|
reveal_type(x >= "abc") # revealed: bool
|
||||||
|
reveal_type("abc" >= x) # revealed: bool
|
||||||
|
|
||||||
|
reveal_type(x in "abc") # revealed: bool
|
||||||
|
reveal_type("abc" in x) # revealed: bool
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Integers
|
||||||
|
|
||||||
|
```py
|
||||||
|
def get_int() -> int: ...
|
||||||
|
|
||||||
|
x = get_int()
|
||||||
|
|
||||||
|
if x != 1:
|
||||||
|
reveal_type(x) # revealed: int & ~Literal[1]
|
||||||
|
|
||||||
|
reveal_type(x != 1) # revealed: Literal[True]
|
||||||
|
reveal_type(x != 2) # revealed: bool
|
||||||
|
|
||||||
|
reveal_type(x == 1) # revealed: Literal[False]
|
||||||
|
reveal_type(x == 2) # revealed: bool
|
||||||
|
```
|
||||||
|
|
||||||
|
### Identity comparisons
|
||||||
|
|
||||||
|
```py
|
||||||
|
class A: ...
|
||||||
|
|
||||||
|
def get_object() -> object: ...
|
||||||
|
|
||||||
|
o = object()
|
||||||
|
|
||||||
|
a = A()
|
||||||
|
n = None
|
||||||
|
|
||||||
|
if o is not None:
|
||||||
|
reveal_type(o) # revealed: object & ~None
|
||||||
|
|
||||||
|
reveal_type(o is n) # revealed: Literal[False]
|
||||||
|
reveal_type(o is not n) # revealed: Literal[True]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Diagnostics
|
||||||
|
|
||||||
|
### Unsupported operators for positive contributions
|
||||||
|
|
||||||
|
Raise an error if any of the positive contributions to the intersection type are unsupported for the
|
||||||
|
given operator:
|
||||||
|
|
||||||
|
```py
|
||||||
|
class Container:
|
||||||
|
def __contains__(self, x) -> bool: ...
|
||||||
|
|
||||||
|
class NonContainer: ...
|
||||||
|
|
||||||
|
def get_object() -> object: ...
|
||||||
|
|
||||||
|
x = get_object()
|
||||||
|
|
||||||
|
if isinstance(x, Container):
|
||||||
|
if isinstance(x, NonContainer):
|
||||||
|
reveal_type(x) # revealed: Container & NonContainer
|
||||||
|
|
||||||
|
# error: [unsupported-operator] "Operator `in` is not supported for types `int` and `NonContainer`"
|
||||||
|
reveal_type(2 in x) # revealed: bool
|
||||||
|
```
|
||||||
|
|
||||||
|
### Unsupported operators for negative contributions
|
||||||
|
|
||||||
|
Do *not* raise an error if any of the negative contributions to the intersection type are
|
||||||
|
unsupported for the given operator:
|
||||||
|
|
||||||
|
```py
|
||||||
|
class Container:
|
||||||
|
def __contains__(self, x) -> bool: ...
|
||||||
|
|
||||||
|
class NonContainer: ...
|
||||||
|
|
||||||
|
def get_object() -> object: ...
|
||||||
|
|
||||||
|
x = get_object()
|
||||||
|
|
||||||
|
if isinstance(x, Container):
|
||||||
|
if not isinstance(x, NonContainer):
|
||||||
|
reveal_type(x) # revealed: Container & ~NonContainer
|
||||||
|
|
||||||
|
# No error here!
|
||||||
|
reveal_type(2 in x) # revealed: bool
|
||||||
|
```
|
|
@ -57,9 +57,9 @@ use crate::types::unpacker::{UnpackResult, Unpacker};
|
||||||
use crate::types::{
|
use crate::types::{
|
||||||
bindings_ty, builtins_symbol, declarations_ty, global_symbol, symbol, typing_extensions_symbol,
|
bindings_ty, builtins_symbol, declarations_ty, global_symbol, symbol, typing_extensions_symbol,
|
||||||
Boundness, BytesLiteralType, Class, ClassLiteralType, FunctionType, InstanceType,
|
Boundness, BytesLiteralType, Class, ClassLiteralType, FunctionType, InstanceType,
|
||||||
IterationOutcome, KnownClass, KnownFunction, KnownInstance, MetaclassErrorKind,
|
IntersectionBuilder, IntersectionType, IterationOutcome, KnownClass, KnownFunction,
|
||||||
SliceLiteralType, StringLiteralType, Symbol, Truthiness, TupleType, Type, TypeArrayDisplay,
|
KnownInstance, MetaclassErrorKind, SliceLiteralType, StringLiteralType, Symbol, Truthiness,
|
||||||
UnionBuilder, UnionType,
|
TupleType, Type, TypeArrayDisplay, UnionBuilder, UnionType,
|
||||||
};
|
};
|
||||||
use crate::unpack::Unpack;
|
use crate::unpack::Unpack;
|
||||||
use crate::util::subscript::{PyIndex, PySlice};
|
use crate::util::subscript::{PyIndex, PySlice};
|
||||||
|
@ -266,6 +266,13 @@ impl<'db> TypeInference<'db> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Whether the intersection type is on the left or right side of the comparison.
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
enum IntersectionOn {
|
||||||
|
Left,
|
||||||
|
Right,
|
||||||
|
}
|
||||||
|
|
||||||
/// Builder to infer all types in a region.
|
/// Builder to infer all types in a region.
|
||||||
///
|
///
|
||||||
/// A builder is used by creating it with [`new()`](TypeInferenceBuilder::new), and then calling
|
/// A builder is used by creating it with [`new()`](TypeInferenceBuilder::new), and then calling
|
||||||
|
@ -3086,7 +3093,7 @@ impl<'db> TypeInferenceBuilder<'db> {
|
||||||
|
|
||||||
// https://docs.python.org/3/reference/expressions.html#comparisons
|
// https://docs.python.org/3/reference/expressions.html#comparisons
|
||||||
// > Formally, if `a, b, c, …, y, z` are expressions and `op1, op2, …, opN` are comparison
|
// > Formally, if `a, b, c, …, y, z` are expressions and `op1, op2, …, opN` are comparison
|
||||||
// > operators, then `a op1 b op2 c ... y opN z` is equivalent to a `op1 b and b op2 c and
|
// > operators, then `a op1 b op2 c ... y opN z` is equivalent to `a op1 b and b op2 c and
|
||||||
// ... > y opN z`, except that each expression is evaluated at most once.
|
// ... > y opN z`, except that each expression is evaluated at most once.
|
||||||
//
|
//
|
||||||
// As some operators (==, !=, <, <=, >, >=) *can* return an arbitrary type, the logic below
|
// As some operators (==, !=, <, <=, >, >=) *can* return an arbitrary type, the logic below
|
||||||
|
@ -3140,6 +3147,87 @@ impl<'db> TypeInferenceBuilder<'db> {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn infer_binary_intersection_type_comparison(
|
||||||
|
&mut self,
|
||||||
|
intersection: IntersectionType<'db>,
|
||||||
|
op: ast::CmpOp,
|
||||||
|
other: Type<'db>,
|
||||||
|
intersection_on: IntersectionOn,
|
||||||
|
) -> Result<Type<'db>, CompareUnsupportedError<'db>> {
|
||||||
|
// If a comparison yields a definitive true/false answer on a (positive) part
|
||||||
|
// of an intersection type, it will also yield a definitive answer on the full
|
||||||
|
// intersection type, which is even more specific.
|
||||||
|
for pos in intersection.positive(self.db) {
|
||||||
|
let result = match intersection_on {
|
||||||
|
IntersectionOn::Left => self.infer_binary_type_comparison(*pos, op, other)?,
|
||||||
|
IntersectionOn::Right => self.infer_binary_type_comparison(other, op, *pos)?,
|
||||||
|
};
|
||||||
|
if let Type::BooleanLiteral(b) = result {
|
||||||
|
return Ok(Type::BooleanLiteral(b));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// For negative contributions to the intersection type, there are only a few
|
||||||
|
// special cases that allow us to narrow down the result type of the comparison.
|
||||||
|
for neg in intersection.negative(self.db) {
|
||||||
|
let result = match intersection_on {
|
||||||
|
IntersectionOn::Left => self.infer_binary_type_comparison(*neg, op, other).ok(),
|
||||||
|
IntersectionOn::Right => self.infer_binary_type_comparison(other, op, *neg).ok(),
|
||||||
|
};
|
||||||
|
|
||||||
|
match (op, result) {
|
||||||
|
(ast::CmpOp::Eq, Some(Type::BooleanLiteral(true))) => {
|
||||||
|
return Ok(Type::BooleanLiteral(false));
|
||||||
|
}
|
||||||
|
(ast::CmpOp::NotEq, Some(Type::BooleanLiteral(false))) => {
|
||||||
|
return Ok(Type::BooleanLiteral(true));
|
||||||
|
}
|
||||||
|
(ast::CmpOp::Is, Some(Type::BooleanLiteral(true))) => {
|
||||||
|
return Ok(Type::BooleanLiteral(false));
|
||||||
|
}
|
||||||
|
(ast::CmpOp::IsNot, Some(Type::BooleanLiteral(false))) => {
|
||||||
|
return Ok(Type::BooleanLiteral(true));
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If none of the simplifications above apply, we still need to return *some*
|
||||||
|
// result type for the comparison 'T_inter `op` T_other' (or reversed), where
|
||||||
|
//
|
||||||
|
// T_inter = P1 & P2 & ... & Pn & ~N1 & ~N2 & ... & ~Nm
|
||||||
|
//
|
||||||
|
// is the intersection type. If f(T) is the function that computes the result
|
||||||
|
// type of a `op`-comparison with `T_other`, we are interested in f(T_inter).
|
||||||
|
// Since we can't compute it exactly, we return the following approximation:
|
||||||
|
//
|
||||||
|
// f(T_inter) = f(P1) & f(P2) & ... & f(Pn)
|
||||||
|
//
|
||||||
|
// The reason for this is the following: In general, for any function 'f', the
|
||||||
|
// set f(A) & f(B) can be *larger than* the set f(A & B). This means that we
|
||||||
|
// will return a type that is too wide, which is not necessarily problematic.
|
||||||
|
//
|
||||||
|
// However, we do have to leave out the negative contributions. If we were to
|
||||||
|
// add a contribution like ~f(N1), we would potentially infer result types
|
||||||
|
// that are too narrow, since ~f(A) can be larger than f(~A).
|
||||||
|
//
|
||||||
|
// As an example for this, consider the intersection type `int & ~Literal[1]`.
|
||||||
|
// If 'f' would be the `==`-comparison with 2, we obviously can't tell if that
|
||||||
|
// answer would be true or false, so we need to return `bool`. However, if we
|
||||||
|
// compute f(int) & ~f(Literal[1]), we get `bool & ~Literal[False]`, which can
|
||||||
|
// be simplified to `Literal[True]` -- a type that is too narrow.
|
||||||
|
let mut builder = IntersectionBuilder::new(self.db);
|
||||||
|
for pos in intersection.positive(self.db) {
|
||||||
|
let result = match intersection_on {
|
||||||
|
IntersectionOn::Left => self.infer_binary_type_comparison(*pos, op, other)?,
|
||||||
|
IntersectionOn::Right => self.infer_binary_type_comparison(other, op, *pos)?,
|
||||||
|
};
|
||||||
|
builder = builder.add_positive(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(builder.build())
|
||||||
|
}
|
||||||
|
|
||||||
/// Infers the type of a binary comparison (e.g. 'left == right'). See
|
/// Infers the type of a binary comparison (e.g. 'left == right'). See
|
||||||
/// `infer_compare_expression` for the higher level logic dealing with multi-comparison
|
/// `infer_compare_expression` for the higher level logic dealing with multi-comparison
|
||||||
/// expressions.
|
/// expressions.
|
||||||
|
@ -3172,6 +3260,21 @@ impl<'db> TypeInferenceBuilder<'db> {
|
||||||
Ok(builder.build())
|
Ok(builder.build())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
(Type::Intersection(intersection), right) => self
|
||||||
|
.infer_binary_intersection_type_comparison(
|
||||||
|
intersection,
|
||||||
|
op,
|
||||||
|
right,
|
||||||
|
IntersectionOn::Left,
|
||||||
|
),
|
||||||
|
(left, Type::Intersection(intersection)) => self
|
||||||
|
.infer_binary_intersection_type_comparison(
|
||||||
|
intersection,
|
||||||
|
op,
|
||||||
|
left,
|
||||||
|
IntersectionOn::Right,
|
||||||
|
),
|
||||||
|
|
||||||
(Type::IntLiteral(n), Type::IntLiteral(m)) => match op {
|
(Type::IntLiteral(n), Type::IntLiteral(m)) => match op {
|
||||||
ast::CmpOp::Eq => Ok(Type::BooleanLiteral(n == m)),
|
ast::CmpOp::Eq => Ok(Type::BooleanLiteral(n == m)),
|
||||||
ast::CmpOp::NotEq => Ok(Type::BooleanLiteral(n != m)),
|
ast::CmpOp::NotEq => Ok(Type::BooleanLiteral(n != m)),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue