[ty] Clean up inherited generic contexts (#20647)

We add an `inherited_generic_context` to the constructors of a generic
class. That lets us infer specializations of the class when invoking the
constructor. The constructor might itself be generic, in which case we
have to merge the list of typevars that we are willing to infer in the
constructor call.

Before we did that by tracking the two (and their specializations)
separately, with distinct `Option` fields/parameters. This PR updates
our call binding logic such that any given function call has _one_
optional generic context that we're willing to infer a specialization
for. If needed, we use the existing `GenericContext::merge` method to
create a new combined generic context for when the class and constructor
are both generic. This simplifies the call binding code considerably,
and is no more complex in the constructor call logic.

We also have a heuristic that we will promote any literals in the
specialized types of a generic class, but we don't promote literals in
the specialized types of the function itself. To handle this, we now
track this `should_promote_literals` property within `GenericContext`.
And moreover, we track this separately for each typevar, instead of a
single property for the generic context as a whole, so that we can
correctly merge the generic context of a constructor method (where the
option should be `false`) with the inherited generic context of its
containing class (where the option should be `true`).

---------

Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
This commit is contained in:
Douglas Creager 2025-10-03 13:55:43 -04:00 committed by GitHub
parent c91b457044
commit b83ac5e234
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 231 additions and 290 deletions

View file

@ -1,7 +1,7 @@
use std::sync::{LazyLock, Mutex}; use std::sync::{LazyLock, Mutex};
use get_size2::{GetSize, StandardTracker}; use get_size2::{GetSize, StandardTracker};
use ordermap::OrderSet; use ordermap::{OrderMap, OrderSet};
/// Returns the memory usage of the provided object, using a global tracker to avoid /// Returns the memory usage of the provided object, using a global tracker to avoid
/// double-counting shared objects. /// double-counting shared objects.
@ -18,3 +18,11 @@ pub fn heap_size<T: GetSize>(value: &T) -> usize {
pub fn order_set_heap_size<T: GetSize, S>(set: &OrderSet<T, S>) -> usize { pub fn order_set_heap_size<T: GetSize, S>(set: &OrderSet<T, S>) -> usize {
(set.capacity() * T::get_stack_size()) + set.iter().map(heap_size).sum::<usize>() (set.capacity() * T::get_stack_size()) + set.iter().map(heap_size).sum::<usize>()
} }
/// An implementation of [`GetSize::get_heap_size`] for [`OrderMap`].
pub fn order_map_heap_size<K: GetSize, V: GetSize, S>(map: &OrderMap<K, V, S>) -> usize {
(map.capacity() * (K::get_stack_size() + V::get_stack_size()))
+ (map.iter())
.map(|(k, v)| heap_size(k) + heap_size(v))
.sum::<usize>()
}

View file

@ -5456,28 +5456,19 @@ impl<'db> Type<'db> {
} }
} }
let new_specialization = new_call_outcome let specialize_constructor = |outcome: Option<Bindings<'db>>| {
.and_then(Result::ok) let (_, binding) = outcome
.as_ref() .as_ref()?
.and_then(Bindings::single_element) .single_element()?
.into_iter() .matching_overloads()
.flat_map(CallableBinding::matching_overloads) .next()?;
.next() binding.specialization()?.restrict(db, generic_context?)
.and_then(|(_, binding)| binding.inherited_specialization()) };
.filter(|specialization| {
Some(specialization.generic_context(db)) == generic_context let new_specialization =
}); specialize_constructor(new_call_outcome.and_then(Result::ok));
let init_specialization = init_call_outcome let init_specialization =
.and_then(Result::ok) specialize_constructor(init_call_outcome.and_then(Result::ok));
.as_ref()
.and_then(Bindings::single_element)
.into_iter()
.flat_map(CallableBinding::matching_overloads)
.next()
.and_then(|(_, binding)| binding.inherited_specialization())
.filter(|specialization| {
Some(specialization.generic_context(db)) == generic_context
});
let specialization = let specialization =
combine_specializations(db, new_specialization, init_specialization); combine_specializations(db, new_specialization, init_specialization);
let specialized = specialization let specialized = specialization
@ -6768,13 +6759,11 @@ impl<'db> TypeMapping<'_, 'db> {
db, db,
context context
.variables(db) .variables(db)
.iter() .filter(|var| !var.typevar(db).is_self(db)),
.filter(|var| !var.typevar(db).is_self(db))
.copied(),
), ),
TypeMapping::ReplaceSelf { new_upper_bound } => GenericContext::from_typevar_instances( TypeMapping::ReplaceSelf { new_upper_bound } => GenericContext::from_typevar_instances(
db, db,
context.variables(db).iter().map(|typevar| { context.variables(db).map(|typevar| {
if typevar.typevar(db).is_self(db) { if typevar.typevar(db).is_self(db) {
BoundTypeVarInstance::synthetic_self( BoundTypeVarInstance::synthetic_self(
db, db,
@ -6782,7 +6771,7 @@ impl<'db> TypeMapping<'_, 'db> {
typevar.binding_context(db), typevar.binding_context(db),
) )
} else { } else {
*typevar typevar
} }
}), }),
), ),

View file

