diff --git a/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md b/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md index a5e62f6866..5db84cfd5a 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md @@ -474,6 +474,16 @@ def g(x: str): f(prefix=x, suffix=".tar.gz") ``` +If the type variable is present multiple times in the union, we choose the correct union element to +infer against based on the argument type: + +```py +def h[T](x: list[T] | dict[T, T]) -> T | None: ... +def _(x: list[int], y: dict[int, int]): + reveal_type(h(x)) # revealed: int | None + reveal_type(h(y)) # revealed: int | None +``` + ## Nested functions see typevars bound in outer function ```py diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 444c5badd6..992e664401 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -1397,11 +1397,13 @@ impl<'db> SpecializationBuilder<'db> { return Ok(()); } - // Remove the union elements that are not related to `formal`. + // Remove the union elements from `actual` that are not related to `formal`, and vice + // versa. // // For example, if `formal` is `list[T]` and `actual` is `list[int] | None`, we want to specialize `T` - // to `int`. + // to `int`, and so ignore the `None`. let actual = actual.filter_disjoint_elements(self.db, formal, self.inferable); + let formal = formal.filter_disjoint_elements(self.db, actual, self.inferable); match (formal, actual) { // TODO: We haven't implemented a full unification solver yet. If typevars appear in