[ty] Fix false positives when subscripting an object inferred as having an Intersection type (#18920)

This commit is contained in:
Alex Waygood 2025-06-24 19:39:02 +01:00 committed by GitHub
parent 3220242dec
commit e44c489273
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 52 additions and 18 deletions

View file

@ -84,3 +84,21 @@ def _(flag: bool):
reveal_type(a) # revealed: str | Unknown reveal_type(a) # revealed: str | Unknown
``` ```
## Intersection of nominal-instance types
If a subscript operation could succeed for *any* positive element of an intersection, no diagnostic
should be reported even if it would not succeed for some other element of the intersection.
```py
class Foo: ...
class Bar:
def __getitem__(self, key: str) -> int:
return 42
def f(x: Foo):
if isinstance(x, Bar):
# TODO: should be `int`
reveal_type(x["whatever"]) # revealed: @Todo(Subscript expressions on intersections)
```

View file

@ -210,8 +210,12 @@ def test3(val: tuple[str] | tuple[int] | int):
### Intersection subscript access ### Intersection subscript access
```py ```py
from ty_extensions import Intersection, Not from ty_extensions import Intersection
def test4(val: Intersection[tuple[str], tuple[int]]): class Foo: ...
reveal_type(val[0]) # revealed: str & int class Bar: ...
def test4(val: Intersection[tuple[Foo], tuple[Bar]]):
# TODO: should be `Foo & Bar`
reveal_type(val[0]) # revealed: @Todo(Subscript expressions on intersections)
``` ```

View file

@ -8140,22 +8140,19 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
slice_ty, slice_ty,
) )
} }
// If the value type is a union make sure to union the load types.
// For example: (Type::Union(union), _, _) => union.map(self.db(), |element| {
// val: tuple[int] | tuple[str] self.infer_subscript_expression_types(value_node, *element, slice_ty)
// val[0] can be an int or str type
(Type::Union(union_ty), _, _) => union_ty.map(self.db(), |ty| {
self.infer_subscript_expression_types(value_node, *ty, slice_ty)
}), }),
(Type::Intersection(intersection_ty), _, _) => intersection_ty
.positive(self.db()) // TODO: we can map over the intersection and fold the results back into an intersection,
.iter() // but we need to make sure we avoid emitting a diagnostic if one positive element has a `__getitem__`
.map(|ty| self.infer_subscript_expression_types(value_node, *ty, slice_ty)) // method but another does not. This means `infer_subscript_expression_types`
.fold( // needs to return a `Result` rather than eagerly emitting diagnostics.
IntersectionBuilder::new(self.db()), (Type::Intersection(_), _, _) => {
IntersectionBuilder::add_positive, todo_type!("Subscript expressions on intersections")
) }
.build(),
// Ex) Given `("a", "b", "c", "d")[1]`, return `"b"` // Ex) Given `("a", "b", "c", "d")[1]`, return `"b"`
(Type::Tuple(tuple_ty), Type::IntLiteral(int), _) if i32::try_from(int).is_ok() => { (Type::Tuple(tuple_ty), Type::IntLiteral(int), _) if i32::try_from(int).is_ok() => {
let tuple = tuple_ty.tuple(self.db()); let tuple = tuple_ty.tuple(self.db());
@ -8176,6 +8173,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
Type::unknown() Type::unknown()
}) })
} }
// Ex) Given `("a", 1, Null)[0:2]`, return `("a", 1)` // Ex) Given `("a", 1, Null)[0:2]`, return `("a", 1)`
(Type::Tuple(tuple_ty), _, Some(SliceLiteral { start, stop, step })) => { (Type::Tuple(tuple_ty), _, Some(SliceLiteral { start, stop, step })) => {
let TupleSpec::Fixed(tuple) = tuple_ty.tuple(self.db()) else { let TupleSpec::Fixed(tuple) = tuple_ty.tuple(self.db()) else {
@ -8189,6 +8187,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
Type::unknown() Type::unknown()
} }
} }
// Ex) Given `"value"[1]`, return `"a"` // Ex) Given `"value"[1]`, return `"a"`
(Type::StringLiteral(literal_ty), Type::IntLiteral(int), _) (Type::StringLiteral(literal_ty), Type::IntLiteral(int), _)
if i32::try_from(int).is_ok() => if i32::try_from(int).is_ok() =>
@ -8212,6 +8211,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
Type::unknown() Type::unknown()
}) })
} }
// Ex) Given `"value"[1:3]`, return `"al"` // Ex) Given `"value"[1:3]`, return `"al"`
(Type::StringLiteral(literal_ty), _, Some(SliceLiteral { start, stop, step })) => { (Type::StringLiteral(literal_ty), _, Some(SliceLiteral { start, stop, step })) => {
let literal_value = literal_ty.value(self.db()); let literal_value = literal_ty.value(self.db());
@ -8226,6 +8226,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
Type::unknown() Type::unknown()
} }
} }
// Ex) Given `b"value"[1]`, return `97` (i.e., `ord(b"a")`) // Ex) Given `b"value"[1]`, return `97` (i.e., `ord(b"a")`)
(Type::BytesLiteral(literal_ty), Type::IntLiteral(int), _) (Type::BytesLiteral(literal_ty), Type::IntLiteral(int), _)
if i32::try_from(int).is_ok() => if i32::try_from(int).is_ok() =>
@ -8249,6 +8250,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
Type::unknown() Type::unknown()
}) })
} }
// Ex) Given `b"value"[1:3]`, return `b"al"` // Ex) Given `b"value"[1:3]`, return `b"al"`
(Type::BytesLiteral(literal_ty), _, Some(SliceLiteral { start, stop, step })) => { (Type::BytesLiteral(literal_ty), _, Some(SliceLiteral { start, stop, step })) => {
let literal_value = literal_ty.value(self.db()); let literal_value = literal_ty.value(self.db());
@ -8261,6 +8263,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
Type::unknown() Type::unknown()
} }
} }
// Ex) Given `"value"[True]`, return `"a"` // Ex) Given `"value"[True]`, return `"a"`
( (
Type::Tuple(_) | Type::StringLiteral(_) | Type::BytesLiteral(_), Type::Tuple(_) | Type::StringLiteral(_) | Type::BytesLiteral(_),
@ -8271,6 +8274,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
value_ty, value_ty,
Type::IntLiteral(i64::from(bool)), Type::IntLiteral(i64::from(bool)),
), ),
(Type::SpecialForm(SpecialFormType::Protocol), Type::Tuple(typevars), _) => { (Type::SpecialForm(SpecialFormType::Protocol), Type::Tuple(typevars), _) => {
let TupleSpec::Fixed(typevars) = typevars.tuple(self.db()) else { let TupleSpec::Fixed(typevars) = typevars.tuple(self.db()) else {
// TODO: emit a diagnostic // TODO: emit a diagnostic
@ -8284,6 +8288,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
.map(|context| Type::KnownInstance(KnownInstanceType::SubscriptedProtocol(context))) .map(|context| Type::KnownInstance(KnownInstanceType::SubscriptedProtocol(context)))
.unwrap_or_else(Type::unknown) .unwrap_or_else(Type::unknown)
} }
(Type::SpecialForm(SpecialFormType::Protocol), typevar, _) => self (Type::SpecialForm(SpecialFormType::Protocol), typevar, _) => self
.legacy_generic_class_context( .legacy_generic_class_context(
value_node, value_node,
@ -8292,10 +8297,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
) )
.map(|context| Type::KnownInstance(KnownInstanceType::SubscriptedProtocol(context))) .map(|context| Type::KnownInstance(KnownInstanceType::SubscriptedProtocol(context)))
.unwrap_or_else(Type::unknown), .unwrap_or_else(Type::unknown),
(Type::KnownInstance(KnownInstanceType::SubscriptedProtocol(_)), _, _) => { (Type::KnownInstance(KnownInstanceType::SubscriptedProtocol(_)), _, _) => {
// TODO: emit a diagnostic // TODO: emit a diagnostic
todo_type!("doubly-specialized typing.Protocol") todo_type!("doubly-specialized typing.Protocol")
} }
(Type::SpecialForm(SpecialFormType::Generic), Type::Tuple(typevars), _) => { (Type::SpecialForm(SpecialFormType::Generic), Type::Tuple(typevars), _) => {
let TupleSpec::Fixed(typevars) = typevars.tuple(self.db()) else { let TupleSpec::Fixed(typevars) = typevars.tuple(self.db()) else {
// TODO: emit a diagnostic // TODO: emit a diagnostic
@ -8309,6 +8316,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
.map(|context| Type::KnownInstance(KnownInstanceType::SubscriptedGeneric(context))) .map(|context| Type::KnownInstance(KnownInstanceType::SubscriptedGeneric(context)))
.unwrap_or_else(Type::unknown) .unwrap_or_else(Type::unknown)
} }
(Type::SpecialForm(SpecialFormType::Generic), typevar, _) => self (Type::SpecialForm(SpecialFormType::Generic), typevar, _) => self
.legacy_generic_class_context( .legacy_generic_class_context(
value_node, value_node,
@ -8317,18 +8325,22 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
) )
.map(|context| Type::KnownInstance(KnownInstanceType::SubscriptedGeneric(context))) .map(|context| Type::KnownInstance(KnownInstanceType::SubscriptedGeneric(context)))
.unwrap_or_else(Type::unknown), .unwrap_or_else(Type::unknown),
(Type::KnownInstance(KnownInstanceType::SubscriptedGeneric(_)), _, _) => { (Type::KnownInstance(KnownInstanceType::SubscriptedGeneric(_)), _, _) => {
// TODO: emit a diagnostic // TODO: emit a diagnostic
todo_type!("doubly-specialized typing.Generic") todo_type!("doubly-specialized typing.Generic")
} }
(Type::SpecialForm(special_form), _, _) if special_form.class().is_special_form() => { (Type::SpecialForm(special_form), _, _) if special_form.class().is_special_form() => {
todo_type!("Inference of subscript on special form") todo_type!("Inference of subscript on special form")
} }
(Type::KnownInstance(known_instance), _, _) (Type::KnownInstance(known_instance), _, _)
if known_instance.class().is_special_form() => if known_instance.class().is_special_form() =>
{ {
todo_type!("Inference of subscript on special form") todo_type!("Inference of subscript on special form")
} }
(value_ty, slice_ty, _) => { (value_ty, slice_ty, _) => {
// If the class defines `__getitem__`, return its return type. // If the class defines `__getitem__`, return its return type.
// //