@ -32,7 +32,7 @@ use crate::types::tuple::{TupleLength, TupleType};
use crate::types::{ use crate::types::{
BoundMethodType, ClassLiteral, DataclassParams, FieldInstance, KnownBoundMethodType, BoundMethodType, ClassLiteral, DataclassParams, FieldInstance, KnownBoundMethodType,
KnownClass, KnownInstanceType, MemberLookupPolicy, PropertyInstanceType, SpecialFormType, KnownClass, KnownInstanceType, MemberLookupPolicy, PropertyInstanceType, SpecialFormType,
TrackedConstraintSet, TypeAliasType, TypeContext, TypeMapping, UnionBuilder, UnionType, TrackedConstraintSet, TypeAliasType, TypeContext, UnionBuilder, UnionType,
WrapperDescriptorKind, enums, ide_support, todo_type, WrapperDescriptorKind, enums, ide_support, todo_type,
}; };
use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity}; use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity};
@ -1701,10 +1701,6 @@ impl<'db> CallableBinding<'db> {
parameter_type = parameter_type =
parameter_type.apply_specialization(db, specialization); parameter_type.apply_specialization(db, specialization);
} }
if let Some(inherited_specialization) = overload.inherited_specialization {
parameter_type =
parameter_type.apply_specialization(db, inherited_specialization);
}
union_parameter_types[parameter_index.saturating_sub(skipped_parameters)] union_parameter_types[parameter_index.saturating_sub(skipped_parameters)]
.add_in_place(parameter_type); .add_in_place(parameter_type);
} }
@ -1983,7 +1979,7 @@ impl<'db> CallableBinding<'db> {
for overload in overloads.iter().take(MAXIMUM_OVERLOADS) { for overload in overloads.iter().take(MAXIMUM_OVERLOADS) {
diag.info(format_args!( diag.info(format_args!(
" {}", " {}",
overload.signature(context.db(), None).display(context.db()) overload.signature(context.db()).display(context.db())
)); ));
} }
if overloads.len() > MAXIMUM_OVERLOADS { if overloads.len() > MAXIMUM_OVERLOADS {
@ -2444,7 +2440,6 @@ struct ArgumentTypeChecker<'a, 'db> {
errors: &'a mut Vec<BindingError<'db>>, errors: &'a mut Vec<BindingError<'db>>,
specialization: Option<Specialization<'db>>, specialization: Option<Specialization<'db>>,
inherited_specialization: Option<Specialization<'db>>,
} }
impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
@ -2466,7 +2461,6 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
call_expression_tcx, call_expression_tcx,
errors, errors,
specialization: None, specialization: None,
inherited_specialization: None,
} }
} }
@ -2498,9 +2492,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
} }
fn infer_specialization(&mut self) { fn infer_specialization(&mut self) {
if self.signature.generic_context.is_none() if self.signature.generic_context.is_none() {
&& self.signature.inherited_generic_context.is_none()
{
return; return;
} }
@ -2542,14 +2534,6 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
} }
self.specialization = self.signature.generic_context.map(|gc| builder.build(gc)); self.specialization = self.signature.generic_context.map(|gc| builder.build(gc));
self.inherited_specialization = self.signature.inherited_generic_context.map(|gc| {
// The inherited generic context is used when inferring the specialization of a generic
// class from a constructor call. In this case (only), we promote any typevars that are
// inferred as a literal to the corresponding instance type.
builder
.build(gc)
.apply_type_mapping(self.db, &TypeMapping::PromoteLiterals)
});
} }
fn check_argument_type( fn check_argument_type(
@ -2566,11 +2550,6 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
argument_type = argument_type.apply_specialization(self.db, specialization); argument_type = argument_type.apply_specialization(self.db, specialization);
expected_ty = expected_ty.apply_specialization(self.db, specialization); expected_ty = expected_ty.apply_specialization(self.db, specialization);
} }
if let Some(inherited_specialization) = self.inherited_specialization {
argument_type =
argument_type.apply_specialization(self.db, inherited_specialization);
expected_ty = expected_ty.apply_specialization(self.db, inherited_specialization);
}
// This is one of the few places where we want to check if there's _any_ specialization // This is one of the few places where we want to check if there's _any_ specialization
// where assignability holds; normally we want to check that assignability holds for // where assignability holds; normally we want to check that assignability holds for
// _all_ specializations. // _all_ specializations.
@ -2742,8 +2721,8 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
} }
} }
fn finish(self) -> (Option<Specialization<'db>>, Option<Specialization<'db>>) { fn finish(self) -> Option<Specialization<'db>> {
(self.specialization, self.inherited_specialization) self.specialization
} }
} }
@ -2807,10 +2786,6 @@ pub(crate) struct Binding<'db> {
/// The specialization that was inferred from the argument types, if the callable is generic. /// The specialization that was inferred from the argument types, if the callable is generic.
specialization: Option<Specialization<'db>>, specialization: Option<Specialization<'db>>,
/// The specialization that was inferred for a class method's containing generic class, if it
/// is being used to infer a specialization for the class.
inherited_specialization: Option<Specialization<'db>>,
/// Information about which parameter(s) each argument was matched with, in argument source /// Information about which parameter(s) each argument was matched with, in argument source
/// order. /// order.
argument_matches: Box<[MatchedArgument<'db>]>, argument_matches: Box<[MatchedArgument<'db>]>,
@ -2835,7 +2810,6 @@ impl<'db> Binding<'db> {
signature_type, signature_type,
return_ty: Type::unknown(), return_ty: Type::unknown(),
specialization: None, specialization: None,
inherited_specialization: None,
argument_matches: Box::from([]), argument_matches: Box::from([]),
variadic_argument_matched_to_variadic_parameter: false, variadic_argument_matched_to_variadic_parameter: false,
parameter_tys: Box::from([]), parameter_tys: Box::from([]),
@ -2906,15 +2880,10 @@ impl<'db> Binding<'db> {
checker.infer_specialization(); checker.infer_specialization();
checker.check_argument_types(); checker.check_argument_types();
(self.specialization, self.inherited_specialization) = checker.finish(); self.specialization = checker.finish();
if let Some(specialization) = self.specialization { if let Some(specialization) = self.specialization {
self.return_ty = self.return_ty.apply_specialization(db, specialization); self.return_ty = self.return_ty.apply_specialization(db, specialization);
} }
if let Some(inherited_specialization) = self.inherited_specialization {
self.return_ty = self
.return_ty
.apply_specialization(db, inherited_specialization);
}
} }
pub(crate) fn set_return_type(&mut self, return_ty: Type<'db>) { pub(crate) fn set_return_type(&mut self, return_ty: Type<'db>) {
@ -2925,8 +2894,8 @@ impl<'db> Binding<'db> {
self.return_ty self.return_ty
} }
pub(crate) fn inherited_specialization(&self) -> Option<Specialization<'db>> { pub(crate) fn specialization(&self) -> Option<Specialization<'db>> {
self.inherited_specialization self.specialization
} }
/// Returns the bound types for each parameter, in parameter source order, or `None` if no /// Returns the bound types for each parameter, in parameter source order, or `None` if no
@ -2988,7 +2957,6 @@ impl<'db> Binding<'db> {
BindingSnapshot { BindingSnapshot {
return_ty: self.return_ty, return_ty: self.return_ty,
specialization: self.specialization, specialization: self.specialization,
inherited_specialization: self.inherited_specialization,
argument_matches: self.argument_matches.clone(), argument_matches: self.argument_matches.clone(),
parameter_tys: self.parameter_tys.clone(), parameter_tys: self.parameter_tys.clone(),
errors: self.errors.clone(), errors: self.errors.clone(),
@ -2999,7 +2967,6 @@ impl<'db> Binding<'db> {
let BindingSnapshot { let BindingSnapshot {
return_ty, return_ty,
specialization, specialization,
inherited_specialization,
argument_matches, argument_matches,
parameter_tys, parameter_tys,
errors, errors,
@ -3007,7 +2974,6 @@ impl<'db> Binding<'db> {
self.return_ty = return_ty; self.return_ty = return_ty;
self.specialization = specialization; self.specialization = specialization;
self.inherited_specialization = inherited_specialization;
self.argument_matches = argument_matches; self.argument_matches = argument_matches;
self.parameter_tys = parameter_tys; self.parameter_tys = parameter_tys;
self.errors = errors; self.errors = errors;
@ -3027,7 +2993,6 @@ impl<'db> Binding<'db> {
fn reset(&mut self) { fn reset(&mut self) {
self.return_ty = Type::unknown(); self.return_ty = Type::unknown();
self.specialization = None; self.specialization = None;
self.inherited_specialization = None;
self.argument_matches = Box::from([]); self.argument_matches = Box::from([]);
self.parameter_tys = Box::from([]); self.parameter_tys = Box::from([]);
self.errors.clear(); self.errors.clear();
@ -3038,7 +3003,6 @@ impl<'db> Binding<'db> {
struct BindingSnapshot<'db> { struct BindingSnapshot<'db> {
return_ty: Type<'db>, return_ty: Type<'db>,
specialization: Option<Specialization<'db>>, specialization: Option<Specialization<'db>>,
inherited_specialization: Option<Specialization<'db>>,
argument_matches: Box<[MatchedArgument<'db>]>, argument_matches: Box<[MatchedArgument<'db>]>,
parameter_tys: Box<[Option<Type<'db>>]>, parameter_tys: Box<[Option<Type<'db>>]>,
errors: Vec<BindingError<'db>>, errors: Vec<BindingError<'db>>,
@ -3078,7 +3042,6 @@ impl<'db> CallableBindingSnapshot<'db> {
// ... and update the snapshot with the current state of the binding. // ... and update the snapshot with the current state of the binding.
snapshot.return_ty = binding.return_ty; snapshot.return_ty = binding.return_ty;
snapshot.specialization = binding.specialization; snapshot.specialization = binding.specialization;
snapshot.inherited_specialization = binding.inherited_specialization;
snapshot snapshot
.argument_matches .argument_matches
.clone_from(&binding.argument_matches); .clone_from(&binding.argument_matches);
@ -3373,7 +3336,7 @@ impl<'db> BindingError<'db> {
} }
diag.info(format_args!( diag.info(format_args!(
" {}", " {}",
overload.signature(context.db(), None).display(context.db()) overload.signature(context.db()).display(context.db())
)); ));
} }
if overloads.len() > MAXIMUM_OVERLOADS { if overloads.len() > MAXIMUM_OVERLOADS {

View file

@ -324,7 +324,6 @@ impl<'db> VarianceInferable<'db> for GenericAlias<'db> {
specialization specialization
.generic_context(db) .generic_context(db)
.variables(db) .variables(db)
.iter()
.zip(specialization.types(db)) .zip(specialization.types(db))
.map(|(generic_typevar, ty)| { .map(|(generic_typevar, ty)| {
if let Some(explicit_variance) = if let Some(explicit_variance) =
@ -346,7 +345,7 @@ impl<'db> VarianceInferable<'db> for GenericAlias<'db> {
let typevar_variance_in_substituted_type = ty.variance_of(db, typevar); let typevar_variance_in_substituted_type = ty.variance_of(db, typevar);
origin origin
.with_polarity(typevar_variance_in_substituted_type) .with_polarity(typevar_variance_in_substituted_type)
.variance_of(db, *generic_typevar) .variance_of(db, generic_typevar)
} }
}), }),
) )
@ -1013,8 +1012,7 @@ impl<'db> ClassType<'db> {
let synthesized_dunder = CallableType::function_like( let synthesized_dunder = CallableType::function_like(
db, db,
Signature::new(parameters, None) Signature::new_generic(inherited_generic_context, parameters, None),
.with_inherited_generic_context(inherited_generic_context),
); );
Place::bound(synthesized_dunder).into() Place::bound(synthesized_dunder).into()
@ -1454,6 +1452,16 @@ impl<'db> ClassLiteral<'db> {
) )
} }
/// Returns the generic context that should be inherited by any constructor methods of this
/// class.
///
/// When inferring a specialization of the class's generic context from a constructor call, we
/// promote any typevars that are inferred as a literal to the corresponding instance type.
fn inherited_generic_context(self, db: &'db dyn Db) -> Option<GenericContext<'db>> {
self.generic_context(db)
.map(|generic_context| generic_context.promote_literals(db))
}
fn file(self, db: &dyn Db) -> File { fn file(self, db: &dyn Db) -> File {
self.body_scope(db).file(db) self.body_scope(db).file(db)
} }
@ -1996,7 +2004,7 @@ impl<'db> ClassLiteral<'db> {
lookup_result = lookup_result.or_else(|lookup_error| { lookup_result = lookup_result.or_else(|lookup_error| {
lookup_error.or_fall_back_to( lookup_error.or_fall_back_to(
db, db,
class.own_class_member(db, self.generic_context(db), name), class.own_class_member(db, self.inherited_generic_context(db), name),
) )
}); });
} }
@ -2246,8 +2254,14 @@ impl<'db> ClassLiteral<'db> {
// so that the keyword-only parameters appear after positional parameters. // so that the keyword-only parameters appear after positional parameters.
parameters.sort_by_key(Parameter::is_keyword_only); parameters.sort_by_key(Parameter::is_keyword_only);
let mut signature = Signature::new(Parameters::new(parameters), return_ty); let signature = match name {
signature.inherited_generic_context = self.generic_context(db); "__new__" | "__init__" => Signature::new_generic(
self.inherited_generic_context(db),
Parameters::new(parameters),
return_ty,
),
_ => Signature::new(Parameters::new(parameters), return_ty),
};
Some(CallableType::function_like(db, signature)) Some(CallableType::function_like(db, signature))
}; };
@ -2295,7 +2309,7 @@ impl<'db> ClassLiteral<'db> {
KnownClass::NamedTupleFallback KnownClass::NamedTupleFallback
.to_class_literal(db) .to_class_literal(db)
.into_class_literal()? .into_class_literal()?
.own_class_member(db, self.generic_context(db), None, name) .own_class_member(db, self.inherited_generic_context(db), None, name)
.place .place
.ignore_possibly_unbound() .ignore_possibly_unbound()
.map(|ty| { .map(|ty| {
@ -5421,7 +5435,7 @@ enum SlotsKind {
impl SlotsKind { impl SlotsKind {
fn from(db: &dyn Db, base: ClassLiteral) -> Self { fn from(db: &dyn Db, base: ClassLiteral) -> Self {
let Place::Type(slots_ty, bound) = base let Place::Type(slots_ty, bound) = base
.own_class_member(db, base.generic_context(db), None, "__slots__") .own_class_member(db, base.inherited_generic_context(db), None, "__slots__")
.place .place
else { else {
return Self::NotSpecified; return Self::NotSpecified;

View file

@ -654,7 +654,7 @@ pub(crate) struct DisplayOverloadLiteral<'db> {
impl Display for DisplayOverloadLiteral<'_> { impl Display for DisplayOverloadLiteral<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let signature = self.literal.signature(self.db, None); let signature = self.literal.signature(self.db);
let type_parameters = DisplayOptionalGenericContext { let type_parameters = DisplayOptionalGenericContext {
generic_context: signature.generic_context.as_ref(), generic_context: signature.generic_context.as_ref(),
db: self.db, db: self.db,
@ -832,7 +832,6 @@ impl Display for DisplayGenericContext<'_> {
let variables = self.generic_context.variables(self.db); let variables = self.generic_context.variables(self.db);
let non_implicit_variables: Vec<_> = variables let non_implicit_variables: Vec<_> = variables
.iter()
.filter(|bound_typevar| !bound_typevar.typevar(self.db).is_self(self.db)) .filter(|bound_typevar| !bound_typevar.typevar(self.db).is_self(self.db))
.collect(); .collect();
@ -852,6 +851,10 @@ impl Display for DisplayGenericContext<'_> {
} }
impl<'db> Specialization<'db> { impl<'db> Specialization<'db> {
pub fn display(&'db self, db: &'db dyn Db) -> DisplaySpecialization<'db> {
self.display_short(db, TupleSpecialization::No, DisplaySettings::default())
}
/// Renders the specialization as it would appear in a subscript expression, e.g. `[int, str]`. /// Renders the specialization as it would appear in a subscript expression, e.g. `[int, str]`.
pub fn display_short( pub fn display_short(
&'db self, &'db self,

View file

@ -72,7 +72,7 @@ use crate::types::diagnostic::{
report_bad_argument_to_get_protocol_members, report_bad_argument_to_protocol_interface, report_bad_argument_to_get_protocol_members, report_bad_argument_to_protocol_interface,
report_runtime_check_against_non_runtime_checkable_protocol, report_runtime_check_against_non_runtime_checkable_protocol,
}; };
use crate::types::generics::{GenericContext, walk_generic_context}; use crate::types::generics::GenericContext;
use crate::types::narrow::ClassInfoConstraintFunction; use crate::types::narrow::ClassInfoConstraintFunction;
use crate::types::signatures::{CallableSignature, Signature}; use crate::types::signatures::{CallableSignature, Signature};
use crate::types::visitor::any_over_type; use crate::types::visitor::any_over_type;
@ -338,11 +338,7 @@ impl<'db> OverloadLiteral<'db> {
/// calling query is not in the same file as this function is defined in, then this will create /// calling query is not in the same file as this function is defined in, then this will create
/// a cross-module dependency directly on the full AST which will lead to cache /// a cross-module dependency directly on the full AST which will lead to cache
/// over-invalidation. /// over-invalidation.
pub(crate) fn signature( pub(crate) fn signature(self, db: &'db dyn Db) -> Signature<'db> {
self,
db: &'db dyn Db,
inherited_generic_context: Option<GenericContext<'db>>,
) -> Signature<'db> {
/// `self` or `cls` can be implicitly positional-only if: /// `self` or `cls` can be implicitly positional-only if:
/// - It is a method AND /// - It is a method AND
/// - No parameters in the method use PEP-570 syntax AND /// - No parameters in the method use PEP-570 syntax AND
@ -420,7 +416,6 @@ impl<'db> OverloadLiteral<'db> {
Signature::from_function( Signature::from_function(
db, db,
generic_context, generic_context,
inherited_generic_context,
definition, definition,
function_stmt_node, function_stmt_node,
is_generator, is_generator,
@ -484,58 +479,13 @@ impl<'db> OverloadLiteral<'db> {
#[derive(PartialOrd, Ord)] #[derive(PartialOrd, Ord)]
pub struct FunctionLiteral<'db> { pub struct FunctionLiteral<'db> {
pub(crate) last_definition: OverloadLiteral<'db>, pub(crate) last_definition: OverloadLiteral<'db>,
/// The inherited generic context, if this function is a constructor method (`__new__` or
/// `__init__`) being used to infer the specialization of its generic class. If any of the
/// method's overloads are themselves generic, this is in addition to those per-overload
/// generic contexts (which are created lazily in [`OverloadLiteral::signature`]).
///
/// If the function is not a constructor method, this field will always be `None`.
///
/// If the function is a constructor method, we will end up creating two `FunctionLiteral`
/// instances for it. The first is created in [`TypeInferenceBuilder`][infer] when we encounter
/// the function definition during type inference. At this point, we don't yet know if the
/// function is a constructor method, so we create a `FunctionLiteral` with `None` for this
/// field.
///
/// If at some point we encounter a call expression, which invokes the containing class's
/// constructor, as will create a _new_ `FunctionLiteral` instance for the function, with this
/// field [updated][] to contain the containing class's generic context.
///
/// [infer]: crate::types::infer::TypeInferenceBuilder::infer_function_definition
/// [updated]: crate::types::class::ClassLiteral::own_class_member
inherited_generic_context: Option<GenericContext<'db>>,
} }
// The Salsa heap is tracked separately. // The Salsa heap is tracked separately.
impl get_size2::GetSize for FunctionLiteral<'_> {} impl get_size2::GetSize for FunctionLiteral<'_> {}
fn walk_function_literal<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>(
db: &'db dyn Db,
function: FunctionLiteral<'db>,
visitor: &V,
) {
if let Some(context) = function.inherited_generic_context(db) {
walk_generic_context(db, context, visitor);
}
}
#[salsa::tracked] #[salsa::tracked]
impl<'db> FunctionLiteral<'db> { impl<'db> FunctionLiteral<'db> {
fn with_inherited_generic_context(
self,
db: &'db dyn Db,
inherited_generic_context: GenericContext<'db>,
) -> Self {
// A function cannot inherit more than one generic context from its containing class.
debug_assert!(self.inherited_generic_context(db).is_none());
Self::new(
db,
self.last_definition(db),
Some(inherited_generic_context),
)
}
fn name(self, db: &'db dyn Db) -> &'db ast::name::Name { fn name(self, db: &'db dyn Db) -> &'db ast::name::Name {
// All of the overloads of a function literal should have the same name. // All of the overloads of a function literal should have the same name.
self.last_definition(db).name(db) self.last_definition(db).name(db)
@ -626,21 +576,14 @@ impl<'db> FunctionLiteral<'db> {
fn signature(self, db: &'db dyn Db) -> CallableSignature<'db> { fn signature(self, db: &'db dyn Db) -> CallableSignature<'db> {
// We only include an implementation (i.e. a definition not decorated with `@overload`) if // We only include an implementation (i.e. a definition not decorated with `@overload`) if
// it's the only definition. // it's the only definition.
let inherited_generic_context = self.inherited_generic_context(db);
let (overloads, implementation) = self.overloads_and_implementation(db); let (overloads, implementation) = self.overloads_and_implementation(db);
if let Some(implementation) = implementation { if let Some(implementation) = implementation {
if overloads.is_empty() { if overloads.is_empty() {
return CallableSignature::single( return CallableSignature::single(implementation.signature(db));
implementation.signature(db, inherited_generic_context),
);
} }
} }
CallableSignature::from_overloads( CallableSignature::from_overloads(overloads.iter().map(|overload| overload.signature(db)))
overloads
.iter()
.map(|overload| overload.signature(db, inherited_generic_context)),
)
} }
/// Typed externally-visible signature of the last overload or implementation of this function. /// Typed externally-visible signature of the last overload or implementation of this function.
@ -652,16 +595,7 @@ impl<'db> FunctionLiteral<'db> {
/// a cross-module dependency directly on the full AST which will lead to cache /// a cross-module dependency directly on the full AST which will lead to cache
/// over-invalidation. /// over-invalidation.
fn last_definition_signature(self, db: &'db dyn Db) -> Signature<'db> { fn last_definition_signature(self, db: &'db dyn Db) -> Signature<'db> {
let inherited_generic_context = self.inherited_generic_context(db); self.last_definition(db).signature(db)
self.last_definition(db)
.signature(db, inherited_generic_context)
}
fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
let context = self
.inherited_generic_context(db)
.map(|ctx| ctx.normalized_impl(db, visitor));
Self::new(db, self.last_definition(db), context)
} }
} }
@ -695,7 +629,6 @@ pub(super) fn walk_function_type<'db, V: super::visitor::TypeVisitor<'db> + ?Siz
function: FunctionType<'db>, function: FunctionType<'db>,
visitor: &V, visitor: &V,
) { ) {
walk_function_literal(db, function.literal(db), visitor);
if let Some(callable_signature) = function.updated_signature(db) { if let Some(callable_signature) = function.updated_signature(db) {
for signature in &callable_signature.overloads { for signature in &callable_signature.overloads {
walk_signature(db, signature, visitor); walk_signature(db, signature, visitor);
@ -713,23 +646,18 @@ impl<'db> FunctionType<'db> {
db: &'db dyn Db, db: &'db dyn Db,
inherited_generic_context: GenericContext<'db>, inherited_generic_context: GenericContext<'db>,
) -> Self { ) -> Self {
let literal = self let updated_signature = self
.literal(db) .signature(db)
.with_inherited_generic_context(db, inherited_generic_context); .with_inherited_generic_context(db, inherited_generic_context);
let updated_signature = self.updated_signature(db).map(|signature| { let updated_last_definition_signature = self
signature.with_inherited_generic_context(Some(inherited_generic_context)) .last_definition_signature(db)
});
let updated_last_definition_signature =
self.updated_last_definition_signature(db).map(|signature| {
signature
.clone() .clone()
.with_inherited_generic_context(Some(inherited_generic_context)) .with_inherited_generic_context(db, inherited_generic_context);
});
Self::new( Self::new(
db, db,
literal, self.literal(db),
updated_signature, Some(updated_signature),
updated_last_definition_signature, Some(updated_last_definition_signature),
) )
} }
@ -764,8 +692,7 @@ impl<'db> FunctionType<'db> {
let last_definition = literal let last_definition = literal
.last_definition(db) .last_definition(db)
.with_dataclass_transformer_params(db, params); .with_dataclass_transformer_params(db, params);
let literal = let literal = FunctionLiteral::new(db, last_definition);
FunctionLiteral::new(db, last_definition, literal.inherited_generic_context(db));
Self::new(db, literal, None, None) Self::new(db, literal, None, None)
} }
@ -1036,7 +963,7 @@ impl<'db> FunctionType<'db> {
} }
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
let literal = self.literal(db).normalized_impl(db, visitor); let literal = self.literal(db);
let updated_signature = self let updated_signature = self
.updated_signature(db) .updated_signature(db)
.map(|signature| signature.normalized_impl(db, visitor)); .map(|signature| signature.normalized_impl(db, visitor));

View file

@ -19,7 +19,7 @@ use crate::types::{
NormalizedVisitor, Type, TypeMapping, TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, NormalizedVisitor, Type, TypeMapping, TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance,
TypeVarKind, TypeVarVariance, UnionType, binding_type, declaration_type, TypeVarKind, TypeVarVariance, UnionType, binding_type, declaration_type,
}; };
use crate::{Db, FxOrderSet}; use crate::{Db, FxOrderMap, FxOrderSet};
/// Returns an iterator of any generic context introduced by the given scope or any enclosing /// Returns an iterator of any generic context introduced by the given scope or any enclosing
/// scope. /// scope.
@ -137,19 +137,28 @@ pub(crate) fn typing_self<'db>(
.map(typevar_to_type) .map(typevar_to_type)
} }
#[derive(Copy, Clone, Debug, Default, Eq, Hash, PartialEq, get_size2::GetSize)]
pub struct GenericContextTypeVarOptions {
should_promote_literals: bool,
}
impl GenericContextTypeVarOptions {
fn promote_literals(mut self) -> Self {
self.should_promote_literals = true;
self
}
}
/// A list of formal type variables for a generic function, class, or type alias. /// A list of formal type variables for a generic function, class, or type alias.
/// ///
/// TODO: Handle nested generic contexts better, with actual parent links to the lexically
/// containing context.
///
/// # Ordering /// # Ordering
/// Ordering is based on the context's salsa-assigned id and not on its values. /// Ordering is based on the context's salsa-assigned id and not on its values.
/// The id may change between runs, or when the context was garbage collected and recreated. /// The id may change between runs, or when the context was garbage collected and recreated.
#[salsa::interned(debug, heap_size=GenericContext::heap_size)] #[salsa::interned(debug, constructor=new_internal, heap_size=GenericContext::heap_size)]
#[derive(PartialOrd, Ord)] #[derive(PartialOrd, Ord)]
pub struct GenericContext<'db> { pub struct GenericContext<'db> {
#[returns(ref)] #[returns(ref)]
pub(crate) variables: FxOrderSet<BoundTypeVarInstance<'db>>, variables_inner: FxOrderMap<BoundTypeVarInstance<'db>, GenericContextTypeVarOptions>,
} }
pub(super) fn walk_generic_context<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>( pub(super) fn walk_generic_context<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>(
@ -158,7 +167,7 @@ pub(super) fn walk_generic_context<'db, V: super::visitor::TypeVisitor<'db> + ?S
visitor: &V, visitor: &V,
) { ) {
for bound_typevar in context.variables(db) { for bound_typevar in context.variables(db) {
visitor.visit_bound_type_var_type(db, *bound_typevar); visitor.visit_bound_type_var_type(db, bound_typevar);
} }
} }
@ -166,6 +175,13 @@ pub(super) fn walk_generic_context<'db, V: super::visitor::TypeVisitor<'db> + ?S
impl get_size2::GetSize for GenericContext<'_> {} impl get_size2::GetSize for GenericContext<'_> {}
impl<'db> GenericContext<'db> { impl<'db> GenericContext<'db> {
fn from_variables(
db: &'db dyn Db,
variables: impl IntoIterator<Item = (BoundTypeVarInstance<'db>, GenericContextTypeVarOptions)>,
) -> Self {
Self::new_internal(db, variables.into_iter().collect::<FxOrderMap<_, _>>())
}
/// Creates a generic context from a list of PEP-695 type parameters. /// Creates a generic context from a list of PEP-695 type parameters.
pub(crate) fn from_type_params( pub(crate) fn from_type_params(
db: &'db dyn Db, db: &'db dyn Db,
@ -185,21 +201,44 @@ impl<'db> GenericContext<'db> {
db: &'db dyn Db, db: &'db dyn Db,
type_params: impl IntoIterator<Item = BoundTypeVarInstance<'db>>, type_params: impl IntoIterator<Item = BoundTypeVarInstance<'db>>,
) -> Self { ) -> Self {
Self::new(db, type_params.into_iter().collect::<FxOrderSet<_>>()) Self::from_variables(
db,
type_params
.into_iter()
.map(|bound_typevar| (bound_typevar, GenericContextTypeVarOptions::default())),
)
}
/// Returns a copy of this generic context where we will promote literal types in any inferred
/// specializations.
pub(crate) fn promote_literals(self, db: &'db dyn Db) -> Self {
Self::from_variables(
db,
self.variables_inner(db)
.iter()
.map(|(bound_typevar, options)| (*bound_typevar, options.promote_literals())),
)
} }
/// Merge this generic context with another, returning a new generic context that /// Merge this generic context with another, returning a new generic context that
/// contains type variables from both contexts. /// contains type variables from both contexts.
pub(crate) fn merge(self, db: &'db dyn Db, other: Self) -> Self { pub(crate) fn merge(self, db: &'db dyn Db, other: Self) -> Self {
Self::from_typevar_instances( Self::from_variables(
db, db,
self.variables(db) self.variables_inner(db)
.iter() .iter()
.chain(other.variables(db).iter()) .chain(other.variables_inner(db).iter())
.copied(), .map(|(bound_typevar, options)| (*bound_typevar, *options)),
) )
} }
pub(crate) fn variables(
self,
db: &'db dyn Db,
) -> impl ExactSizeIterator<Item = BoundTypeVarInstance<'db>> + Clone {
self.variables_inner(db).keys().copied()
}
fn variable_from_type_param( fn variable_from_type_param(
db: &'db dyn Db, db: &'db dyn Db,
index: &'db SemanticIndex<'db>, index: &'db SemanticIndex<'db>,
@ -247,7 +286,7 @@ impl<'db> GenericContext<'db> {
if variables.is_empty() { if variables.is_empty() {
return None; return None;
} }
Some(Self::new(db, variables)) Some(Self::from_typevar_instances(db, variables))
} }
/// Creates a generic context from the legacy `TypeVar`s that appear in class's base class /// Creates a generic context from the legacy `TypeVar`s that appear in class's base class
@ -263,18 +302,17 @@ impl<'db> GenericContext<'db> {
if variables.is_empty() { if variables.is_empty() {
return None; return None;
} }
Some(Self::new(db, variables)) Some(Self::from_typevar_instances(db, variables))
} }
pub(crate) fn len(self, db: &'db dyn Db) -> usize { pub(crate) fn len(self, db: &'db dyn Db) -> usize {
self.variables(db).len() self.variables_inner(db).len()
} }
pub(crate) fn signature(self, db: &'db dyn Db) -> Signature<'db> { pub(crate) fn signature(self, db: &'db dyn Db) -> Signature<'db> {
let parameters = Parameters::new( let parameters = Parameters::new(
self.variables(db) self.variables(db)
.iter() .map(|typevar| Self::parameter_from_typevar(db, typevar)),
.map(|typevar| Self::parameter_from_typevar(db, *typevar)),
); );
Signature::new(parameters, None) Signature::new(parameters, None)
} }
@ -309,8 +347,7 @@ impl<'db> GenericContext<'db> {
db: &'db dyn Db, db: &'db dyn Db,
known_class: Option<KnownClass>, known_class: Option<KnownClass>,
) -> Specialization<'db> { ) -> Specialization<'db> {
let partial = let partial = self.specialize_partial(db, std::iter::repeat_n(None, self.len(db)));
self.specialize_partial(db, std::iter::repeat_n(None, self.variables(db).len()));
if known_class == Some(KnownClass::Tuple) { if known_class == Some(KnownClass::Tuple) {
Specialization::new( Specialization::new(
db, db,
@ -332,31 +369,24 @@ impl<'db> GenericContext<'db> {
db: &'db dyn Db, db: &'db dyn Db,
typevar_to_type: &impl Fn(BoundTypeVarInstance<'db>) -> Type<'db>, typevar_to_type: &impl Fn(BoundTypeVarInstance<'db>) -> Type<'db>,
) -> Specialization<'db> { ) -> Specialization<'db> {
let types = self let types = self.variables(db).map(typevar_to_type).collect();
.variables(db)
.iter()
.map(|typevar| typevar_to_type(*typevar))
.collect();
self.specialize(db, types) self.specialize(db, types)
} }
pub(crate) fn unknown_specialization(self, db: &'db dyn Db) -> Specialization<'db> { pub(crate) fn unknown_specialization(self, db: &'db dyn Db) -> Specialization<'db> {
let types = vec![Type::unknown(); self.variables(db).len()]; let types = vec![Type::unknown(); self.len(db)];
self.specialize(db, types.into()) self.specialize(db, types.into())
} }
/// Returns a tuple type of the typevars introduced by this generic context. /// Returns a tuple type of the typevars introduced by this generic context.
pub(crate) fn as_tuple(self, db: &'db dyn Db) -> Type<'db> { pub(crate) fn as_tuple(self, db: &'db dyn Db) -> Type<'db> {
Type::heterogeneous_tuple( Type::heterogeneous_tuple(db, self.variables(db).map(Type::TypeVar))
db,
self.variables(db)
.iter()
.map(|typevar| Type::TypeVar(*typevar)),
)
} }
pub(crate) fn is_subset_of(self, db: &'db dyn Db, other: GenericContext<'db>) -> bool { pub(crate) fn is_subset_of(self, db: &'db dyn Db, other: GenericContext<'db>) -> bool {
self.variables(db).is_subset(other.variables(db)) let other_variables = other.variables_inner(db);
self.variables(db)
.all(|bound_typevar| other_variables.contains_key(&bound_typevar))
} }
pub(crate) fn binds_typevar( pub(crate) fn binds_typevar(
@ -365,9 +395,7 @@ impl<'db> GenericContext<'db> {
typevar: TypeVarInstance<'db>, typevar: TypeVarInstance<'db>,
) -> Option<BoundTypeVarInstance<'db>> { ) -> Option<BoundTypeVarInstance<'db>> {
self.variables(db) self.variables(db)
.iter()
.find(|self_bound_typevar| self_bound_typevar.typevar(db) == typevar) .find(|self_bound_typevar| self_bound_typevar.typevar(db) == typevar)
.copied()
} }
/// Creates a specialization of this generic context. Panics if the length of `types` does not /// Creates a specialization of this generic context. Panics if the length of `types` does not
@ -379,7 +407,7 @@ impl<'db> GenericContext<'db> {
db: &'db dyn Db, db: &'db dyn Db,
types: Box<[Type<'db>]>, types: Box<[Type<'db>]>,
) -> Specialization<'db> { ) -> Specialization<'db> {
assert!(self.variables(db).len() == types.len()); assert!(self.len(db) == types.len());
Specialization::new(db, self, types, None, None) Specialization::new(db, self, types, None, None)
} }
@ -403,7 +431,7 @@ impl<'db> GenericContext<'db> {
{ {
let types = types.into_iter(); let types = types.into_iter();
let variables = self.variables(db); let variables = self.variables(db);
assert!(variables.len() == types.len()); assert!(self.len(db) == types.len());
// Typevars can have other typevars as their default values, e.g. // Typevars can have other typevars as their default values, e.g.
// //
@ -442,14 +470,15 @@ impl<'db> GenericContext<'db> {
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
let variables = self let variables = self
.variables(db) .variables(db)
.iter()
.map(|bound_typevar| bound_typevar.normalized_impl(db, visitor)); .map(|bound_typevar| bound_typevar.normalized_impl(db, visitor));
Self::from_typevar_instances(db, variables) Self::from_typevar_instances(db, variables)
} }
fn heap_size((variables,): &(FxOrderSet<BoundTypeVarInstance<'db>>,)) -> usize { fn heap_size(
ruff_memory_usage::order_set_heap_size(variables) (variables,): &(FxOrderMap<BoundTypeVarInstance<'db>, GenericContextTypeVarOptions>,),
) -> usize {
ruff_memory_usage::order_map_heap_size(variables)
} }
} }
@ -661,6 +690,31 @@ fn has_relation_in_invariant_position<'db>(
} }
impl<'db> Specialization<'db> { impl<'db> Specialization<'db> {
/// Restricts this specialization to only include the typevars in a generic context. If the
/// specialization does not include all of those typevars, returns `None`.
pub(crate) fn restrict(
self,
db: &'db dyn Db,
generic_context: GenericContext<'db>,
) -> Option<Self> {
let self_variables = self.generic_context(db).variables_inner(db);
let self_types = self.types(db);
let restricted_variables = generic_context.variables(db);
let restricted_types: Option<Box<[_]>> = restricted_variables
.map(|variable| {
let index = self_variables.get_index_of(&variable)?;
self_types.get(index).copied()
})
.collect();
Some(Self::new(
db,
generic_context,
restricted_types?,
self.materialization_kind(db),
None,
))
}
/// Returns the tuple spec for a specialization of the `tuple` class. /// Returns the tuple spec for a specialization of the `tuple` class.
pub(crate) fn tuple(self, db: &'db dyn Db) -> Option<&'db TupleSpec<'db>> { pub(crate) fn tuple(self, db: &'db dyn Db) -> Option<&'db TupleSpec<'db>> {
self.tuple_inner(db).map(|tuple_type| tuple_type.tuple(db)) self.tuple_inner(db).map(|tuple_type| tuple_type.tuple(db))
@ -675,7 +729,7 @@ impl<'db> Specialization<'db> {
) -> Option<Type<'db>> { ) -> Option<Type<'db>> {
let index = self let index = self
.generic_context(db) .generic_context(db)
.variables(db) .variables_inner(db)
.get_index_of(&bound_typevar)?; .get_index_of(&bound_typevar)?;
self.types(db).get(index).copied() self.types(db).get(index).copied()
} }
@ -813,7 +867,6 @@ impl<'db> Specialization<'db> {
let types: Box<[_]> = self let types: Box<[_]> = self
.generic_context(db) .generic_context(db)
.variables(db) .variables(db)
.into_iter()
.zip(self.types(db)) .zip(self.types(db))
.map(|(bound_typevar, vartype)| { .map(|(bound_typevar, vartype)| {
match bound_typevar.variance(db) { match bound_typevar.variance(db) {
@ -882,7 +935,7 @@ impl<'db> Specialization<'db> {
let other_materialization_kind = other.materialization_kind(db); let other_materialization_kind = other.materialization_kind(db);
let mut result = ConstraintSet::from(true); let mut result = ConstraintSet::from(true);
for ((bound_typevar, self_type), other_type) in (generic_context.variables(db).into_iter()) for ((bound_typevar, self_type), other_type) in (generic_context.variables(db))
.zip(self.types(db)) .zip(self.types(db))
.zip(other.types(db)) .zip(other.types(db))
{ {
@ -933,7 +986,7 @@ impl<'db> Specialization<'db> {
} }
let mut result = ConstraintSet::from(true); let mut result = ConstraintSet::from(true);
for ((bound_typevar, self_type), other_type) in (generic_context.variables(db).into_iter()) for ((bound_typevar, self_type), other_type) in (generic_context.variables(db))
.zip(self.types(db)) .zip(self.types(db))
.zip(other.types(db)) .zip(other.types(db))
{ {
@ -1005,7 +1058,7 @@ impl<'db> PartialSpecialization<'_, 'db> {
) -> Option<Type<'db>> { ) -> Option<Type<'db>> {
let index = self let index = self
.generic_context .generic_context
.variables(db) .variables_inner(db)
.get_index_of(&bound_typevar)?; .get_index_of(&bound_typevar)?;
self.types.get(index).copied() self.types.get(index).copied()
} }
@ -1027,10 +1080,18 @@ impl<'db> SpecializationBuilder<'db> {
} }
pub(crate) fn build(&mut self, generic_context: GenericContext<'db>) -> Specialization<'db> { pub(crate) fn build(&mut self, generic_context: GenericContext<'db>) -> Specialization<'db> {
let types = generic_context let types = (generic_context.variables_inner(self.db).iter()).map(|(variable, options)| {
.variables(self.db) let mut ty = self.types.get(variable).copied();
.iter()
.map(|variable| self.types.get(variable).copied()); // When inferring a specialization for a generic class typevar from a constructor call,
// promote any typevars that are inferred as a literal to the corresponding instance
// type.
if options.should_promote_literals {
ty = ty.map(|ty| ty.promote_literals(self.db));
}
ty
});
// TODO Infer the tuple spec for a tuple type // TODO Infer the tuple spec for a tuple type
generic_context.specialize_partial(self.db, types) generic_context.specialize_partial(self.db, types)
} }

View file

@ -88,13 +88,12 @@ use crate::types::typed_dict::{
}; };
use crate::types::visitor::any_over_type; use crate::types::visitor::any_over_type;
use crate::types::{ use crate::types::{
BoundTypeVarInstance, CallDunderError, CallableType, ClassLiteral, ClassType, DataclassParams, CallDunderError, CallableType, ClassLiteral, ClassType, DataclassParams, DynamicType,
DynamicType, IntersectionBuilder, IntersectionType, KnownClass, KnownInstanceType, IntersectionBuilder, IntersectionType, KnownClass, KnownInstanceType, MemberLookupPolicy,
MemberLookupPolicy, MetaclassCandidate, PEP695TypeAliasType, Parameter, ParameterForm, MetaclassCandidate, PEP695TypeAliasType, Parameter, ParameterForm, Parameters, SpecialFormType,
Parameters, SpecialFormType, SubclassOfType, TrackedConstraintSet, Truthiness, Type, SubclassOfType, TrackedConstraintSet, Truthiness, Type, TypeAliasType, TypeAndQualifiers,
TypeAliasType, TypeAndQualifiers, TypeContext, TypeQualifiers, TypeContext, TypeQualifiers, TypeVarBoundOrConstraintsEvaluation, TypeVarDefaultEvaluation,
TypeVarBoundOrConstraintsEvaluation, TypeVarDefaultEvaluation, TypeVarInstance, TypeVarKind, TypeVarInstance, TypeVarKind, UnionBuilder, UnionType, binding_type, todo_type,
UnionBuilder, UnionType, binding_type, todo_type,
}; };
use crate::types::{ClassBase, add_inferred_python_version_hint_to_diagnostic}; use crate::types::{ClassBase, add_inferred_python_version_hint_to_diagnostic};
use crate::unpack::{EvaluationMode, UnpackPosition}; use crate::unpack::{EvaluationMode, UnpackPosition};
@ -2141,10 +2140,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
deprecated, deprecated,
dataclass_transformer_params, dataclass_transformer_params,
); );
let function_literal = FunctionLiteral::new(self.db(), overload_literal);
let inherited_generic_context = None;
let function_literal =
FunctionLiteral::new(self.db(), overload_literal, inherited_generic_context);
let mut inferred_ty = let mut inferred_ty =
Type::FunctionLiteral(FunctionType::new(self.db(), function_literal, None, None)); Type::FunctionLiteral(FunctionType::new(self.db(), function_literal, None, None));
@ -5354,16 +5350,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
collection_class: KnownClass, collection_class: KnownClass,
) -> Option<Type<'db>> { ) -> Option<Type<'db>> {
// Extract the type variable `T` from `list[T]` in typeshed. // Extract the type variable `T` from `list[T]` in typeshed.
fn elt_tys( let elt_tys = |collection_class: KnownClass| {
collection_class: KnownClass, let class_literal = collection_class.try_to_class_literal(self.db())?;
db: &dyn Db, let generic_context = class_literal.generic_context(self.db())?;
) -> Option<(ClassLiteral<'_>, &FxOrderSet<BoundTypeVarInstance<'_>>)> { Some((class_literal, generic_context.variables(self.db())))
let class_literal = collection_class.try_to_class_literal(db)?; };
let generic_context = class_literal.generic_context(db)?;
Some((class_literal, generic_context.variables(db)))
}
let (class_literal, elt_tys) = elt_tys(collection_class, self.db()).unwrap_or_else(|| { let (class_literal, elt_tys) = elt_tys(collection_class).unwrap_or_else(|| {
let name = collection_class.name(self.db()); let name = collection_class.name(self.db());
panic!("Typeshed should always have a `{name}` class in `builtins.pyi`") panic!("Typeshed should always have a `{name}` class in `builtins.pyi`")
}); });
@ -5382,9 +5375,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// Note that we infer the annotated type _before_ the elements, to more closely match the // Note that we infer the annotated type _before_ the elements, to more closely match the
// order of any unions as written in the type annotation. // order of any unions as written in the type annotation.
Some(annotated_elt_tys) => { Some(annotated_elt_tys) => {
for (elt_ty, annotated_elt_ty) in iter::zip(elt_tys, annotated_elt_tys) { for (elt_ty, annotated_elt_ty) in iter::zip(elt_tys.clone(), annotated_elt_tys) {
builder builder
.infer(Type::TypeVar(*elt_ty), *annotated_elt_ty) .infer(Type::TypeVar(elt_ty), *annotated_elt_ty)
.ok()?; .ok()?;
} }
} }
@ -5392,10 +5385,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// If a valid type annotation was not provided, avoid restricting the type of the collection // If a valid type annotation was not provided, avoid restricting the type of the collection
// by unioning the inferred type with `Unknown`. // by unioning the inferred type with `Unknown`.
None => { None => {
for elt_ty in elt_tys { for elt_ty in elt_tys.clone() {
builder builder.infer(Type::TypeVar(elt_ty), Type::unknown()).ok()?;
.infer(Type::TypeVar(*elt_ty), Type::unknown())
.ok()?;
} }
} }
} }
@ -5415,10 +5406,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
inferred_value_ty.known_specialization(KnownClass::Dict, self.db()) inferred_value_ty.known_specialization(KnownClass::Dict, self.db())
{ {
for (elt_ty, inferred_elt_ty) in for (elt_ty, inferred_elt_ty) in
iter::zip(elt_tys, specialization.types(self.db())) iter::zip(elt_tys.clone(), specialization.types(self.db()))
{ {
builder builder
.infer(Type::TypeVar(*elt_ty), *inferred_elt_ty) .infer(Type::TypeVar(elt_ty), *inferred_elt_ty)
.ok()?; .ok()?;
} }
} }
@ -5427,7 +5418,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
} }
// The inferred type of each element acts as an additional constraint on `T`. // The inferred type of each element acts as an additional constraint on `T`.
for (elt, elt_ty, elt_tcx) in itertools::izip!(elts, elt_tys, elt_tcxs.clone()) { for (elt, elt_ty, elt_tcx) in itertools::izip!(elts, elt_tys.clone(), elt_tcxs.clone())
{
let Some(inferred_elt_ty) = self.infer_optional_expression(elt, elt_tcx) else { let Some(inferred_elt_ty) = self.infer_optional_expression(elt, elt_tcx) else {
continue; continue;
}; };
@ -5436,9 +5428,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// unions for large nested list literals, which the constraint solver struggles with. // unions for large nested list literals, which the constraint solver struggles with.
let inferred_elt_ty = inferred_elt_ty.promote_literals(self.db()); let inferred_elt_ty = inferred_elt_ty.promote_literals(self.db());
builder builder.infer(Type::TypeVar(elt_ty), inferred_elt_ty).ok()?;
.infer(Type::TypeVar(*elt_ty), inferred_elt_ty)
.ok()?;
} }
} }
@ -9012,7 +9002,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
} }
}) })
.collect(); .collect();
typevars.map(|typevars| GenericContext::new(self.db(), typevars)) typevars.map(|typevars| GenericContext::from_typevar_instances(self.db(), typevars))
} }
fn infer_slice_expression(&mut self, slice: &ast::ExprSlice) -> Type<'db> { fn infer_slice_expression(&mut self, slice: &ast::ExprSlice) -> Type<'db> {

View file

@ -102,12 +102,13 @@ impl<'db> CallableSignature<'db> {
pub(crate) fn with_inherited_generic_context( pub(crate) fn with_inherited_generic_context(
&self, &self,
inherited_generic_context: Option<GenericContext<'db>>, db: &'db dyn Db,
inherited_generic_context: GenericContext<'db>,
) -> Self { ) -> Self {
Self::from_overloads(self.overloads.iter().map(|signature| { Self::from_overloads(self.overloads.iter().map(|signature| {
signature signature
.clone() .clone()
.with_inherited_generic_context(inherited_generic_context) .with_inherited_generic_context(db, inherited_generic_context)
})) }))
} }
@ -301,11 +302,6 @@ pub struct Signature<'db> {
/// The generic context for this overload, if it is generic. /// The generic context for this overload, if it is generic.
pub(crate) generic_context: Option<GenericContext<'db>>, pub(crate) generic_context: Option<GenericContext<'db>>,
/// The inherited generic context, if this function is a class method being used to infer the
/// specialization of its generic class. If the method is itself generic, this is in addition
/// to its own generic context.
pub(crate) inherited_generic_context: Option<GenericContext<'db>>,
/// The original definition associated with this function, if available. /// The original definition associated with this function, if available.
/// This is useful for locating and extracting docstring information for the signature. /// This is useful for locating and extracting docstring information for the signature.
pub(crate) definition: Option<Definition<'db>>, pub(crate) definition: Option<Definition<'db>>,
@ -332,9 +328,6 @@ pub(super) fn walk_signature<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>(
if let Some(generic_context) = &signature.generic_context { if let Some(generic_context) = &signature.generic_context {
walk_generic_context(db, *generic_context, visitor); walk_generic_context(db, *generic_context, visitor);
} }
if let Some(inherited_generic_context) = &signature.inherited_generic_context {
walk_generic_context(db, *inherited_generic_context, visitor);
}
// By default we usually don't visit the type of the default value, // By default we usually don't visit the type of the default value,
// as it isn't relevant to most things // as it isn't relevant to most things
for parameter in &signature.parameters { for parameter in &signature.parameters {
@ -351,7 +344,6 @@ impl<'db> Signature<'db> {
pub(crate) fn new(parameters: Parameters<'db>, return_ty: Option<Type<'db>>) -> Self { pub(crate) fn new(parameters: Parameters<'db>, return_ty: Option<Type<'db>>) -> Self {
Self { Self {
generic_context: None, generic_context: None,
inherited_generic_context: None,
definition: None, definition: None,
parameters, parameters,
return_ty, return_ty,
@ -365,7 +357,6 @@ impl<'db> Signature<'db> {
) -> Self { ) -> Self {
Self { Self {
generic_context, generic_context,
inherited_generic_context: None,
definition: None, definition: None,
parameters, parameters,
return_ty, return_ty,
@ -376,7 +367,6 @@ impl<'db> Signature<'db> {
pub(crate) fn dynamic(signature_type: Type<'db>) -> Self { pub(crate) fn dynamic(signature_type: Type<'db>) -> Self {
Signature { Signature {
generic_context: None, generic_context: None,
inherited_generic_context: None,
definition: None, definition: None,
parameters: Parameters::gradual_form(), parameters: Parameters::gradual_form(),
return_ty: Some(signature_type), return_ty: Some(signature_type),
@ -389,7 +379,6 @@ impl<'db> Signature<'db> {
let signature_type = todo_type!(reason); let signature_type = todo_type!(reason);
Signature { Signature {
generic_context: None, generic_context: None,
inherited_generic_context: None,
definition: None, definition: None,
parameters: Parameters::todo(), parameters: Parameters::todo(),
return_ty: Some(signature_type), return_ty: Some(signature_type),
@ -400,7 +389,6 @@ impl<'db> Signature<'db> {
pub(super) fn from_function( pub(super) fn from_function(
db: &'db dyn Db, db: &'db dyn Db,
generic_context: Option<GenericContext<'db>>, generic_context: Option<GenericContext<'db>>,
inherited_generic_context: Option<GenericContext<'db>>,
definition: Definition<'db>, definition: Definition<'db>,
function_node: &ast::StmtFunctionDef, function_node: &ast::StmtFunctionDef,
is_generator: bool, is_generator: bool,
@ -434,7 +422,6 @@ impl<'db> Signature<'db> {
(Some(legacy_ctx), Some(ctx)) => { (Some(legacy_ctx), Some(ctx)) => {
if legacy_ctx if legacy_ctx
.variables(db) .variables(db)
.iter()
.exactly_one() .exactly_one()
.is_ok_and(|bound_typevar| bound_typevar.typevar(db).is_self(db)) .is_ok_and(|bound_typevar| bound_typevar.typevar(db).is_self(db))
{ {
@ -449,7 +436,6 @@ impl<'db> Signature<'db> {
Self { Self {
generic_context: full_generic_context, generic_context: full_generic_context,
inherited_generic_context,
definition: Some(definition), definition: Some(definition),
parameters, parameters,
return_ty, return_ty,
@ -468,9 +454,17 @@ impl<'db> Signature<'db> {
pub(crate) fn with_inherited_generic_context( pub(crate) fn with_inherited_generic_context(
mut self, mut self,
inherited_generic_context: Option<GenericContext<'db>>, db: &'db dyn Db,
inherited_generic_context: GenericContext<'db>,
) -> Self { ) -> Self {
self.inherited_generic_context = inherited_generic_context; match self.generic_context.as_mut() {
Some(generic_context) => {
*generic_context = generic_context.merge(db, inherited_generic_context);
}
None => {
self.generic_context = Some(inherited_generic_context);
}
}
self self
} }
@ -483,9 +477,6 @@ impl<'db> Signature<'db> {
generic_context: self generic_context: self
.generic_context .generic_context
.map(|ctx| ctx.normalized_impl(db, visitor)), .map(|ctx| ctx.normalized_impl(db, visitor)),
inherited_generic_context: self
.inherited_generic_context
.map(|ctx| ctx.normalized_impl(db, visitor)),
// Discard the definition when normalizing, so that two equivalent signatures // Discard the definition when normalizing, so that two equivalent signatures
// with different `Definition`s share the same Salsa ID when normalized // with different `Definition`s share the same Salsa ID when normalized
definition: None, definition: None,
@ -516,7 +507,6 @@ impl<'db> Signature<'db> {
generic_context: self generic_context: self
.generic_context .generic_context
.map(|context| type_mapping.update_signature_generic_context(db, context)), .map(|context| type_mapping.update_signature_generic_context(db, context)),
inherited_generic_context: self.inherited_generic_context,
definition: self.definition, definition: self.definition,
parameters: self parameters: self
.parameters .parameters
@ -571,7 +561,6 @@ impl<'db> Signature<'db> {
} }
Self { Self {
generic_context: self.generic_context, generic_context: self.generic_context,
inherited_generic_context: self.inherited_generic_context,
definition: self.definition, definition: self.definition,
parameters, parameters,
return_ty, return_ty,
@ -1236,10 +1225,7 @@ impl<'db> Parameters<'db> {
let method_has_self_in_generic_context = let method_has_self_in_generic_context =
method.signature(db).overloads.iter().any(|s| { method.signature(db).overloads.iter().any(|s| {
s.generic_context.is_some_and(|context| { s.generic_context.is_some_and(|context| {
context context.variables(db).any(|v| v.typevar(db).is_self(db))
.variables(db)
.iter()
.any(|v| v.typevar(db).is_self(db))
}) })
}); });
@ -1882,7 +1868,7 @@ mod tests {
.literal(&db) .literal(&db)
.last_definition(&db); .last_definition(&db);
let sig = func.signature(&db, None); let sig = func.signature(&db);
assert!(sig.return_ty.is_none()); assert!(sig.return_ty.is_none());
assert_params(&sig, &[]); assert_params(&sig, &[]);
@ -1907,7 +1893,7 @@ mod tests {
.literal(&db) .literal(&db)
.last_definition(&db); .last_definition(&db);
let sig = func.signature(&db, None); let sig = func.signature(&db);
assert_eq!(sig.return_ty.unwrap().display(&db).to_string(), "bytes"); assert_eq!(sig.return_ty.unwrap().display(&db).to_string(), "bytes");
assert_params( assert_params(
@ -1959,7 +1945,7 @@ mod tests {
.literal(&db) .literal(&db)
.last_definition(&db); .last_definition(&db);
let sig = func.signature(&db, None); let sig = func.signature(&db);
let [ let [
Parameter { Parameter {
@ -1997,7 +1983,7 @@ mod tests {
.literal(&db) .literal(&db)
.last_definition(&db); .last_definition(&db);
let sig = func.signature(&db, None); let sig = func.signature(&db);
let [ let [
Parameter { Parameter {
@ -2035,7 +2021,7 @@ mod tests {
.literal(&db) .literal(&db)
.last_definition(&db); .last_definition(&db);
let sig = func.signature(&db, None); let sig = func.signature(&db);
let [ let [
Parameter { Parameter {
@ -2079,7 +2065,7 @@ mod tests {
.literal(&db) .literal(&db)
.last_definition(&db); .last_definition(&db);
let sig = func.signature(&db, None); let sig = func.signature(&db);
let [ let [
Parameter { Parameter {
@ -2116,7 +2102,7 @@ mod tests {
let func = get_function_f(&db, "/src/a.py"); let func = get_function_f(&db, "/src/a.py");
let overload = func.literal(&db).last_definition(&db); let overload = func.literal(&db).last_definition(&db);
let expected_sig = overload.signature(&db, None); let expected_sig = overload.signature(&db);
// With no decorators, internal and external signature are the same // With no decorators, internal and external signature are the same
assert_eq!( assert_eq!(