[ty] Synthesize precise __getitem__ overloads for tuple subclasses (#19493)

This commit is contained in:
Alex Waygood 2025-07-30 12:25:44 +01:00 committed by GitHub
parent 6237ecb4db
commit feaedb1812
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 323 additions and 5 deletions

View file

@ -19,7 +19,7 @@ use crate::types::function::{DataclassTransformerParams, KnownFunction};
use crate::types::generics::{GenericContext, Specialization, walk_specialization};
use crate::types::infer::nearest_enclosing_class;
use crate::types::signatures::{CallableSignature, Parameter, Parameters, Signature};
use crate::types::tuple::TupleType;
use crate::types::tuple::{TupleSpec, TupleType};
use crate::types::{
BareTypeAliasType, Binding, BoundSuperError, BoundSuperType, CallableType, DataclassParams,
DeprecatedInstance, DynamicType, KnownInstanceType, TypeAliasType, TypeMapping, TypeRelation,
@ -53,7 +53,7 @@ use ruff_db::parsed::{ParsedModuleRef, parsed_module};
use ruff_python_ast::name::Name;
use ruff_python_ast::{self as ast, PythonVersion};
use ruff_text_size::{Ranged, TextRange};
use rustc_hash::{FxHashSet, FxHasher};
use rustc_hash::{FxHashMap, FxHashSet, FxHasher};
type FxOrderMap<K, V> = ordermap::map::OrderMap<K, V, BuildHasherDefault<FxHasher>>;
@ -574,8 +574,25 @@ impl<'db> ClassType<'db> {
/// directly. Use [`ClassType::class_member`] if you require a method that will
/// traverse through the MRO until it finds the member.
pub(super) fn own_class_member(self, db: &'db dyn Db, name: &str) -> PlaceAndQualifiers<'db> {
fn synthesize_getitem_overload_signature<'db>(
index_annotation: Type<'db>,
return_annotation: Type<'db>,
) -> Signature<'db> {
let self_parameter = Parameter::positional_only(Some(Name::new_static("self")));
let index_parameter = Parameter::positional_only(Some(Name::new_static("index")))
.with_annotated_type(index_annotation);
let parameters = Parameters::new([self_parameter, index_parameter]);
Signature::new(parameters, Some(return_annotation))
}
let (class_literal, specialization) = self.class_literal(db);
let fallback_member_lookup = || {
class_literal
.own_class_member(db, specialization, name)
.map_type(|ty| ty.apply_optional_specialization(db, specialization))
};
let synthesize_simple_tuple_method = |return_type| {
let parameters =
Parameters::new([Parameter::positional_only(Some(Name::new_static("self")))
@ -606,6 +623,162 @@ impl<'db> ClassType<'db> {
synthesize_simple_tuple_method(return_type)
}
"__getitem__" if class_literal.is_tuple(db) => {
specialization
.map(|spec| {
let tuple = spec.tuple(db);
let mut element_type_to_indices: FxHashMap<Type<'db>, Vec<i64>> =
FxHashMap::default();
match tuple {
// E.g. for `tuple[int, str]`, we will generate the following overloads:
//
// __getitem__(self, index: Literal[0, -2], /) -> int
// __getitem__(self, index: Literal[1, -1], /) -> str
//
TupleSpec::Fixed(fixed_length_tuple) => {
let tuple_length = fixed_length_tuple.len();
for (index, ty) in fixed_length_tuple.elements().enumerate() {
let entry = element_type_to_indices.entry(*ty).or_default();
if let Ok(index) = i64::try_from(index) {
entry.push(index);
}
if let Ok(index) = i64::try_from(tuple_length - index) {
entry.push(0 - index);
}
}
}
// E.g. for `tuple[str, *tuple[float, ...], bytes, range]`, we will generate the following overloads:
//
// __getitem__(self, index: Literal[0], /) -> str
// __getitem__(self, index: Literal[1], /) -> float | bytes
// __getitem__(self, index: Literal[2], /) -> float | bytes | range
// __getitem__(self, index: Literal[-1], /) -> range
// __getitem__(self, index: Literal[-2], /) -> bytes
// __getitem__(self, index: Literal[-3], /) -> float | str
//
TupleSpec::Variable(variable_length_tuple) => {
for (index, ty) in variable_length_tuple.prefix.iter().enumerate() {
if let Ok(index) = i64::try_from(index) {
element_type_to_indices.entry(*ty).or_default().push(index);
}
let one_based_index = index + 1;
if let Ok(i) = i64::try_from(
variable_length_tuple.suffix.len() + one_based_index,
) {
let overload_return = UnionType::from_elements(
db,
std::iter::once(variable_length_tuple.variable).chain(
variable_length_tuple
.prefix
.iter()
.rev()
.take(one_based_index)
.copied(),
),
);
element_type_to_indices
.entry(overload_return)
.or_default()
.push(0 - i);
}
}
for (index, ty) in
variable_length_tuple.suffix.iter().rev().enumerate()
{
if let Some(index) =
index.checked_add(1).and_then(|i| i64::try_from(i).ok())
{
element_type_to_indices
.entry(*ty)
.or_default()
.push(0 - index);
}
if let Ok(i) =
i64::try_from(variable_length_tuple.prefix.len() + index)
{
let overload_return = UnionType::from_elements(
db,
std::iter::once(variable_length_tuple.variable).chain(
variable_length_tuple
.suffix
.iter()
.take(index + 1)
.copied(),
),
);
element_type_to_indices
.entry(overload_return)
.or_default()
.push(i);
}
}
}
}
let all_elements_unioned =
UnionType::from_elements(db, tuple.all_elements());
let mut overload_signatures =
Vec::with_capacity(element_type_to_indices.len().saturating_add(2));
overload_signatures.extend(element_type_to_indices.into_iter().filter_map(
|(return_type, mut indices)| {
if return_type.is_equivalent_to(db, all_elements_unioned) {
return None;
}
// Sorting isn't strictly required, but leads to nicer `reveal_type` output
indices.sort_unstable();
let index_annotation = UnionType::from_elements(
db,
indices.into_iter().map(Type::IntLiteral),
);
Some(synthesize_getitem_overload_signature(
index_annotation,
return_type,
))
},
));
// Fallback overloads: for `tuple[int, str]`, we will generate the following overloads:
//
// __getitem__(self, index: int, /) -> int | str
// __getitem__(self, index: slice[Any, Any, Any], /) -> tuple[int | str, ...]
//
// and for `tuple[str, *tuple[float, ...], bytes]`, we will generate the following overloads:
//
// __getitem__(self, index: int, /) -> str | float | bytes
// __getitem__(self, index: slice[Any, Any, Any], /) -> tuple[str | float | bytes, ...]
//
overload_signatures.push(synthesize_getitem_overload_signature(
KnownClass::SupportsIndex.to_instance(db),
all_elements_unioned,
));
overload_signatures.push(synthesize_getitem_overload_signature(
KnownClass::Slice.to_instance(db),
TupleType::homogeneous(db, all_elements_unioned),
));
let getitem_signature =
CallableSignature::from_overloads(overload_signatures);
let getitem_type =
Type::Callable(CallableType::new(db, getitem_signature, true));
Place::bound(getitem_type).into()
})
.unwrap_or_else(fallback_member_lookup)
}
// ```py
// class tuple:
// @overload
@ -672,9 +845,7 @@ impl<'db> ClassType<'db> {
Place::bound(synthesized_dunder).into()
}
_ => class_literal
.own_class_member(db, specialization, name)
.map_type(|ty| ty.apply_optional_specialization(db, specialization)),
_ => fallback_member_lookup(),
}
}