[ty] Apply type mappings to functions eagerly (#20596)

`TypeMapping` is no longer cow-shaped.

Before, `TypeMapping` defined a `to_owned` method, which would make an
owned copy of the type mapping. This let us apply type mappings to
function literals lazily. The primary part of a function that you have
to apply the type mapping to is its signature. The hypothesis was that
doing this lazily would prevent us from constructing the signature of a
function just to apply a type mapping; if you never ended up needed the
updated function signature, that would be extraneous work.

But looking at the CI for this PR, it looks like that hypothesis is
wrong! And this definitely cleans up the code quite a bit. It also means
that over time we can consider replacing all of these `TypeMapping` enum
variants with separate `TypeTransformer` impls.

---------

Co-authored-by: David Peter <mail@david-peter.de>
This commit is contained in:
Douglas Creager 2025-09-29 07:24:40 -04:00 committed by GitHub
parent 3f640dacd4
commit cf2b083668
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 105 additions and 198 deletions

View file

@ -53,7 +53,6 @@ use crate::types::function::{
}; };
use crate::types::generics::{ use crate::types::generics::{
GenericContext, PartialSpecialization, Specialization, bind_typevar, walk_generic_context, GenericContext, PartialSpecialization, Specialization, bind_typevar, walk_generic_context,
walk_partial_specialization, walk_specialization,
}; };
pub use crate::types::ide_support::{ pub use crate::types::ide_support::{
CallSignatureDetails, Member, MemberWithDefinition, all_members, call_signature_details, CallSignatureDetails, Member, MemberWithDefinition, all_members, call_signature_details,
@ -6109,7 +6108,7 @@ impl<'db> Type<'db> {
} }
Type::FunctionLiteral(function) => { Type::FunctionLiteral(function) => {
let function = Type::FunctionLiteral(function.with_type_mapping(db, type_mapping)); let function = Type::FunctionLiteral(function.apply_type_mapping_impl(db, type_mapping, visitor));
match type_mapping { match type_mapping {
TypeMapping::PromoteLiterals => function.literal_promotion_type(db) TypeMapping::PromoteLiterals => function.literal_promotion_type(db)
@ -6120,8 +6119,8 @@ impl<'db> Type<'db> {
Type::BoundMethod(method) => Type::BoundMethod(BoundMethodType::new( Type::BoundMethod(method) => Type::BoundMethod(BoundMethodType::new(
db, db,
method.function(db).with_type_mapping(db, type_mapping), method.function(db).apply_type_mapping_impl(db, type_mapping, visitor),
method.self_instance(db).apply_type_mapping(db, type_mapping), method.self_instance(db).apply_type_mapping_impl(db, type_mapping, visitor),
)), )),
Type::NominalInstance(instance) => Type::NominalInstance(instance) =>
@ -6140,13 +6139,13 @@ impl<'db> Type<'db> {
Type::KnownBoundMethod(KnownBoundMethodType::FunctionTypeDunderGet(function)) => { Type::KnownBoundMethod(KnownBoundMethodType::FunctionTypeDunderGet(function)) => {
Type::KnownBoundMethod(KnownBoundMethodType::FunctionTypeDunderGet( Type::KnownBoundMethod(KnownBoundMethodType::FunctionTypeDunderGet(
function.with_type_mapping(db, type_mapping), function.apply_type_mapping_impl(db, type_mapping, visitor),
)) ))
} }
Type::KnownBoundMethod(KnownBoundMethodType::FunctionTypeDunderCall(function)) => { Type::KnownBoundMethod(KnownBoundMethodType::FunctionTypeDunderCall(function)) => {
Type::KnownBoundMethod(KnownBoundMethodType::FunctionTypeDunderCall( Type::KnownBoundMethod(KnownBoundMethodType::FunctionTypeDunderCall(
function.with_type_mapping(db, type_mapping), function.apply_type_mapping_impl(db, type_mapping, visitor),
)) ))
} }
@ -6782,84 +6781,7 @@ pub enum TypeMapping<'a, 'db> {
Materialize(MaterializationKind), Materialize(MaterializationKind),
} }
fn walk_type_mapping<'db, V: visitor::TypeVisitor<'db> + ?Sized>(
db: &'db dyn Db,
mapping: &TypeMapping<'_, 'db>,
visitor: &V,
) {
match mapping {
TypeMapping::Specialization(specialization) => {
walk_specialization(db, *specialization, visitor);
}
TypeMapping::PartialSpecialization(specialization) => {
walk_partial_specialization(db, specialization, visitor);
}
TypeMapping::BindSelf(self_type) => {
visitor.visit_type(db, *self_type);
}
TypeMapping::ReplaceSelf { new_upper_bound } => {
visitor.visit_type(db, *new_upper_bound);
}
TypeMapping::PromoteLiterals
| TypeMapping::BindLegacyTypevars(_)
| TypeMapping::MarkTypeVarsInferable(_)
| TypeMapping::Materialize(_) => {}
}
}
impl<'db> TypeMapping<'_, 'db> { impl<'db> TypeMapping<'_, 'db> {
fn to_owned(&self) -> TypeMapping<'db, 'db> {
match self {
TypeMapping::Specialization(specialization) => {
TypeMapping::Specialization(*specialization)
}
TypeMapping::PartialSpecialization(partial) => {
TypeMapping::PartialSpecialization(partial.to_owned())
}
TypeMapping::PromoteLiterals => TypeMapping::PromoteLiterals,
TypeMapping::BindLegacyTypevars(binding_context) => {
TypeMapping::BindLegacyTypevars(*binding_context)
}
TypeMapping::BindSelf(self_type) => TypeMapping::BindSelf(*self_type),
TypeMapping::ReplaceSelf { new_upper_bound } => TypeMapping::ReplaceSelf {
new_upper_bound: *new_upper_bound,
},
TypeMapping::MarkTypeVarsInferable(binding_context) => {
TypeMapping::MarkTypeVarsInferable(*binding_context)
}
TypeMapping::Materialize(materialization_kind) => {
TypeMapping::Materialize(*materialization_kind)
}
}
}
fn normalized_impl(&self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
match self {
TypeMapping::Specialization(specialization) => {
TypeMapping::Specialization(specialization.normalized_impl(db, visitor))
}
TypeMapping::PartialSpecialization(partial) => {
TypeMapping::PartialSpecialization(partial.normalized_impl(db, visitor))
}
TypeMapping::PromoteLiterals => TypeMapping::PromoteLiterals,
TypeMapping::BindLegacyTypevars(binding_context) => {
TypeMapping::BindLegacyTypevars(*binding_context)
}
TypeMapping::BindSelf(self_type) => {
TypeMapping::BindSelf(self_type.normalized_impl(db, visitor))
}
TypeMapping::ReplaceSelf { new_upper_bound } => TypeMapping::ReplaceSelf {
new_upper_bound: new_upper_bound.normalized_impl(db, visitor),
},
TypeMapping::MarkTypeVarsInferable(binding_context) => {
TypeMapping::MarkTypeVarsInferable(*binding_context)
}
TypeMapping::Materialize(materialization_kind) => {
TypeMapping::Materialize(*materialization_kind)
}
}
}
/// Update the generic context of a [`Signature`] according to the current type mapping /// Update the generic context of a [`Signature`] according to the current type mapping
pub(crate) fn update_signature_generic_context( pub(crate) fn update_signature_generic_context(
&self, &self,

View file

@ -77,11 +77,11 @@ use crate::types::narrow::ClassInfoConstraintFunction;
use crate::types::signatures::{CallableSignature, Signature}; use crate::types::signatures::{CallableSignature, Signature};
use crate::types::visitor::any_over_type; use crate::types::visitor::any_over_type;
use crate::types::{ use crate::types::{
BoundMethodType, BoundTypeVarInstance, CallableType, ClassBase, ClassLiteral, ClassType, ApplyTypeMappingVisitor, BoundMethodType, BoundTypeVarInstance, CallableType, ClassBase,
DeprecatedInstance, DynamicType, FindLegacyTypeVarsVisitor, HasRelationToVisitor, ClassLiteral, ClassType, DeprecatedInstance, DynamicType, FindLegacyTypeVarsVisitor,
IsEquivalentVisitor, KnownClass, KnownInstanceType, NormalizedVisitor, SpecialFormType, HasRelationToVisitor, IsEquivalentVisitor, KnownClass, KnownInstanceType, NormalizedVisitor,
TrackedConstraintSet, Truthiness, Type, TypeMapping, TypeRelation, UnionBuilder, all_members, SpecialFormType, TrackedConstraintSet, Truthiness, Type, TypeMapping, TypeRelation,
binding_type, todo_type, walk_type_mapping, UnionBuilder, all_members, binding_type, todo_type, walk_signature,
}; };
use crate::{Db, FxOrderSet, ModuleName, resolve_module}; use crate::{Db, FxOrderSet, ModuleName, resolve_module};
@ -623,33 +623,24 @@ impl<'db> FunctionLiteral<'db> {
/// calling query is not in the same file as this function is defined in, then this will create /// calling query is not in the same file as this function is defined in, then this will create
/// a cross-module dependency directly on the full AST which will lead to cache /// a cross-module dependency directly on the full AST which will lead to cache
/// over-invalidation. /// over-invalidation.
fn signature<'a>( fn signature(self, db: &'db dyn Db) -> CallableSignature<'db> {
self,
db: &'db dyn Db,
type_mappings: &'a [TypeMapping<'a, 'db>],
) -> CallableSignature<'db>
where
'db: 'a,
{
// We only include an implementation (i.e. a definition not decorated with `@overload`) if // We only include an implementation (i.e. a definition not decorated with `@overload`) if
// it's the only definition. // it's the only definition.
let inherited_generic_context = self.inherited_generic_context(db); let inherited_generic_context = self.inherited_generic_context(db);
let (overloads, implementation) = self.overloads_and_implementation(db); let (overloads, implementation) = self.overloads_and_implementation(db);
if let Some(implementation) = implementation { if let Some(implementation) = implementation {
if overloads.is_empty() { if overloads.is_empty() {
return CallableSignature::single(type_mappings.iter().fold( return CallableSignature::single(
implementation.signature(db, inherited_generic_context), implementation.signature(db, inherited_generic_context),
|sig, mapping| sig.apply_type_mapping(db, mapping), );
));
} }
} }
CallableSignature::from_overloads(overloads.iter().map(|overload| { CallableSignature::from_overloads(
type_mappings.iter().fold( overloads
overload.signature(db, inherited_generic_context), .iter()
|sig, mapping| sig.apply_type_mapping(db, mapping), .map(|overload| overload.signature(db, inherited_generic_context)),
) )
}))
} }
/// Typed externally-visible signature of the last overload or implementation of this function. /// Typed externally-visible signature of the last overload or implementation of this function.
@ -660,20 +651,10 @@ impl<'db> FunctionLiteral<'db> {
/// calling query is not in the same file as this function is defined in, then this will create /// calling query is not in the same file as this function is defined in, then this will create
/// a cross-module dependency directly on the full AST which will lead to cache /// a cross-module dependency directly on the full AST which will lead to cache
/// over-invalidation. /// over-invalidation.
fn last_definition_signature<'a>( fn last_definition_signature(self, db: &'db dyn Db) -> Signature<'db> {
self,
db: &'db dyn Db,
type_mappings: &'a [TypeMapping<'a, 'db>],
) -> Signature<'db>
where
'db: 'a,
{
let inherited_generic_context = self.inherited_generic_context(db); let inherited_generic_context = self.inherited_generic_context(db);
type_mappings.iter().fold(
self.last_definition(db) self.last_definition(db)
.signature(db, inherited_generic_context), .signature(db, inherited_generic_context)
|sig, mapping| sig.apply_type_mapping(db, mapping),
)
} }
fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
@ -691,11 +672,19 @@ impl<'db> FunctionLiteral<'db> {
pub struct FunctionType<'db> { pub struct FunctionType<'db> {
pub(crate) literal: FunctionLiteral<'db>, pub(crate) literal: FunctionLiteral<'db>,
/// Type mappings that should be applied to the function's parameter and return types. This /// Contains a potentially modified signature for this function literal, in case certain operations
/// might include specializations of enclosing generic contexts (e.g. for non-generic methods /// (like type mappings) have been applied to it.
/// of a specialized generic class). ///
#[returns(deref)] /// See also: [`FunctionLiteral::updated_signature`].
type_mappings: Box<[TypeMapping<'db, 'db>]>, #[returns(as_ref)]
updated_signature: Option<CallableSignature<'db>>,
/// Contains a potentially modified signature for the last overload or the implementation of this
/// function literal, in case certain operations (like type mappings) have been applied to it.
///
/// See also: [`FunctionLiteral::last_definition_signature`].
#[returns(as_ref)]
updated_last_definition_signature: Option<Signature<'db>>,
} }
// The Salsa heap is tracked separately. // The Salsa heap is tracked separately.
@ -707,8 +696,13 @@ pub(super) fn walk_function_type<'db, V: super::visitor::TypeVisitor<'db> + ?Siz
visitor: &V, visitor: &V,
) { ) {
walk_function_literal(db, function.literal(db), visitor); walk_function_literal(db, function.literal(db), visitor);
for mapping in function.type_mappings(db) { if let Some(callable_signature) = function.updated_signature(db) {
walk_type_mapping(db, mapping, visitor); for signature in &callable_signature.overloads {
walk_signature(db, signature, visitor);
}
}
if let Some(signature) = function.updated_last_definition_signature(db) {
walk_signature(db, signature, visitor);
} }
} }
@ -722,21 +716,41 @@ impl<'db> FunctionType<'db> {
let literal = self let literal = self
.literal(db) .literal(db)
.with_inherited_generic_context(db, inherited_generic_context); .with_inherited_generic_context(db, inherited_generic_context);
Self::new(db, literal, self.type_mappings(db)) let updated_signature = self.updated_signature(db).map(|signature| {
signature.with_inherited_generic_context(Some(inherited_generic_context))
});
let updated_last_definition_signature =
self.updated_last_definition_signature(db).map(|signature| {
signature
.clone()
.with_inherited_generic_context(Some(inherited_generic_context))
});
Self::new(
db,
literal,
updated_signature,
updated_last_definition_signature,
)
} }
pub(crate) fn with_type_mapping<'a>( pub(crate) fn apply_type_mapping_impl<'a>(
self, self,
db: &'db dyn Db, db: &'db dyn Db,
type_mapping: &TypeMapping<'a, 'db>, type_mapping: &TypeMapping<'a, 'db>,
visitor: &ApplyTypeMappingVisitor<'db>,
) -> Self { ) -> Self {
let type_mappings: Box<[_]> = self let updated_signature =
.type_mappings(db) self.signature(db)
.iter() .apply_type_mapping_impl(db, type_mapping, visitor);
.cloned() let updated_last_definition_signature = self
.chain(std::iter::once(type_mapping.to_owned())) .last_definition_signature(db)
.collect(); .apply_type_mapping_impl(db, type_mapping, visitor);
Self::new(db, self.literal(db), type_mappings) Self::new(
db,
self.literal(db),
Some(updated_signature),
Some(updated_last_definition_signature),
)
} }
pub(crate) fn with_dataclass_transformer_params( pub(crate) fn with_dataclass_transformer_params(
@ -752,7 +766,7 @@ impl<'db> FunctionType<'db> {
.with_dataclass_transformer_params(db, params); .with_dataclass_transformer_params(db, params);
let literal = let literal =
FunctionLiteral::new(db, last_definition, literal.inherited_generic_context(db)); FunctionLiteral::new(db, last_definition, literal.inherited_generic_context(db));
Self::new(db, literal, self.type_mappings(db)) Self::new(db, literal, None, None)
} }
/// Returns the [`File`] in which this function is defined. /// Returns the [`File`] in which this function is defined.
@ -907,7 +921,9 @@ impl<'db> FunctionType<'db> {
/// would depend on the function's AST and rerun for every change in that file. /// would depend on the function's AST and rerun for every change in that file.
#[salsa::tracked(returns(ref), cycle_fn=signature_cycle_recover, cycle_initial=signature_cycle_initial, heap_size=ruff_memory_usage::heap_size)] #[salsa::tracked(returns(ref), cycle_fn=signature_cycle_recover, cycle_initial=signature_cycle_initial, heap_size=ruff_memory_usage::heap_size)]
pub(crate) fn signature(self, db: &'db dyn Db) -> CallableSignature<'db> { pub(crate) fn signature(self, db: &'db dyn Db) -> CallableSignature<'db> {
self.literal(db).signature(db, self.type_mappings(db)) self.updated_signature(db)
.cloned()
.unwrap_or_else(|| self.literal(db).signature(db))
} }
/// Typed externally-visible signature of the last overload or implementation of this function. /// Typed externally-visible signature of the last overload or implementation of this function.
@ -926,8 +942,9 @@ impl<'db> FunctionType<'db> {
heap_size=ruff_memory_usage::heap_size, heap_size=ruff_memory_usage::heap_size,
)] )]
pub(crate) fn last_definition_signature(self, db: &'db dyn Db) -> Signature<'db> { pub(crate) fn last_definition_signature(self, db: &'db dyn Db) -> Signature<'db> {
self.literal(db) self.updated_last_definition_signature(db)
.last_definition_signature(db, self.type_mappings(db)) .cloned()
.unwrap_or_else(|| self.literal(db).last_definition_signature(db))
} }
/// Convert the `FunctionType` into a [`CallableType`]. /// Convert the `FunctionType` into a [`CallableType`].
@ -1017,12 +1034,19 @@ impl<'db> FunctionType<'db> {
} }
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
let mappings: Box<_> = self let literal = self.literal(db).normalized_impl(db, visitor);
.type_mappings(db) let updated_signature = self
.iter() .updated_signature(db)
.map(|mapping| mapping.normalized_impl(db, visitor)) .map(|signature| signature.normalized_impl(db, visitor));
.collect(); let updated_last_definition_signature = self
Self::new(db, self.literal(db).normalized_impl(db, visitor), mappings) .updated_last_definition_signature(db)
.map(|signature| signature.normalized_impl(db, visitor));
Self::new(
db,
literal,
updated_signature,
updated_last_definition_signature,
)
} }
} }

