[ty] Apply function specialization to all overloads (#18020)

Function literals have an optional specialization, which is applied to
the parameter/return type annotations lazily when the function's
signature is requested. We were previously only applying this
specialization to the final overload of an overloaded function.

This manifested most visibly for `list.__add__`, which has an overloaded
definition in the typeshed:


b398b83631/crates/ty_vendored/vendor/typeshed/stdlib/builtins.pyi (L1069-L1072)

Closes https://github.com/astral-sh/ty/issues/314
This commit is contained in:
Douglas Creager 2025-05-12 13:48:54 -04:00 committed by GitHub
parent 3ccc0edfe4
commit bdccb37b4a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 110 additions and 86 deletions

View file

@ -3418,16 +3418,10 @@ impl<'db> Type<'db> {
Type::BoundMethod(bound_method) => {
let signature = bound_method.function(db).signature(db);
Signatures::single(match signature {
FunctionSignature::Single(signature) => {
CallableSignature::single(self, signature.clone())
.with_bound_type(bound_method.self_instance(db))
}
FunctionSignature::Overloaded(signatures, _) => {
CallableSignature::from_overloads(self, signatures.iter().cloned())
.with_bound_type(bound_method.self_instance(db))
}
})
Signatures::single(
CallableSignature::from_overloads(self, signature.overloads.iter().cloned())
.with_bound_type(bound_method.self_instance(db)),
)
}
Type::MethodWrapper(
@ -3785,14 +3779,7 @@ impl<'db> Type<'db> {
Signatures::single(signature)
}
_ => Signatures::single(match function_type.signature(db) {
FunctionSignature::Single(signature) => {
CallableSignature::single(self, signature.clone())
}
FunctionSignature::Overloaded(signatures, _) => {
CallableSignature::from_overloads(self, signatures.iter().cloned())
}
}),
_ => Signatures::single(function_type.signature(db).overloads.clone()),
},
Type::ClassLiteral(class) => match class.known(db) {
@ -6561,46 +6548,21 @@ bitflags! {
}
}
/// A function signature, which can be either a single signature or an overloaded signature.
/// A function signature, which optionally includes an implementation signature if the function is
/// overloaded.
#[derive(Clone, Debug, PartialEq, Eq, Hash, salsa::Update)]
pub(crate) enum FunctionSignature<'db> {
/// A single function signature.
Single(Signature<'db>),
/// An overloaded function signature containing the `@overload`-ed signatures and an optional
/// implementation signature.
Overloaded(Vec<Signature<'db>>, Option<Signature<'db>>),
pub(crate) struct FunctionSignature<'db> {
pub(crate) overloads: CallableSignature<'db>,
pub(crate) implementation: Option<Signature<'db>>,
}
impl<'db> FunctionSignature<'db> {
/// Returns a slice of all signatures.
///
/// For an overloaded function, this only includes the `@overload`-ed signatures and not the
/// implementation signature.
pub(crate) fn as_slice(&self) -> &[Signature<'db>] {
match self {
Self::Single(signature) => std::slice::from_ref(signature),
Self::Overloaded(signatures, _) => signatures,
}
}
/// Returns an iterator over the signatures.
pub(crate) fn iter(&self) -> Iter<Signature<'db>> {
self.as_slice().iter()
}
/// Returns the "bottom" signature (subtype of all fully-static signatures.)
pub(crate) fn bottom(db: &'db dyn Db) -> Self {
Self::Single(Signature::bottom(db))
}
}
impl<'db> IntoIterator for &'db FunctionSignature<'db> {
type Item = &'db Signature<'db>;
type IntoIter = Iter<'db, Signature<'db>>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
FunctionSignature {
overloads: CallableSignature::single(Type::any(), Signature::bottom(db)),
implementation: None,
}
}
}
@ -6671,7 +6633,7 @@ impl<'db> FunctionType<'db> {
pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> Type<'db> {
Type::Callable(CallableType::from_overloads(
db,
self.signature(db).iter().cloned(),
self.signature(db).overloads.iter().cloned(),
))
}
@ -6739,20 +6701,32 @@ impl<'db> FunctionType<'db> {
/// would depend on the function's AST and rerun for every change in that file.
#[salsa::tracked(returns(ref), cycle_fn=signature_cycle_recover, cycle_initial=signature_cycle_initial)]
pub(crate) fn signature(self, db: &'db dyn Db) -> FunctionSignature<'db> {
let specialization = self.specialization(db);
if let Some(overloaded) = self.to_overloaded(db) {
FunctionSignature::Overloaded(
overloaded
.overloads
.iter()
.copied()
.map(|overload| overload.internal_signature(db))
.collect(),
overloaded
.implementation
.map(|implementation| implementation.internal_signature(db)),
)
FunctionSignature {
overloads: CallableSignature::from_overloads(
Type::FunctionLiteral(self),
overloaded.overloads.iter().copied().map(|overload| {
overload
.internal_signature(db)
.apply_optional_specialization(db, specialization)
}),
),
implementation: overloaded.implementation.map(|implementation| {
implementation
.internal_signature(db)
.apply_optional_specialization(db, specialization)
}),
}
} else {
FunctionSignature::Single(self.internal_signature(db))
FunctionSignature {
overloads: CallableSignature::single(
Type::FunctionLiteral(self),
self.internal_signature(db)
.apply_optional_specialization(db, specialization),
),
implementation: None,
}
}
}
@ -6774,17 +6748,13 @@ impl<'db> FunctionType<'db> {
let index = semantic_index(db, scope.file(db));
GenericContext::from_type_params(db, index, type_params)
});
let mut signature = Signature::from_function(
Signature::from_function(
db,
generic_context,
self.inherited_generic_context(db),
definition,
function_stmt_node,
);
if let Some(specialization) = self.specialization(db) {
signature = signature.apply_specialization(db, specialization);
}
signature
)
}
pub(crate) fn is_known(self, db: &'db dyn Db, known_function: KnownFunction) -> bool {
@ -6854,7 +6824,7 @@ impl<'db> FunctionType<'db> {
typevars: &mut FxOrderSet<TypeVarInstance<'db>>,
) {
let signatures = self.signature(db);
for signature in signatures {
for signature in &signatures.overloads {
signature.find_legacy_typevars(db, typevars);
}
}
@ -7114,6 +7084,7 @@ impl<'db> BoundMethodType<'db> {
db,
self.function(db)
.signature(db)
.overloads
.iter()
.map(signatures::Signature::bind_self),
))

View file

@ -10,9 +10,9 @@ use crate::types::class::{ClassLiteral, ClassType, GenericAlias};
use crate::types::generics::{GenericContext, Specialization};
use crate::types::signatures::{Parameter, Parameters, Signature};
use crate::types::{
CallableType, FunctionSignature, IntersectionType, KnownClass, MethodWrapperKind, Protocol,
StringLiteralType, SubclassOfInner, Type, TypeVarBoundOrConstraints, TypeVarInstance,
UnionType, WrapperDescriptorKind,
CallableType, IntersectionType, KnownClass, MethodWrapperKind, Protocol, StringLiteralType,
SubclassOfInner, Type, TypeVarBoundOrConstraints, TypeVarInstance, UnionType,
WrapperDescriptorKind,
};
use crate::{Db, FxOrderSet};
@ -118,8 +118,8 @@ impl Display for DisplayRepresentation<'_> {
// the generic type parameters to the signature, i.e.
// show `def foo[T](x: T) -> T`.
match signature {
FunctionSignature::Single(signature) => {
match signature.overloads.as_slice() {
[signature] => {
write!(
f,
// "def {name}{specialization}{signature}",
@ -128,7 +128,7 @@ impl Display for DisplayRepresentation<'_> {
signature = signature.display(self.db)
)
}
FunctionSignature::Overloaded(signatures, _) => {
signatures => {
// TODO: How to display overloads?
f.write_str("Overload[")?;
let mut join = f.join(", ");
@ -146,8 +146,8 @@ impl Display for DisplayRepresentation<'_> {
// TODO: use the specialization from the method. Similar to the comment above
// about the function specialization,
match function.signature(self.db) {
FunctionSignature::Single(signature) => {
match function.signature(self.db).overloads.as_slice() {
[signature] => {
write!(
f,
"bound method {instance}.{method}{signature}",
@ -156,7 +156,7 @@ impl Display for DisplayRepresentation<'_> {
signature = signature.bind_self().display(self.db)
)
}
FunctionSignature::Overloaded(signatures, _) => {
signatures => {
// TODO: How to display overloads?
f.write_str("Overload[")?;
let mut join = f.join(", ");

View file

@ -195,6 +195,10 @@ impl<'db> CallableSignature<'db> {
self.overloads.iter()
}
pub(crate) fn as_slice(&self) -> &[Signature<'db>] {
self.overloads.as_slice()
}
fn replace_callable_type(&mut self, before: Type<'db>, after: Type<'db>) {
if self.callable_type == before {
self.callable_type = after;
@ -309,12 +313,16 @@ impl<'db> Signature<'db> {
}
}
pub(crate) fn apply_specialization(
&self,
pub(crate) fn apply_optional_specialization(
self,
db: &'db dyn Db,
specialization: Specialization<'db>,
specialization: Option<Specialization<'db>>,
) -> Self {
self.apply_type_mapping(db, specialization.type_mapping())
if let Some(specialization) = specialization {
self.apply_type_mapping(db, specialization.type_mapping())
} else {
self
}
}
pub(crate) fn apply_type_mapping<'a>(
@ -1743,7 +1751,10 @@ mod tests {
// With no decorators, internal and external signature are the same
assert_eq!(
func.signature(&db),
&FunctionSignature::Single(expected_sig)
&FunctionSignature {
overloads: CallableSignature::single(Type::FunctionLiteral(func), expected_sig),
implementation: None
},
);
}
}