[ty] Handle annotated self parameter in constructor of non-invariant generic classes (#21325)
Some checks are pending
CI / Determine changes (push) Waiting to run
CI / cargo fmt (push) Waiting to run
CI / cargo clippy (push) Blocked by required conditions
CI / cargo test (linux) (push) Blocked by required conditions
CI / cargo test (linux, release) (push) Blocked by required conditions
CI / cargo test (${{ github.repository == 'astral-sh/ruff' && 'depot-windows-2022-16' || 'windows-latest' }}) (push) Blocked by required conditions
CI / cargo test (macos-latest) (push) Blocked by required conditions
CI / cargo test (wasm) (push) Blocked by required conditions
CI / cargo build (msrv) (push) Blocked by required conditions
CI / cargo fuzz build (push) Blocked by required conditions
CI / fuzz parser (push) Blocked by required conditions
CI / test scripts (push) Blocked by required conditions
CI / ecosystem (push) Blocked by required conditions
CI / Fuzz for new ty panics (push) Blocked by required conditions
CI / cargo shear (push) Blocked by required conditions
CI / ty completion evaluation (push) Blocked by required conditions
CI / python package (push) Waiting to run
CI / pre-commit (push) Waiting to run
CI / mkdocs (push) Waiting to run
CI / formatter instabilities and black similarity (push) Blocked by required conditions
CI / test ruff-lsp (push) Blocked by required conditions
CI / check playground (push) Blocked by required conditions
CI / benchmarks instrumented (ruff) (push) Blocked by required conditions
CI / benchmarks instrumented (ty) (push) Blocked by required conditions
CI / benchmarks walltime (medium|multithreaded) (push) Blocked by required conditions
CI / benchmarks walltime (small|large) (push) Blocked by required conditions
[ty Playground] Release / publish (push) Waiting to run

This manifested as an error when inferring the type of a PEP-695 generic
class via its constructor parameters:

```py
class D[T, U]:
    @overload
    def __init__(self: "D[str, U]", u: U) -> None: ...
    @overload
    def __init__(self, t: T, u: U) -> None: ...
    def __init__(self, *args) -> None: ...

# revealed: D[Unknown, str]
# SHOULD BE: D[str, str]
reveal_type(D("string"))
```

This manifested because `D` is inferred to be bivariant in both `T` and
`U`. We weren't seeing this in the equivalent example for legacy
typevars, since those default to invariant. (This issue also showed up
for _covariant_ typevars, so this issue was not limited to bivariance.)

The underlying cause was because of a heuristic that we have in our
current constraint solver, which attempts to handle situations like
this:

```py
def f[T](t: T | None): ...
f(None)
```

Here, the `None` argument matches the non-typevar union element, so this
argument should not add any constraints on what `T` can specialize to.
Our previous heuristic would check for this by seeing if the argument
type is a subtype of the parameter annotation as a whole — even if it
isn't a union! That would cause us to erroneously ignore the `self`
parameter in our constructor call, since bivariant classes are
equivalent to each other, regardless of their specializations.

The quick fix is to move this heuristic "down a level", so that we only
apply it when the parameter annotation is a union. This heuristic should
go away completely 🤞 with the new constraint solver.
This commit is contained in:
Douglas Creager 2025-11-10 19:46:49 -05:00 committed by GitHub
parent 9ce3230add
commit 33b942c7ad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 50 additions and 37 deletions

View file

@ -674,8 +674,7 @@ x6: Covariant[Any] = covariant(1)
x7: Contravariant[Any] = contravariant(1) x7: Contravariant[Any] = contravariant(1)
x8: Invariant[Any] = invariant(1) x8: Invariant[Any] = invariant(1)
# TODO: This could reveal `Bivariant[Any]`. reveal_type(x5) # revealed: Bivariant[Any]
reveal_type(x5) # revealed: Bivariant[Literal[1]]
reveal_type(x6) # revealed: Covariant[Any] reveal_type(x6) # revealed: Covariant[Any]
reveal_type(x7) # revealed: Contravariant[Any] reveal_type(x7) # revealed: Contravariant[Any]
reveal_type(x8) # revealed: Invariant[Any] reveal_type(x8) # revealed: Invariant[Any]

View file

@ -436,9 +436,7 @@ def test_seq(x: Sequence[T]) -> Sequence[T]:
def func8(t1: tuple[complex, list[int]], t2: tuple[int, *tuple[str, ...]], t3: tuple[()]): def func8(t1: tuple[complex, list[int]], t2: tuple[int, *tuple[str, ...]], t3: tuple[()]):
reveal_type(test_seq(t1)) # revealed: Sequence[int | float | complex | list[int]] reveal_type(test_seq(t1)) # revealed: Sequence[int | float | complex | list[int]]
reveal_type(test_seq(t2)) # revealed: Sequence[int | str] reveal_type(test_seq(t2)) # revealed: Sequence[int | str]
reveal_type(test_seq(t3)) # revealed: Sequence[Never]
# TODO: this should be `Sequence[Never]`
reveal_type(test_seq(t3)) # revealed: Sequence[Unknown]
``` ```
### `__init__` is itself generic ### `__init__` is itself generic
@ -466,6 +464,7 @@ wrong_innards: C[int] = C("five", 1)
from typing_extensions import overload, Generic, TypeVar from typing_extensions import overload, Generic, TypeVar
T = TypeVar("T") T = TypeVar("T")
U = TypeVar("U")
class C(Generic[T]): class C(Generic[T]):
@overload @overload
@ -497,6 +496,17 @@ C[int](12)
C[None]("string") # error: [no-matching-overload] C[None]("string") # error: [no-matching-overload]
C[None](b"bytes") # error: [no-matching-overload] C[None](b"bytes") # error: [no-matching-overload]
C[None](12) C[None](12)
class D(Generic[T, U]):
@overload
def __init__(self: "D[str, U]", u: U) -> None: ...
@overload
def __init__(self, t: T, u: U) -> None: ...
def __init__(self, *args) -> None: ...
reveal_type(D("string")) # revealed: D[str, str]
reveal_type(D(1)) # revealed: D[str, int]
reveal_type(D(1, "string")) # revealed: D[int, str]
``` ```
### Synthesized methods with dataclasses ### Synthesized methods with dataclasses

View file

@ -375,9 +375,7 @@ def test_seq[T](x: Sequence[T]) -> Sequence[T]:
def func8(t1: tuple[complex, list[int]], t2: tuple[int, *tuple[str, ...]], t3: tuple[()]): def func8(t1: tuple[complex, list[int]], t2: tuple[int, *tuple[str, ...]], t3: tuple[()]):
reveal_type(test_seq(t1)) # revealed: Sequence[int | float | complex | list[int]] reveal_type(test_seq(t1)) # revealed: Sequence[int | float | complex | list[int]]
reveal_type(test_seq(t2)) # revealed: Sequence[int | str] reveal_type(test_seq(t2)) # revealed: Sequence[int | str]
reveal_type(test_seq(t3)) # revealed: Sequence[Never]
# TODO: this should be `Sequence[Never]`
reveal_type(test_seq(t3)) # revealed: Sequence[Unknown]
``` ```
### `__init__` is itself generic ### `__init__` is itself generic
@ -436,6 +434,17 @@ C[int](12)
C[None]("string") # error: [no-matching-overload] C[None]("string") # error: [no-matching-overload]
C[None](b"bytes") # error: [no-matching-overload] C[None](b"bytes") # error: [no-matching-overload]
C[None](12) C[None](12)
class D[T, U]:
@overload
def __init__(self: "D[str, U]", u: U) -> None: ...
@overload
def __init__(self, t: T, u: U) -> None: ...
def __init__(self, *args) -> None: ...
reveal_type(D("string")) # revealed: D[str, str]
reveal_type(D(1)) # revealed: D[str, int]
reveal_type(D(1, "string")) # revealed: D[int, str]
``` ```
### Synthesized methods with dataclasses ### Synthesized methods with dataclasses

View file

@ -1393,31 +1393,6 @@ impl<'db> SpecializationBuilder<'db> {
return Ok(()); return Ok(());
} }
// If the actual type is a subtype of the formal type, then return without adding any new
// type mappings. (Note that if the formal type contains any typevars, this check will
// fail, since no non-typevar types are assignable to a typevar. Also note that we are
// checking _subtyping_, not _assignability_, so that we do specialize typevars to dynamic
// argument types; and we have a special case for `Never`, which is a subtype of all types,
// but which we also do want as a specialization candidate.)
//
// In particular, this handles a case like
//
// ```py
// def f[T](t: T | None): ...
//
// f(None)
// ```
//
// without specializing `T` to `None`.
if !matches!(formal, Type::ProtocolInstance(_))
&& !actual.is_never()
&& actual
.when_subtype_of(self.db, formal, self.inferable)
.is_always_satisfied(self.db)
{
return Ok(());
}
// Remove the union elements from `actual` that are not related to `formal`, and vice // Remove the union elements from `actual` that are not related to `formal`, and vice
// versa. // versa.
// //
@ -1473,10 +1448,30 @@ impl<'db> SpecializationBuilder<'db> {
self.add_type_mapping(*formal_bound_typevar, remaining_actual, filter); self.add_type_mapping(*formal_bound_typevar, remaining_actual, filter);
} }
(Type::Union(formal), _) => { (Type::Union(formal), _) => {
// Second, if the formal is a union, and precisely one union element _is_ a typevar (not // Second, if the formal is a union, and precisely one union element is assignable
// _contains_ a typevar), then we add a mapping between that typevar and the actual // from the actual type, then we don't add any type mapping. This handles a case like
// type. (Note that we've already handled above the case where the actual is //
// assignable to any _non-typevar_ union element.) // ```py
// def f[T](t: T | None): ...
//
// f(None)
// ```
//
// without specializing `T` to `None`.
//
// Otherwise, if precisely one union element _is_ a typevar (not _contains_ a
// typevar), then we add a mapping between that typevar and the actual type.
if !actual.is_never() {
let assignable_elements = (formal.elements(self.db).iter()).filter(|ty| {
actual
.when_subtype_of(self.db, **ty, self.inferable)
.is_always_satisfied(self.db)
});
if assignable_elements.exactly_one().is_ok() {
return Ok(());
}
}
let bound_typevars = let bound_typevars =
(formal.elements(self.db).iter()).filter_map(|ty| ty.as_typevar()); (formal.elements(self.db).iter()).filter_map(|ty| ty.as_typevar());
if let Ok(bound_typevar) = bound_typevars.exactly_one() { if let Ok(bound_typevar) = bound_typevars.exactly_one() {