[red-knot] Type inference for comparisons involving intersection types (#14138)

## Summary

This adds type inference for comparison expressions involving
intersection types.

For example:
```py
x = get_random_int()

if x != 42:
    reveal_type(x == 42)  # revealed: Literal[False]
    reveal_type(x == 43)  # bool
```

closes #13854

## Test Plan

New Markdown-based tests.

---------

Co-authored-by: Carl Meyer <carl@astral.sh>
This commit is contained in:
David Peter 2024-11-07 20:51:14 +01:00 committed by GitHub
parent 4f74db5630
commit 57ba25caaf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 262 additions and 4 deletions

View file

@ -57,9 +57,9 @@ use crate::types::unpacker::{UnpackResult, Unpacker};
use crate::types::{
bindings_ty, builtins_symbol, declarations_ty, global_symbol, symbol, typing_extensions_symbol,
Boundness, BytesLiteralType, Class, ClassLiteralType, FunctionType, InstanceType,
IterationOutcome, KnownClass, KnownFunction, KnownInstance, MetaclassErrorKind,
SliceLiteralType, StringLiteralType, Symbol, Truthiness, TupleType, Type, TypeArrayDisplay,
UnionBuilder, UnionType,
IntersectionBuilder, IntersectionType, IterationOutcome, KnownClass, KnownFunction,
KnownInstance, MetaclassErrorKind, SliceLiteralType, StringLiteralType, Symbol, Truthiness,
TupleType, Type, TypeArrayDisplay, UnionBuilder, UnionType,
};
use crate::unpack::Unpack;
use crate::util::subscript::{PyIndex, PySlice};
@ -266,6 +266,13 @@ impl<'db> TypeInference<'db> {
}
}
/// Whether the intersection type is on the left or right side of the comparison.
#[derive(Debug, Clone, Copy)]
enum IntersectionOn {
Left,
Right,
}
/// Builder to infer all types in a region.
///
/// A builder is used by creating it with [`new()`](TypeInferenceBuilder::new), and then calling
@ -3086,7 +3093,7 @@ impl<'db> TypeInferenceBuilder<'db> {
// https://docs.python.org/3/reference/expressions.html#comparisons
// > Formally, if `a, b, c, …, y, z` are expressions and `op1, op2, …, opN` are comparison
// > operators, then `a op1 b op2 c ... y opN z` is equivalent to a `op1 b and b op2 c and
// > operators, then `a op1 b op2 c ... y opN z` is equivalent to `a op1 b and b op2 c and
// ... > y opN z`, except that each expression is evaluated at most once.
//
// As some operators (==, !=, <, <=, >, >=) *can* return an arbitrary type, the logic below
@ -3140,6 +3147,87 @@ impl<'db> TypeInferenceBuilder<'db> {
)
}
fn infer_binary_intersection_type_comparison(
&mut self,
intersection: IntersectionType<'db>,
op: ast::CmpOp,
other: Type<'db>,
intersection_on: IntersectionOn,
) -> Result<Type<'db>, CompareUnsupportedError<'db>> {
// If a comparison yields a definitive true/false answer on a (positive) part
// of an intersection type, it will also yield a definitive answer on the full
// intersection type, which is even more specific.
for pos in intersection.positive(self.db) {
let result = match intersection_on {
IntersectionOn::Left => self.infer_binary_type_comparison(*pos, op, other)?,
IntersectionOn::Right => self.infer_binary_type_comparison(other, op, *pos)?,
};
if let Type::BooleanLiteral(b) = result {
return Ok(Type::BooleanLiteral(b));
}
}
// For negative contributions to the intersection type, there are only a few
// special cases that allow us to narrow down the result type of the comparison.
for neg in intersection.negative(self.db) {
let result = match intersection_on {
IntersectionOn::Left => self.infer_binary_type_comparison(*neg, op, other).ok(),
IntersectionOn::Right => self.infer_binary_type_comparison(other, op, *neg).ok(),
};
match (op, result) {
(ast::CmpOp::Eq, Some(Type::BooleanLiteral(true))) => {
return Ok(Type::BooleanLiteral(false));
}
(ast::CmpOp::NotEq, Some(Type::BooleanLiteral(false))) => {
return Ok(Type::BooleanLiteral(true));
}
(ast::CmpOp::Is, Some(Type::BooleanLiteral(true))) => {
return Ok(Type::BooleanLiteral(false));
}
(ast::CmpOp::IsNot, Some(Type::BooleanLiteral(false))) => {
return Ok(Type::BooleanLiteral(true));
}
_ => {}
}
}
// If none of the simplifications above apply, we still need to return *some*
// result type for the comparison 'T_inter `op` T_other' (or reversed), where
//
// T_inter = P1 & P2 & ... & Pn & ~N1 & ~N2 & ... & ~Nm
//
// is the intersection type. If f(T) is the function that computes the result
// type of a `op`-comparison with `T_other`, we are interested in f(T_inter).
// Since we can't compute it exactly, we return the following approximation:
//
// f(T_inter) = f(P1) & f(P2) & ... & f(Pn)
//
// The reason for this is the following: In general, for any function 'f', the
// set f(A) & f(B) can be *larger than* the set f(A & B). This means that we
// will return a type that is too wide, which is not necessarily problematic.
//
// However, we do have to leave out the negative contributions. If we were to
// add a contribution like ~f(N1), we would potentially infer result types
// that are too narrow, since ~f(A) can be larger than f(~A).
//
// As an example for this, consider the intersection type `int & ~Literal[1]`.
// If 'f' would be the `==`-comparison with 2, we obviously can't tell if that
// answer would be true or false, so we need to return `bool`. However, if we
// compute f(int) & ~f(Literal[1]), we get `bool & ~Literal[False]`, which can
// be simplified to `Literal[True]` -- a type that is too narrow.
let mut builder = IntersectionBuilder::new(self.db);
for pos in intersection.positive(self.db) {
let result = match intersection_on {
IntersectionOn::Left => self.infer_binary_type_comparison(*pos, op, other)?,
IntersectionOn::Right => self.infer_binary_type_comparison(other, op, *pos)?,
};
builder = builder.add_positive(result);
}
Ok(builder.build())
}
/// Infers the type of a binary comparison (e.g. 'left == right'). See
/// `infer_compare_expression` for the higher level logic dealing with multi-comparison
/// expressions.
@ -3172,6 +3260,21 @@ impl<'db> TypeInferenceBuilder<'db> {
Ok(builder.build())
}
(Type::Intersection(intersection), right) => self
.infer_binary_intersection_type_comparison(
intersection,
op,
right,
IntersectionOn::Left,
),
(left, Type::Intersection(intersection)) => self
.infer_binary_intersection_type_comparison(
intersection,
op,
left,
IntersectionOn::Right,
),
(Type::IntLiteral(n), Type::IntLiteral(m)) => match op {
ast::CmpOp::Eq => Ok(Type::BooleanLiteral(n == m)),
ast::CmpOp::NotEq => Ok(Type::BooleanLiteral(n != m)),