mirror of
https://github.com/astral-sh/ruff.git
synced 2025-10-22 00:01:56 +00:00
[ty] Avoid unnecessarily widening generic specializations (#20875)
## Summary Ignore the type context when specializing a generic call if it leads to an unnecessarily wide return type. For example, [the example mentioned here](https://github.com/astral-sh/ruff/pull/20796#issuecomment-3403319536) works as expected after this change: ```py def id[T](x: T) -> T: return x def _(i: int): x: int | None = id(i) y: int | None = i reveal_type(x) # revealed: int reveal_type(y) # revealed: int ``` I also added extended our usage of `filter_disjoint_elements` to tuple and typed-dict inference, which resolves https://github.com/astral-sh/ty/issues/1266.
This commit is contained in:
parent
8dad58de37
commit
1ade4f2081
8 changed files with 156 additions and 58 deletions
|
@ -190,8 +190,7 @@ k: list[tuple[list[int], ...]] | None = [([],), ([1, 2], [3, 4]), ([5], [6], [7]
|
|||
reveal_type(k) # revealed: list[tuple[list[int], ...]]
|
||||
|
||||
l: tuple[list[int], *tuple[list[typing.Any], ...], list[str]] | None = ([1, 2, 3], [4, 5, 6], [7, 8, 9], ["10", "11", "12"])
|
||||
# TODO: this should be `tuple[list[int], list[Any | int], list[Any | int], list[str]]`
|
||||
reveal_type(l) # revealed: tuple[list[Unknown | int], list[Unknown | int], list[Unknown | int], list[Unknown | str]]
|
||||
reveal_type(l) # revealed: tuple[list[int], list[Any | int], list[Any | int], list[str]]
|
||||
|
||||
type IntList = list[int]
|
||||
|
||||
|
@ -416,13 +415,14 @@ a = f("a")
|
|||
reveal_type(a) # revealed: list[Literal["a"]]
|
||||
|
||||
b: list[int | Literal["a"]] = f("a")
|
||||
reveal_type(b) # revealed: list[int | Literal["a"]]
|
||||
reveal_type(b) # revealed: list[Literal["a"] | int]
|
||||
|
||||
c: list[int | str] = f("a")
|
||||
reveal_type(c) # revealed: list[int | str]
|
||||
reveal_type(c) # revealed: list[str | int]
|
||||
|
||||
d: list[int | tuple[int, int]] = f((1, 2))
|
||||
reveal_type(d) # revealed: list[int | tuple[int, int]]
|
||||
# TODO: We could avoid reordering the union elements here.
|
||||
reveal_type(d) # revealed: list[tuple[int, int] | int]
|
||||
|
||||
e: list[int] = f(True)
|
||||
reveal_type(e) # revealed: list[int]
|
||||
|
@ -437,8 +437,49 @@ def f2[T: int](x: T) -> T:
|
|||
return x
|
||||
|
||||
i: int = f2(True)
|
||||
reveal_type(i) # revealed: int
|
||||
reveal_type(i) # revealed: Literal[True]
|
||||
|
||||
j: int | str = f2(True)
|
||||
reveal_type(j) # revealed: Literal[True]
|
||||
```
|
||||
|
||||
Types are not widened unnecessarily:
|
||||
|
||||
```py
|
||||
def id[T](x: T) -> T:
|
||||
return x
|
||||
|
||||
def lst[T](x: T) -> list[T]:
|
||||
return [x]
|
||||
|
||||
def _(i: int):
|
||||
a: int | None = i
|
||||
b: int | None = id(i)
|
||||
c: int | str | None = id(i)
|
||||
reveal_type(a) # revealed: int
|
||||
reveal_type(b) # revealed: int
|
||||
reveal_type(c) # revealed: int
|
||||
|
||||
a: list[int | None] | None = [i]
|
||||
b: list[int | None] | None = id([i])
|
||||
c: list[int | None] | int | None = id([i])
|
||||
reveal_type(a) # revealed: list[int | None]
|
||||
# TODO: these should reveal `list[int | None]`
|
||||
# we currently do not use the call expression annotation as type context for argument inference
|
||||
reveal_type(b) # revealed: list[Unknown | int]
|
||||
reveal_type(c) # revealed: list[Unknown | int]
|
||||
|
||||
a: list[int | None] | None = [i]
|
||||
b: list[int | None] | None = lst(i)
|
||||
c: list[int | None] | int | None = lst(i)
|
||||
reveal_type(a) # revealed: list[int | None]
|
||||
reveal_type(b) # revealed: list[int | None]
|
||||
reveal_type(c) # revealed: list[int | None]
|
||||
|
||||
a: list | None = []
|
||||
b: list | None = id([])
|
||||
c: list | int | None = id([])
|
||||
reveal_type(a) # revealed: list[Unknown]
|
||||
reveal_type(b) # revealed: list[Unknown]
|
||||
reveal_type(c) # revealed: list[Unknown]
|
||||
```
|
||||
|
|
|
@ -11,7 +11,7 @@ class Member:
|
|||
role: str = field(default="user")
|
||||
tag: str | None = field(default=None, init=False)
|
||||
|
||||
# revealed: (self: Member, name: str, role: str = str) -> None
|
||||
# revealed: (self: Member, name: str, role: str = Literal["user"]) -> None
|
||||
reveal_type(Member.__init__)
|
||||
|
||||
alice = Member(name="Alice", role="admin")
|
||||
|
@ -37,7 +37,7 @@ class Data:
|
|||
content: list[int] = field(default_factory=list)
|
||||
timestamp: datetime = field(default_factory=datetime.now, init=False)
|
||||
|
||||
# revealed: (self: Data, content: list[int] = list[int]) -> None
|
||||
# revealed: (self: Data, content: list[int] = Unknown) -> None
|
||||
reveal_type(Data.__init__)
|
||||
|
||||
data = Data([1, 2, 3])
|
||||
|
@ -64,7 +64,7 @@ class Person:
|
|||
role: str = field(default="user", kw_only=True)
|
||||
|
||||
# TODO: this would ideally show a default value of `None` for `age`
|
||||
# revealed: (self: Person, name: str, *, age: int | None = int | None, role: str = str) -> None
|
||||
# revealed: (self: Person, name: str, *, age: int | None = None, role: str = Literal["user"]) -> None
|
||||
reveal_type(Person.__init__)
|
||||
|
||||
alice = Person(role="admin", name="Alice")
|
||||
|
|
|
@ -907,7 +907,7 @@ grandchild: Node = {"name": "grandchild", "parent": child}
|
|||
|
||||
nested: Node = {"name": "n1", "parent": {"name": "n2", "parent": {"name": "n3", "parent": None}}}
|
||||
|
||||
# TODO: this should be an error (invalid type for `name` in innermost node)
|
||||
# error: [invalid-argument-type] "Invalid argument to key "name" with declared type `str` on TypedDict `Node`: value of type `Literal[3]`"
|
||||
nested_invalid: Node = {"name": "n1", "parent": {"name": "n2", "parent": {"name": 3, "parent": None}}}
|
||||
```
|
||||
|
||||
|
|
|
@ -1233,22 +1233,35 @@ impl<'db> Type<'db> {
|
|||
if yes { self.negate(db) } else { *self }
|
||||
}
|
||||
|
||||
/// Remove the union elements that are not related to `target`.
|
||||
/// If the type is a union, filters union elements based on the provided predicate.
|
||||
///
|
||||
/// Otherwise, returns the type unchanged.
|
||||
pub(crate) fn filter_union(
|
||||
self,
|
||||
db: &'db dyn Db,
|
||||
f: impl FnMut(&Type<'db>) -> bool,
|
||||
) -> Type<'db> {
|
||||
if let Type::Union(union) = self {
|
||||
union.filter(db, f)
|
||||
} else {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// If the type is a union, removes union elements that are disjoint from `target`.
|
||||
///
|
||||
/// Otherwise, returns the type unchanged.
|
||||
pub(crate) fn filter_disjoint_elements(
|
||||
self,
|
||||
db: &'db dyn Db,
|
||||
target: Type<'db>,
|
||||
inferable: InferableTypeVars<'_, 'db>,
|
||||
) -> Type<'db> {
|
||||
if let Type::Union(union) = self {
|
||||
union.filter(db, |elem| {
|
||||
!elem
|
||||
.when_disjoint_from(db, target, inferable)
|
||||
.is_always_satisfied()
|
||||
})
|
||||
} else {
|
||||
self
|
||||
}
|
||||
self.filter_union(db, |elem| {
|
||||
!elem
|
||||
.when_disjoint_from(db, target, inferable)
|
||||
.is_always_satisfied()
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the fallback instance type that a literal is an instance of, or `None` if the type
|
||||
|
@ -11185,9 +11198,9 @@ impl<'db> UnionType<'db> {
|
|||
pub(crate) fn filter(
|
||||
self,
|
||||
db: &'db dyn Db,
|
||||
filter_fn: impl FnMut(&&Type<'db>) -> bool,
|
||||
mut f: impl FnMut(&Type<'db>) -> bool,
|
||||
) -> Type<'db> {
|
||||
Self::from_elements(db, self.elements(db).iter().filter(filter_fn))
|
||||
Self::from_elements(db, self.elements(db).iter().filter(|ty| f(ty)))
|
||||
}
|
||||
|
||||
pub(crate) fn map_with_boundness(
|
||||
|
|
|
@ -2524,6 +2524,7 @@ struct ArgumentTypeChecker<'a, 'db> {
|
|||
argument_matches: &'a [MatchedArgument<'db>],
|
||||
parameter_tys: &'a mut [Option<Type<'db>>],
|
||||
call_expression_tcx: &'a TypeContext<'db>,
|
||||
return_ty: Type<'db>,
|
||||
errors: &'a mut Vec<BindingError<'db>>,
|
||||
|
||||
inferable_typevars: InferableTypeVars<'db, 'db>,
|
||||
|
@ -2531,6 +2532,7 @@ struct ArgumentTypeChecker<'a, 'db> {
|
|||
}
|
||||
|
||||
impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
|
||||
#[expect(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
db: &'db dyn Db,
|
||||
signature: &'a Signature<'db>,
|
||||
|
@ -2538,6 +2540,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
|
|||
argument_matches: &'a [MatchedArgument<'db>],
|
||||
parameter_tys: &'a mut [Option<Type<'db>>],
|
||||
call_expression_tcx: &'a TypeContext<'db>,
|
||||
return_ty: Type<'db>,
|
||||
errors: &'a mut Vec<BindingError<'db>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
|
@ -2547,6 +2550,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
|
|||
argument_matches,
|
||||
parameter_tys,
|
||||
call_expression_tcx,
|
||||
return_ty,
|
||||
errors,
|
||||
inferable_typevars: InferableTypeVars::None,
|
||||
specialization: None,
|
||||
|
@ -2588,25 +2592,6 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
|
|||
// TODO: Use the list of inferable typevars from the generic context of the callable.
|
||||
let mut builder = SpecializationBuilder::new(self.db, self.inferable_typevars);
|
||||
|
||||
// Note that we infer the annotated type _before_ the arguments if this call is part of
|
||||
// an annotated assignment, to closer match the order of any unions written in the type
|
||||
// annotation.
|
||||
if let Some(return_ty) = self.signature.return_ty
|
||||
&& let Some(call_expression_tcx) = self.call_expression_tcx.annotation
|
||||
{
|
||||
match call_expression_tcx {
|
||||
// A type variable is not a useful type-context for expression inference, and applying it
|
||||
// to the return type can lead to confusing unions in nested generic calls.
|
||||
Type::TypeVar(_) => {}
|
||||
|
||||
_ => {
|
||||
// Ignore any specialization errors here, because the type context is only used as a hint
|
||||
// to infer a more assignable return type.
|
||||
let _ = builder.infer(return_ty, call_expression_tcx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let parameters = self.signature.parameters();
|
||||
for (argument_index, adjusted_argument_index, _, argument_type) in
|
||||
self.enumerate_argument_types()
|
||||
|
@ -2631,7 +2616,41 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
|
|||
}
|
||||
}
|
||||
|
||||
self.specialization = Some(builder.build(generic_context, *self.call_expression_tcx));
|
||||
// Build the specialization first without inferring the type context.
|
||||
let isolated_specialization = builder.build(generic_context, *self.call_expression_tcx);
|
||||
let isolated_return_ty = self
|
||||
.return_ty
|
||||
.apply_specialization(self.db, isolated_specialization);
|
||||
|
||||
let mut try_infer_tcx = || {
|
||||
let return_ty = self.signature.return_ty?;
|
||||
let call_expression_tcx = self.call_expression_tcx.annotation?;
|
||||
|
||||
// A type variable is not a useful type-context for expression inference, and applying it
|
||||
// to the return type can lead to confusing unions in nested generic calls.
|
||||
if call_expression_tcx.is_type_var() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// If the return type is already assignable to the annotated type, we can ignore the
|
||||
// type context and prefer the narrower inferred type.
|
||||
if isolated_return_ty.is_assignable_to(self.db, call_expression_tcx) {
|
||||
return None;
|
||||
}
|
||||
|
||||
// TODO: Ideally we would infer the annotated type _before_ the arguments if this call is part of an
|
||||
// annotated assignment, to closer match the order of any unions written in the type annotation.
|
||||
builder.infer(return_ty, call_expression_tcx).ok()?;
|
||||
|
||||
// Otherwise, build the specialization again after inferring the type context.
|
||||
let specialization = builder.build(generic_context, *self.call_expression_tcx);
|
||||
let return_ty = return_ty.apply_specialization(self.db, specialization);
|
||||
|
||||
Some((Some(specialization), return_ty))
|
||||
};
|
||||
|
||||
(self.specialization, self.return_ty) =
|
||||
try_infer_tcx().unwrap_or((Some(isolated_specialization), isolated_return_ty));
|
||||
}
|
||||
|
||||
fn check_argument_type(
|
||||
|
@ -2826,8 +2845,14 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
|
|||
}
|
||||
}
|
||||
|
||||
fn finish(self) -> (InferableTypeVars<'db, 'db>, Option<Specialization<'db>>) {
|
||||
(self.inferable_typevars, self.specialization)
|
||||
fn finish(
|
||||
self,
|
||||
) -> (
|
||||
InferableTypeVars<'db, 'db>,
|
||||
Option<Specialization<'db>>,
|
||||
Type<'db>,
|
||||
) {
|
||||
(self.inferable_typevars, self.specialization, self.return_ty)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2985,18 +3010,16 @@ impl<'db> Binding<'db> {
|
|||
&self.argument_matches,
|
||||
&mut self.parameter_tys,
|
||||
call_expression_tcx,
|
||||
self.return_ty,
|
||||
&mut self.errors,
|
||||
);
|
||||
|
||||
// If this overload is generic, first see if we can infer a specialization of the function
|
||||
// from the arguments that were passed in.
|
||||
checker.infer_specialization();
|
||||
|
||||
checker.check_argument_types();
|
||||
(self.inferable_typevars, self.specialization) = checker.finish();
|
||||
if let Some(specialization) = self.specialization {
|
||||
self.return_ty = self.return_ty.apply_specialization(db, specialization);
|
||||
}
|
||||
|
||||
(self.inferable_typevars, self.specialization, self.return_ty) = checker.finish();
|
||||
}
|
||||
|
||||
pub(crate) fn set_return_type(&mut self, return_ty: Type<'db>) {
|
||||
|
|
|
@ -1229,6 +1229,7 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
let tcx = tcx_specialization.and_then(|specialization| {
|
||||
specialization.get(self.db, variable.bound_typevar)
|
||||
});
|
||||
|
||||
ty = ty.map(|ty| ty.promote_literals(self.db, TypeContext::new(tcx)));
|
||||
}
|
||||
|
||||
|
@ -1251,7 +1252,7 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
pub(crate) fn infer(
|
||||
&mut self,
|
||||
formal: Type<'db>,
|
||||
mut actual: Type<'db>,
|
||||
actual: Type<'db>,
|
||||
) -> Result<(), SpecializationError<'db>> {
|
||||
if formal == actual {
|
||||
return Ok(());
|
||||
|
@ -1282,9 +1283,11 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
return Ok(());
|
||||
}
|
||||
|
||||
// For example, if `formal` is `list[T]` and `actual` is `list[int] | None`, we want to specialize `T` to `int`.
|
||||
// So, here we remove the union elements that are not related to `formal`.
|
||||
actual = actual.filter_disjoint_elements(self.db, formal, self.inferable);
|
||||
// Remove the union elements that are not related to `formal`.
|
||||
//
|
||||
// For example, if `formal` is `list[T]` and `actual` is `list[int] | None`, we want to specialize `T`
|
||||
// to `int`.
|
||||
let actual = actual.filter_disjoint_elements(self.db, formal, self.inferable);
|
||||
|
||||
match (formal, actual) {
|
||||
// TODO: We haven't implemented a full unification solver yet. If typevars appear in
|
||||
|
|
|
@ -391,7 +391,7 @@ impl<'db> TypeContext<'db> {
|
|||
.and_then(|ty| ty.known_specialization(db, known_class))
|
||||
}
|
||||
|
||||
pub(crate) fn map_annotation(self, f: impl FnOnce(Type<'db>) -> Type<'db>) -> Self {
|
||||
pub(crate) fn map(self, f: impl FnOnce(Type<'db>) -> Type<'db>) -> Self {
|
||||
Self {
|
||||
annotation: self.annotation.map(f),
|
||||
}
|
||||
|
|
|
@ -5890,6 +5890,18 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
parenthesized: _,
|
||||
} = tuple;
|
||||
|
||||
// TODO: Use the list of inferable typevars from the generic context of tuple.
|
||||
let inferable = InferableTypeVars::None;
|
||||
|
||||
// Remove any union elements of that are unrelated to the tuple type.
|
||||
let tcx = tcx.map(|annotation| {
|
||||
annotation.filter_disjoint_elements(
|
||||
self.db(),
|
||||
KnownClass::Tuple.to_instance(self.db()),
|
||||
inferable,
|
||||
)
|
||||
});
|
||||
|
||||
let annotated_tuple = tcx
|
||||
.known_specialization(self.db(), KnownClass::Tuple)
|
||||
.and_then(|specialization| {
|
||||
|
@ -5955,7 +5967,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
} = dict;
|
||||
|
||||
// Validate `TypedDict` dictionary literal assignments.
|
||||
if let Some(typed_dict) = tcx.annotation.and_then(Type::as_typed_dict)
|
||||
if let Some(tcx) = tcx.annotation
|
||||
&& let Some(typed_dict) = tcx
|
||||
.filter_union(self.db(), Type::is_typed_dict)
|
||||
.as_typed_dict()
|
||||
&& let Some(ty) = self.infer_typed_dict_expression(dict, typed_dict)
|
||||
{
|
||||
return ty;
|
||||
|
@ -6038,9 +6053,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
// TODO: Use the list of inferable typevars from the generic context of the collection
|
||||
// class.
|
||||
let inferable = InferableTypeVars::None;
|
||||
let tcx = tcx.map_annotation(|annotation| {
|
||||
// Remove any union elements of `annotation` that are not related to `collection_ty`.
|
||||
// e.g. `annotation: list[int] | None => list[int]` if `collection_ty: list`
|
||||
|
||||
// Remove any union elements of that are unrelated to the collection type.
|
||||
//
|
||||
// For example, we only want the `list[int]` from `annotation: list[int] | None` if
|
||||
// `collection_ty` is `list`.
|
||||
let tcx = tcx.map(|annotation| {
|
||||
let collection_ty = collection_class.to_instance(self.db());
|
||||
annotation.filter_disjoint_elements(self.db(), collection_ty, inferable)
|
||||
});
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue