From feaedb1812e043da26a332983e4fa54e33391f2c Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Wed, 30 Jul 2025 12:25:44 +0100 Subject: [PATCH] [ty] Synthesize precise `__getitem__` overloads for tuple subclasses (#19493) --- .../resources/mdtest/subscript/tuple.md | 147 ++++++++++++++ crates/ty_python_semantic/src/types/class.rs | 181 +++++++++++++++++- 2 files changed, 323 insertions(+), 5 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/subscript/tuple.md b/crates/ty_python_semantic/resources/mdtest/subscript/tuple.md index 17886ffaef..9b345f7978 100644 --- a/crates/ty_python_semantic/resources/mdtest/subscript/tuple.md +++ b/crates/ty_python_semantic/resources/mdtest/subscript/tuple.md @@ -2,6 +2,11 @@ ## Indexing +```toml +[environment] +python-version = "3.11" +``` + ```py t = (1, "a", "b") @@ -20,6 +25,148 @@ b = t[-4] # error: [index-out-of-bounds] reveal_type(b) # revealed: Unknown ``` +Precise types for index operations are also inferred for tuple subclasses: + +```py +class I0: ... +class I1: ... +class I2: ... +class I3: ... +class I5: ... +class HeterogeneousSubclass0(tuple[()]): ... + +# revealed: Overload[(self, index: SupportsIndex, /) -> Never, (self, index: slice[Any, Any, Any], /) -> tuple[()]] +reveal_type(HeterogeneousSubclass0.__getitem__) + +def f0(h0: HeterogeneousSubclass0, i: int): + reveal_type(h0[0]) # revealed: Never + reveal_type(h0[1]) # revealed: Never + reveal_type(h0[-1]) # revealed: Never + reveal_type(h0[i]) # revealed: Never + +class HeterogeneousSubclass1(tuple[I0]): ... + +# revealed: Overload[(self, index: SupportsIndex, /) -> I0, (self, index: slice[Any, Any, Any], /) -> tuple[I0, ...]] +reveal_type(HeterogeneousSubclass1.__getitem__) + +def f0(h1: HeterogeneousSubclass1, i: int): + reveal_type(h1[0]) # revealed: I0 + reveal_type(h1[1]) # revealed: I0 + reveal_type(h1[-1]) # revealed: I0 + reveal_type(h1[i]) # revealed: I0 + +# Element at index 2 is deliberately the same as the element at index 1, +# to illustrate that the `__getitem__` overloads for these two indices are combined +class HeterogeneousSubclass4(tuple[I0, I1, I0, I3]): ... + +# revealed: Overload[(self, index: Literal[-4, -2, 0, 2], /) -> I0, (self, index: Literal[-3, 1], /) -> I1, (self, index: Literal[-1, 3], /) -> I3, (self, index: SupportsIndex, /) -> I0 | I1 | I3, (self, index: slice[Any, Any, Any], /) -> tuple[I0 | I1 | I3, ...]] +reveal_type(HeterogeneousSubclass4.__getitem__) + +def f(h4: HeterogeneousSubclass4, i: int): + reveal_type(h4[0]) # revealed: I0 + reveal_type(h4[1]) # revealed: I1 + reveal_type(h4[2]) # revealed: I0 + reveal_type(h4[3]) # revealed: I3 + reveal_type(h4[-1]) # revealed: I3 + reveal_type(h4[-2]) # revealed: I0 + reveal_type(h4[-3]) # revealed: I1 + reveal_type(h4[-4]) # revealed: I0 + reveal_type(h4[i]) # revealed: I0 | I1 | I3 + +class MixedSubclass(tuple[I0, *tuple[I1, ...], I2, I3, I2, I5]): ... + +# revealed: Overload[(self, index: Literal[0], /) -> I0, (self, index: Literal[2, 3], /) -> I1 | I2 | I3, (self, index: Literal[-1], /) -> I5, (self, index: Literal[1], /) -> I1 | I2, (self, index: Literal[-3], /) -> I3, (self, index: Literal[-5], /) -> I1 | I0, (self, index: Literal[-4, -2], /) -> I2, (self, index: Literal[4], /) -> I1 | I2 | I3 | I5, (self, index: SupportsIndex, /) -> I0 | I1 | I2 | I3 | I5, (self, index: slice[Any, Any, Any], /) -> tuple[I0 | I1 | I2 | I3 | I5, ...]] +reveal_type(MixedSubclass.__getitem__) + +def g(m: MixedSubclass, i: int): + reveal_type(m[0]) # revealed: I0 + reveal_type(m[1]) # revealed: I1 | I2 + reveal_type(m[2]) # revealed: I1 | I2 | I3 + reveal_type(m[3]) # revealed: I1 | I2 | I3 + reveal_type(m[4]) # revealed: I1 | I2 | I3 | I5 + + reveal_type(m[-1]) # revealed: I5 + reveal_type(m[-2]) # revealed: I2 + reveal_type(m[-3]) # revealed: I3 + reveal_type(m[-4]) # revealed: I2 + reveal_type(m[-5]) # revealed: I1 | I0 + + reveal_type(m[i]) # revealed: I0 | I1 | I2 | I3 | I5 + + # Ideally we would not include `I0` in the unions for these, + # but it's not possible to do this using only synthesized overloads. + reveal_type(m[5]) # revealed: I0 | I1 | I2 | I3 | I5 + reveal_type(m[10]) # revealed: I0 | I1 | I2 | I3 | I5 + + # Similarly, ideally these would just be `I0` | I1`, + # but achieving that with only synthesized overloads wouldn't be possible + reveal_type(m[-6]) # revealed: I0 | I1 | I2 | I3 | I5 + reveal_type(m[-10]) # revealed: I0 | I1 | I2 | I3 | I5 + +class MixedSubclass2(tuple[I0, I1, *tuple[I2, ...], I3]): ... + +# revealed: Overload[(self, index: Literal[-1], /) -> I3, (self, index: Literal[0], /) -> I0, (self, index: Literal[-2], /) -> I2 | I1, (self, index: Literal[2], /) -> I2 | I3, (self, index: Literal[1], /) -> I1, (self, index: Literal[-3], /) -> I2 | I1 | I0, (self, index: SupportsIndex, /) -> I0 | I1 | I2 | I3, (self, index: slice[Any, Any, Any], /) -> tuple[I0 | I1 | I2 | I3, ...]] +reveal_type(MixedSubclass2.__getitem__) + +def g(m: MixedSubclass2, i: int): + reveal_type(m[0]) # revealed: I0 + reveal_type(m[1]) # revealed: I1 + reveal_type(m[2]) # revealed: I2 | I3 + + # Ideally this would just be `I2 | I3`, + # but that's not possible to achieve with synthesized overloads + reveal_type(m[3]) # revealed: I0 | I1 | I2 | I3 + + reveal_type(m[-1]) # revealed: I3 + reveal_type(m[-2]) # revealed: I2 | I1 + reveal_type(m[-3]) # revealed: I2 | I1 | I0 + + # Ideally this would just be `I2 | I1 | I0`, + # but that's not possible to achieve with synthesized overloads + reveal_type(m[-4]) # revealed: I0 | I1 | I2 | I3 +``` + +The stdlib API `os.stat` is a commonly used API that returns an instance of a tuple subclass +(`os.stat_result`), and therefore provides a good integration test for tuple subclasses. + +```py +import os +import stat + +reveal_type(os.stat("my_file.txt")) # revealed: stat_result +reveal_type(os.stat("my_file.txt")[stat.ST_MODE]) # revealed: int +reveal_type(os.stat("my_file.txt")[stat.ST_ATIME]) # revealed: int | float + +# revealed: tuple[, , , , , , , , typing.Protocol, typing.Generic, ] +reveal_type(os.stat_result.__mro__) + +# There are no specific overloads for the `float` elements in `os.stat_result`, +# because the fallback `(self, index: SupportsIndex, /) -> int | float` overload +# gives the right result for those elements in the tuple, and we aim to synthesize +# the minimum number of overloads for any given tuple +# +# revealed: Overload[(self, index: Literal[-10, -9, -8, -7, -6, -5, -4, 0, 1, 2, 3, 4, 5, 6], /) -> int, (self, index: SupportsIndex, /) -> int | float, (self, index: slice[Any, Any, Any], /) -> tuple[int | float, ...]] +reveal_type(os.stat_result.__getitem__) +``` + +Because of the synthesized `__getitem__` overloads we synthesize for tuples and tuple subclasses, +tuples are naturally understood as being subtypes of protocols that have precise return types from +`__getitem__` method members: + +```py +from typing import Protocol, Literal +from ty_extensions import static_assert, is_subtype_of + +class IntFromZeroSubscript(Protocol): + def __getitem__(self, index: Literal[0], /) -> int: ... + +static_assert(is_subtype_of(tuple[int, str], IntFromZeroSubscript)) + +class TupleSubclass(tuple[int, str]): ... + +static_assert(is_subtype_of(TupleSubclass, IntFromZeroSubscript)) +``` + ## Slices ```py diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index 201307dae5..cf9e7ff331 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -19,7 +19,7 @@ use crate::types::function::{DataclassTransformerParams, KnownFunction}; use crate::types::generics::{GenericContext, Specialization, walk_specialization}; use crate::types::infer::nearest_enclosing_class; use crate::types::signatures::{CallableSignature, Parameter, Parameters, Signature}; -use crate::types::tuple::TupleType; +use crate::types::tuple::{TupleSpec, TupleType}; use crate::types::{ BareTypeAliasType, Binding, BoundSuperError, BoundSuperType, CallableType, DataclassParams, DeprecatedInstance, DynamicType, KnownInstanceType, TypeAliasType, TypeMapping, TypeRelation, @@ -53,7 +53,7 @@ use ruff_db::parsed::{ParsedModuleRef, parsed_module}; use ruff_python_ast::name::Name; use ruff_python_ast::{self as ast, PythonVersion}; use ruff_text_size::{Ranged, TextRange}; -use rustc_hash::{FxHashSet, FxHasher}; +use rustc_hash::{FxHashMap, FxHashSet, FxHasher}; type FxOrderMap = ordermap::map::OrderMap>; @@ -574,8 +574,25 @@ impl<'db> ClassType<'db> { /// directly. Use [`ClassType::class_member`] if you require a method that will /// traverse through the MRO until it finds the member. pub(super) fn own_class_member(self, db: &'db dyn Db, name: &str) -> PlaceAndQualifiers<'db> { + fn synthesize_getitem_overload_signature<'db>( + index_annotation: Type<'db>, + return_annotation: Type<'db>, + ) -> Signature<'db> { + let self_parameter = Parameter::positional_only(Some(Name::new_static("self"))); + let index_parameter = Parameter::positional_only(Some(Name::new_static("index"))) + .with_annotated_type(index_annotation); + let parameters = Parameters::new([self_parameter, index_parameter]); + Signature::new(parameters, Some(return_annotation)) + } + let (class_literal, specialization) = self.class_literal(db); + let fallback_member_lookup = || { + class_literal + .own_class_member(db, specialization, name) + .map_type(|ty| ty.apply_optional_specialization(db, specialization)) + }; + let synthesize_simple_tuple_method = |return_type| { let parameters = Parameters::new([Parameter::positional_only(Some(Name::new_static("self"))) @@ -606,6 +623,162 @@ impl<'db> ClassType<'db> { synthesize_simple_tuple_method(return_type) } + "__getitem__" if class_literal.is_tuple(db) => { + specialization + .map(|spec| { + let tuple = spec.tuple(db); + + let mut element_type_to_indices: FxHashMap, Vec> = + FxHashMap::default(); + + match tuple { + // E.g. for `tuple[int, str]`, we will generate the following overloads: + // + // __getitem__(self, index: Literal[0, -2], /) -> int + // __getitem__(self, index: Literal[1, -1], /) -> str + // + TupleSpec::Fixed(fixed_length_tuple) => { + let tuple_length = fixed_length_tuple.len(); + + for (index, ty) in fixed_length_tuple.elements().enumerate() { + let entry = element_type_to_indices.entry(*ty).or_default(); + if let Ok(index) = i64::try_from(index) { + entry.push(index); + } + if let Ok(index) = i64::try_from(tuple_length - index) { + entry.push(0 - index); + } + } + } + + // E.g. for `tuple[str, *tuple[float, ...], bytes, range]`, we will generate the following overloads: + // + // __getitem__(self, index: Literal[0], /) -> str + // __getitem__(self, index: Literal[1], /) -> float | bytes + // __getitem__(self, index: Literal[2], /) -> float | bytes | range + // __getitem__(self, index: Literal[-1], /) -> range + // __getitem__(self, index: Literal[-2], /) -> bytes + // __getitem__(self, index: Literal[-3], /) -> float | str + // + TupleSpec::Variable(variable_length_tuple) => { + for (index, ty) in variable_length_tuple.prefix.iter().enumerate() { + if let Ok(index) = i64::try_from(index) { + element_type_to_indices.entry(*ty).or_default().push(index); + } + + let one_based_index = index + 1; + + if let Ok(i) = i64::try_from( + variable_length_tuple.suffix.len() + one_based_index, + ) { + let overload_return = UnionType::from_elements( + db, + std::iter::once(variable_length_tuple.variable).chain( + variable_length_tuple + .prefix + .iter() + .rev() + .take(one_based_index) + .copied(), + ), + ); + element_type_to_indices + .entry(overload_return) + .or_default() + .push(0 - i); + } + } + + for (index, ty) in + variable_length_tuple.suffix.iter().rev().enumerate() + { + if let Some(index) = + index.checked_add(1).and_then(|i| i64::try_from(i).ok()) + { + element_type_to_indices + .entry(*ty) + .or_default() + .push(0 - index); + } + + if let Ok(i) = + i64::try_from(variable_length_tuple.prefix.len() + index) + { + let overload_return = UnionType::from_elements( + db, + std::iter::once(variable_length_tuple.variable).chain( + variable_length_tuple + .suffix + .iter() + .take(index + 1) + .copied(), + ), + ); + element_type_to_indices + .entry(overload_return) + .or_default() + .push(i); + } + } + } + } + + let all_elements_unioned = + UnionType::from_elements(db, tuple.all_elements()); + + let mut overload_signatures = + Vec::with_capacity(element_type_to_indices.len().saturating_add(2)); + + overload_signatures.extend(element_type_to_indices.into_iter().filter_map( + |(return_type, mut indices)| { + if return_type.is_equivalent_to(db, all_elements_unioned) { + return None; + } + + // Sorting isn't strictly required, but leads to nicer `reveal_type` output + indices.sort_unstable(); + + let index_annotation = UnionType::from_elements( + db, + indices.into_iter().map(Type::IntLiteral), + ); + + Some(synthesize_getitem_overload_signature( + index_annotation, + return_type, + )) + }, + )); + + // Fallback overloads: for `tuple[int, str]`, we will generate the following overloads: + // + // __getitem__(self, index: int, /) -> int | str + // __getitem__(self, index: slice[Any, Any, Any], /) -> tuple[int | str, ...] + // + // and for `tuple[str, *tuple[float, ...], bytes]`, we will generate the following overloads: + // + // __getitem__(self, index: int, /) -> str | float | bytes + // __getitem__(self, index: slice[Any, Any, Any], /) -> tuple[str | float | bytes, ...] + // + overload_signatures.push(synthesize_getitem_overload_signature( + KnownClass::SupportsIndex.to_instance(db), + all_elements_unioned, + )); + + overload_signatures.push(synthesize_getitem_overload_signature( + KnownClass::Slice.to_instance(db), + TupleType::homogeneous(db, all_elements_unioned), + )); + + let getitem_signature = + CallableSignature::from_overloads(overload_signatures); + let getitem_type = + Type::Callable(CallableType::new(db, getitem_signature, true)); + Place::bound(getitem_type).into() + }) + .unwrap_or_else(fallback_member_lookup) + } + // ```py // class tuple: // @overload @@ -672,9 +845,7 @@ impl<'db> ClassType<'db> { Place::bound(synthesized_dunder).into() } - _ => class_literal - .own_class_member(db, specialization, name) - .map_type(|ty| ty.apply_optional_specialization(db, specialization)), + _ => fallback_member_lookup(), } }