View file

@ -1,5 +1,3 @@
use std::borrow::Cow;
use crate::types::constraints::ConstraintSet; use crate::types::constraints::ConstraintSet;
use itertools::Itertools; use itertools::Itertools;
@ -392,7 +390,7 @@ impl<'db> GenericContext<'db> {
// requirement for legacy contexts.) // requirement for legacy contexts.)
let partial = PartialSpecialization { let partial = PartialSpecialization {
generic_context: self, generic_context: self,
types: Cow::Borrowed(&expanded[0..idx]), types: &expanded[0..idx],
}; };
let default = let default =
default.apply_type_mapping(db, &TypeMapping::PartialSpecialization(partial)); default.apply_type_mapping(db, &TypeMapping::PartialSpecialization(partial));
@ -947,18 +945,7 @@ impl<'db> Specialization<'db> {
#[derive(Clone, Debug, Eq, Hash, PartialEq, get_size2::GetSize)] #[derive(Clone, Debug, Eq, Hash, PartialEq, get_size2::GetSize)]
pub struct PartialSpecialization<'a, 'db> { pub struct PartialSpecialization<'a, 'db> {
generic_context: GenericContext<'db>, generic_context: GenericContext<'db>,
types: Cow<'a, [Type<'db>]>, types: &'a [Type<'db>],
}
pub(super) fn walk_partial_specialization<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>(
db: &'db dyn Db,
specialization: &PartialSpecialization<'_, 'db>,
visitor: &V,
) {
walk_generic_context(db, specialization.generic_context, visitor);
for ty in &*specialization.types {
visitor.visit_type(db, *ty);
}
} }
impl<'db> PartialSpecialization<'_, 'db> { impl<'db> PartialSpecialization<'_, 'db> {
@ -975,31 +962,6 @@ impl<'db> PartialSpecialization<'_, 'db> {
.get_index_of(&bound_typevar)?; .get_index_of(&bound_typevar)?;
self.types.get(index).copied() self.types.get(index).copied()
} }
pub(crate) fn to_owned(&self) -> PartialSpecialization<'db, 'db> {
PartialSpecialization {
generic_context: self.generic_context,
types: Cow::from(self.types.clone().into_owned()),
}
}
pub(crate) fn normalized_impl(
&self,
db: &'db dyn Db,
visitor: &NormalizedVisitor<'db>,
) -> PartialSpecialization<'db, 'db> {
let generic_context = self.generic_context.normalized_impl(db, visitor);
let types: Cow<_> = self
.types
.iter()
.map(|ty| ty.normalized_impl(db, visitor))
.collect();
PartialSpecialization {
generic_context,
types,
}
}
} }
/// Performs type inference between parameter annotations and argument types, producing a /// Performs type inference between parameter annotations and argument types, producing a

