[ty] Avoid unnecessarily widening generic specializations (#20875)

## Summary

Ignore the type context when specializing a generic call if it leads to
an unnecessarily wide return type. For example, [the example mentioned
here](https://github.com/astral-sh/ruff/pull/20796#issuecomment-3403319536)
works as expected after this change:
```py
def id[T](x: T) -> T:
    return x

def _(i: int):
    x: int | None = id(i)
    y: int | None = i
    reveal_type(x)  # revealed: int
    reveal_type(y)  # revealed: int
```

I also added extended our usage of `filter_disjoint_elements` to tuple
and typed-dict inference, which resolves
https://github.com/astral-sh/ty/issues/1266.
This commit is contained in:
Ibraheem Ahmed 2025-10-16 15:17:37 -04:00 committed by GitHub
parent 8dad58de37
commit 1ade4f2081
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 156 additions and 58 deletions

View file

@ -1233,22 +1233,35 @@ impl<'db> Type<'db> {
if yes { self.negate(db) } else { *self }
}
/// Remove the union elements that are not related to `target`.
/// If the type is a union, filters union elements based on the provided predicate.
///
/// Otherwise, returns the type unchanged.
pub(crate) fn filter_union(
self,
db: &'db dyn Db,
f: impl FnMut(&Type<'db>) -> bool,
) -> Type<'db> {
if let Type::Union(union) = self {
union.filter(db, f)
} else {
self
}
}
/// If the type is a union, removes union elements that are disjoint from `target`.
///
/// Otherwise, returns the type unchanged.
pub(crate) fn filter_disjoint_elements(
self,
db: &'db dyn Db,
target: Type<'db>,
inferable: InferableTypeVars<'_, 'db>,
) -> Type<'db> {
if let Type::Union(union) = self {
union.filter(db, |elem| {
!elem
.when_disjoint_from(db, target, inferable)
.is_always_satisfied()
})
} else {
self
}
self.filter_union(db, |elem| {
!elem
.when_disjoint_from(db, target, inferable)
.is_always_satisfied()
})
}
/// Returns the fallback instance type that a literal is an instance of, or `None` if the type
@ -11185,9 +11198,9 @@ impl<'db> UnionType<'db> {
pub(crate) fn filter(
self,
db: &'db dyn Db,
filter_fn: impl FnMut(&&Type<'db>) -> bool,
mut f: impl FnMut(&Type<'db>) -> bool,
) -> Type<'db> {
Self::from_elements(db, self.elements(db).iter().filter(filter_fn))
Self::from_elements(db, self.elements(db).iter().filter(|ty| f(ty)))
}
pub(crate) fn map_with_boundness(

View file

@ -2524,6 +2524,7 @@ struct ArgumentTypeChecker<'a, 'db> {
argument_matches: &'a [MatchedArgument<'db>],
parameter_tys: &'a mut [Option<Type<'db>>],
call_expression_tcx: &'a TypeContext<'db>,
return_ty: Type<'db>,
errors: &'a mut Vec<BindingError<'db>>,
inferable_typevars: InferableTypeVars<'db, 'db>,
@ -2531,6 +2532,7 @@ struct ArgumentTypeChecker<'a, 'db> {
}
impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
#[expect(clippy::too_many_arguments)]
fn new(
db: &'db dyn Db,
signature: &'a Signature<'db>,
@ -2538,6 +2540,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
argument_matches: &'a [MatchedArgument<'db>],
parameter_tys: &'a mut [Option<Type<'db>>],
call_expression_tcx: &'a TypeContext<'db>,
return_ty: Type<'db>,
errors: &'a mut Vec<BindingError<'db>>,
) -> Self {
Self {
@ -2547,6 +2550,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
argument_matches,
parameter_tys,
call_expression_tcx,
return_ty,
errors,
inferable_typevars: InferableTypeVars::None,
specialization: None,
@ -2588,25 +2592,6 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
// TODO: Use the list of inferable typevars from the generic context of the callable.
let mut builder = SpecializationBuilder::new(self.db, self.inferable_typevars);
// Note that we infer the annotated type _before_ the arguments if this call is part of
// an annotated assignment, to closer match the order of any unions written in the type
// annotation.
if let Some(return_ty) = self.signature.return_ty
&& let Some(call_expression_tcx) = self.call_expression_tcx.annotation
{
match call_expression_tcx {
// A type variable is not a useful type-context for expression inference, and applying it
// to the return type can lead to confusing unions in nested generic calls.
Type::TypeVar(_) => {}
_ => {
// Ignore any specialization errors here, because the type context is only used as a hint
// to infer a more assignable return type.
let _ = builder.infer(return_ty, call_expression_tcx);
}
}
}
let parameters = self.signature.parameters();
for (argument_index, adjusted_argument_index, _, argument_type) in
self.enumerate_argument_types()
@ -2631,7 +2616,41 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
}
}
self.specialization = Some(builder.build(generic_context, *self.call_expression_tcx));
// Build the specialization first without inferring the type context.
let isolated_specialization = builder.build(generic_context, *self.call_expression_tcx);
let isolated_return_ty = self
.return_ty
.apply_specialization(self.db, isolated_specialization);
let mut try_infer_tcx = || {
let return_ty = self.signature.return_ty?;
let call_expression_tcx = self.call_expression_tcx.annotation?;
// A type variable is not a useful type-context for expression inference, and applying it
// to the return type can lead to confusing unions in nested generic calls.
if call_expression_tcx.is_type_var() {
return None;
}
// If the return type is already assignable to the annotated type, we can ignore the
// type context and prefer the narrower inferred type.
if isolated_return_ty.is_assignable_to(self.db, call_expression_tcx) {
return None;
}
// TODO: Ideally we would infer the annotated type _before_ the arguments if this call is part of an
// annotated assignment, to closer match the order of any unions written in the type annotation.
builder.infer(return_ty, call_expression_tcx).ok()?;
// Otherwise, build the specialization again after inferring the type context.
let specialization = builder.build(generic_context, *self.call_expression_tcx);
let return_ty = return_ty.apply_specialization(self.db, specialization);
Some((Some(specialization), return_ty))
};
(self.specialization, self.return_ty) =
try_infer_tcx().unwrap_or((Some(isolated_specialization), isolated_return_ty));
}
fn check_argument_type(
@ -2826,8 +2845,14 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
}
}
fn finish(self) -> (InferableTypeVars<'db, 'db>, Option<Specialization<'db>>) {
(self.inferable_typevars, self.specialization)
fn finish(
self,
) -> (
InferableTypeVars<'db, 'db>,
Option<Specialization<'db>>,
Type<'db>,
) {
(self.inferable_typevars, self.specialization, self.return_ty)
}
}
@ -2985,18 +3010,16 @@ impl<'db> Binding<'db> {
&self.argument_matches,
&mut self.parameter_tys,
call_expression_tcx,
self.return_ty,
&mut self.errors,
);
// If this overload is generic, first see if we can infer a specialization of the function
// from the arguments that were passed in.
checker.infer_specialization();
checker.check_argument_types();
(self.inferable_typevars, self.specialization) = checker.finish();
if let Some(specialization) = self.specialization {
self.return_ty = self.return_ty.apply_specialization(db, specialization);
}
(self.inferable_typevars, self.specialization, self.return_ty) = checker.finish();
}
pub(crate) fn set_return_type(&mut self, return_ty: Type<'db>) {

View file

@ -1229,6 +1229,7 @@ impl<'db> SpecializationBuilder<'db> {
let tcx = tcx_specialization.and_then(|specialization| {
specialization.get(self.db, variable.bound_typevar)
});
ty = ty.map(|ty| ty.promote_literals(self.db, TypeContext::new(tcx)));
}
@ -1251,7 +1252,7 @@ impl<'db> SpecializationBuilder<'db> {
pub(crate) fn infer(
&mut self,
formal: Type<'db>,
mut actual: Type<'db>,
actual: Type<'db>,
) -> Result<(), SpecializationError<'db>> {
if formal == actual {
return Ok(());
@ -1282,9 +1283,11 @@ impl<'db> SpecializationBuilder<'db> {
return Ok(());
}
// For example, if `formal` is `list[T]` and `actual` is `list[int] | None`, we want to specialize `T` to `int`.
// So, here we remove the union elements that are not related to `formal`.
actual = actual.filter_disjoint_elements(self.db, formal, self.inferable);
// Remove the union elements that are not related to `formal`.
//
// For example, if `formal` is `list[T]` and `actual` is `list[int] | None`, we want to specialize `T`
// to `int`.
let actual = actual.filter_disjoint_elements(self.db, formal, self.inferable);
match (formal, actual) {
// TODO: We haven't implemented a full unification solver yet. If typevars appear in

View file

@ -391,7 +391,7 @@ impl<'db> TypeContext<'db> {
.and_then(|ty| ty.known_specialization(db, known_class))
}
pub(crate) fn map_annotation(self, f: impl FnOnce(Type<'db>) -> Type<'db>) -> Self {
pub(crate) fn map(self, f: impl FnOnce(Type<'db>) -> Type<'db>) -> Self {
Self {
annotation: self.annotation.map(f),
}

View file

@ -5890,6 +5890,18 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
parenthesized: _,
} = tuple;
// TODO: Use the list of inferable typevars from the generic context of tuple.
let inferable = InferableTypeVars::None;
// Remove any union elements of that are unrelated to the tuple type.
let tcx = tcx.map(|annotation| {
annotation.filter_disjoint_elements(
self.db(),
KnownClass::Tuple.to_instance(self.db()),
inferable,
)
});
let annotated_tuple = tcx
.known_specialization(self.db(), KnownClass::Tuple)
.and_then(|specialization| {
@ -5955,7 +5967,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
} = dict;
// Validate `TypedDict` dictionary literal assignments.
if let Some(typed_dict) = tcx.annotation.and_then(Type::as_typed_dict)
if let Some(tcx) = tcx.annotation
&& let Some(typed_dict) = tcx
.filter_union(self.db(), Type::is_typed_dict)
.as_typed_dict()
&& let Some(ty) = self.infer_typed_dict_expression(dict, typed_dict)
{
return ty;
@ -6038,9 +6053,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// TODO: Use the list of inferable typevars from the generic context of the collection
// class.
let inferable = InferableTypeVars::None;
let tcx = tcx.map_annotation(|annotation| {
// Remove any union elements of `annotation` that are not related to `collection_ty`.
// e.g. `annotation: list[int] | None => list[int]` if `collection_ty: list`
// Remove any union elements of that are unrelated to the collection type.
//
// For example, we only want the `list[int]` from `annotation: list[int] | None` if
// `collection_ty` is `list`.
let tcx = tcx.map(|annotation| {
let collection_ty = collection_class.to_instance(self.db());
annotation.filter_disjoint_elements(self.db(), collection_ty, inferable)
});