View file

@ -2144,12 +2144,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let function_literal = let function_literal =
FunctionLiteral::new(self.db(), overload_literal, inherited_generic_context); FunctionLiteral::new(self.db(), overload_literal, inherited_generic_context);
let type_mappings = Box::default(); let mut inferred_ty =
let mut inferred_ty = Type::FunctionLiteral(FunctionType::new( Type::FunctionLiteral(FunctionType::new(self.db(), function_literal, None, None));
self.db(),
function_literal,
type_mappings,
));
self.undecorated_type = Some(inferred_ty); self.undecorated_type = Some(inferred_ty);
for (decorator_ty, decorator_node) in decorator_types_and_nodes.iter().rev() { for (decorator_ty, decorator_node) in decorator_types_and_nodes.iter().rev() {

View file

@ -62,6 +62,17 @@ impl<'db> CallableSignature<'db> {
self.overloads.iter() self.overloads.iter()
} }
pub(crate) fn with_inherited_generic_context(
&self,
inherited_generic_context: Option<GenericContext<'db>>,
) -> Self {
Self::from_overloads(self.overloads.iter().map(|signature| {
signature
.clone()
.with_inherited_generic_context(inherited_generic_context)
}))
}
pub(crate) fn normalized_impl( pub(crate) fn normalized_impl(
&self, &self,
db: &'db dyn Db, db: &'db dyn Db,
@ -451,14 +462,6 @@ impl<'db> Signature<'db> {
} }
} }
pub(crate) fn apply_type_mapping<'a>(
&self,
db: &'db dyn Db,
type_mapping: &TypeMapping<'a, 'db>,
) -> Self {
self.apply_type_mapping_impl(db, type_mapping, &ApplyTypeMappingVisitor::default())
}
pub(crate) fn apply_type_mapping_impl<'a>( pub(crate) fn apply_type_mapping_impl<'a>(
&self, &self,
db: &'db dyn Db, db: &'db dyn Db,