[ty] linear variance inference for PEP-695 type parameters (#18713)
Some checks are pending
CI / Determine changes (push) Waiting to run
CI / cargo fmt (push) Waiting to run
CI / cargo clippy (push) Blocked by required conditions
CI / cargo test (linux) (push) Blocked by required conditions
CI / cargo test (linux, release) (push) Blocked by required conditions
CI / cargo test (windows) (push) Blocked by required conditions
CI / cargo test (wasm) (push) Blocked by required conditions
CI / cargo build (release) (push) Waiting to run
CI / mkdocs (push) Waiting to run
CI / cargo build (msrv) (push) Blocked by required conditions
CI / cargo fuzz build (push) Blocked by required conditions
CI / fuzz parser (push) Blocked by required conditions
CI / test scripts (push) Blocked by required conditions
CI / ecosystem (push) Blocked by required conditions
CI / Fuzz for new ty panics (push) Blocked by required conditions
CI / cargo shear (push) Blocked by required conditions
CI / python package (push) Waiting to run
CI / pre-commit (push) Waiting to run
CI / formatter instabilities and black similarity (push) Blocked by required conditions
CI / test ruff-lsp (push) Blocked by required conditions
CI / check playground (push) Blocked by required conditions
CI / benchmarks-instrumented (push) Blocked by required conditions
CI / benchmarks-walltime (push) Blocked by required conditions
[ty Playground] Release / publish (push) Waiting to run

## Summary

Implement linear-time variance inference for type variables
(https://github.com/astral-sh/ty/issues/488).

Inspired by Martin Huschenbett's [PyCon 2025
Talk](https://www.youtube.com/watch?v=7uixlNTOY4s&t=9705s).

## Test Plan

update tests, add new tests, including for mutually recursive classes

---------

Co-authored-by: Carl Meyer <carl@astral.sh>
This commit is contained in:
Eric Mark Martin 2025-08-19 20:54:09 -04:00 committed by GitHub
parent 656fc335f2
commit 33030b34cd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 1088 additions and 95 deletions

View file

@ -237,10 +237,15 @@ If the type of a constructor parameter is a class typevar, we can use that to in
parameter. The types inferred from a type context and from a constructor parameter must be
consistent with each other.
We have to add `x: T` to the classes to ensure they're not bivariant in `T` (__new__ and __init__
signatures don't count towards variance).
### `__new__` only
```py
class C[T]:
x: T
def __new__(cls, x: T) -> "C[T]":
return object.__new__(cls)
@ -254,6 +259,8 @@ wrong_innards: C[int] = C("five")
```py
class C[T]:
x: T
def __init__(self, x: T) -> None: ...
reveal_type(C(1)) # revealed: C[int]
@ -266,6 +273,8 @@ wrong_innards: C[int] = C("five")
```py
class C[T]:
x: T
def __new__(cls, x: T) -> "C[T]":
return object.__new__(cls)
@ -281,6 +290,8 @@ wrong_innards: C[int] = C("five")
```py
class C[T]:
x: T
def __new__(cls, *args, **kwargs) -> "C[T]":
return object.__new__(cls)
@ -292,6 +303,8 @@ reveal_type(C(1)) # revealed: C[int]
wrong_innards: C[int] = C("five")
class D[T]:
x: T
def __new__(cls, x: T) -> "D[T]":
return object.__new__(cls)
@ -378,6 +391,8 @@ def func8(t1: tuple[complex, list[int]], t2: tuple[int, *tuple[str, ...]], t3: t
```py
class C[T]:
x: T
def __init__[S](self, x: T, y: S) -> None: ...
reveal_type(C(1, 1)) # revealed: C[int]
@ -395,6 +410,10 @@ from __future__ import annotations
from typing import overload
class C[T]:
# we need to use the type variable or else the class is bivariant in T, and
# specializations become meaningless
x: T
@overload
def __init__(self: C[str], x: str) -> None: ...
@overload

View file

@ -40,8 +40,6 @@ class C[T]:
class D[U](C[U]):
pass
# TODO: no error
# error: [static-assert-error]
static_assert(is_assignable_to(C[B], C[A]))
static_assert(not is_assignable_to(C[A], C[B]))
static_assert(is_assignable_to(C[A], C[Any]))
@ -49,8 +47,6 @@ static_assert(is_assignable_to(C[B], C[Any]))
static_assert(is_assignable_to(C[Any], C[A]))
static_assert(is_assignable_to(C[Any], C[B]))
# TODO: no error
# error: [static-assert-error]
static_assert(is_assignable_to(D[B], C[A]))
static_assert(not is_assignable_to(D[A], C[B]))
static_assert(is_assignable_to(D[A], C[Any]))
@ -58,8 +54,6 @@ static_assert(is_assignable_to(D[B], C[Any]))
static_assert(is_assignable_to(D[Any], C[A]))
static_assert(is_assignable_to(D[Any], C[B]))
# TODO: no error
# error: [static-assert-error]
static_assert(is_subtype_of(C[B], C[A]))
static_assert(not is_subtype_of(C[A], C[B]))
static_assert(not is_subtype_of(C[A], C[Any]))
@ -67,8 +61,6 @@ static_assert(not is_subtype_of(C[B], C[Any]))
static_assert(not is_subtype_of(C[Any], C[A]))
static_assert(not is_subtype_of(C[Any], C[B]))
# TODO: no error
# error: [static-assert-error]
static_assert(is_subtype_of(D[B], C[A]))
static_assert(not is_subtype_of(D[A], C[B]))
static_assert(not is_subtype_of(D[A], C[Any]))
@ -124,8 +116,6 @@ class D[U](C[U]):
pass
static_assert(not is_assignable_to(C[B], C[A]))
# TODO: no error
# error: [static-assert-error]
static_assert(is_assignable_to(C[A], C[B]))
static_assert(is_assignable_to(C[A], C[Any]))
static_assert(is_assignable_to(C[B], C[Any]))
@ -133,8 +123,6 @@ static_assert(is_assignable_to(C[Any], C[A]))
static_assert(is_assignable_to(C[Any], C[B]))
static_assert(not is_assignable_to(D[B], C[A]))
# TODO: no error
# error: [static-assert-error]
static_assert(is_assignable_to(D[A], C[B]))
static_assert(is_assignable_to(D[A], C[Any]))
static_assert(is_assignable_to(D[B], C[Any]))
@ -142,8 +130,6 @@ static_assert(is_assignable_to(D[Any], C[A]))
static_assert(is_assignable_to(D[Any], C[B]))
static_assert(not is_subtype_of(C[B], C[A]))
# TODO: no error
# error: [static-assert-error]
static_assert(is_subtype_of(C[A], C[B]))
static_assert(not is_subtype_of(C[A], C[Any]))
static_assert(not is_subtype_of(C[B], C[Any]))
@ -151,8 +137,6 @@ static_assert(not is_subtype_of(C[Any], C[A]))
static_assert(not is_subtype_of(C[Any], C[B]))
static_assert(not is_subtype_of(D[B], C[A]))
# TODO: no error
# error: [static-assert-error]
static_assert(is_subtype_of(D[A], C[B]))
static_assert(not is_subtype_of(D[A], C[Any]))
static_assert(not is_subtype_of(D[B], C[Any]))
@ -297,34 +281,22 @@ class C[T]:
class D[U](C[U]):
pass
# TODO: no error
# error: [static-assert-error]
static_assert(is_assignable_to(C[B], C[A]))
# TODO: no error
# error: [static-assert-error]
static_assert(is_assignable_to(C[A], C[B]))
static_assert(is_assignable_to(C[A], C[Any]))
static_assert(is_assignable_to(C[B], C[Any]))
static_assert(is_assignable_to(C[Any], C[A]))
static_assert(is_assignable_to(C[Any], C[B]))
# TODO: no error
# error: [static-assert-error]
static_assert(is_assignable_to(D[B], C[A]))
static_assert(is_subtype_of(C[A], C[A]))
# TODO: no error
# error: [static-assert-error]
static_assert(is_assignable_to(D[A], C[B]))
static_assert(is_assignable_to(D[A], C[Any]))
static_assert(is_assignable_to(D[B], C[Any]))
static_assert(is_assignable_to(D[Any], C[A]))
static_assert(is_assignable_to(D[Any], C[B]))
# TODO: no error
# error: [static-assert-error]
static_assert(is_subtype_of(C[B], C[A]))
# TODO: no error
# error: [static-assert-error]
static_assert(is_subtype_of(C[A], C[B]))
static_assert(not is_subtype_of(C[A], C[Any]))
static_assert(not is_subtype_of(C[B], C[Any]))
@ -332,11 +304,7 @@ static_assert(not is_subtype_of(C[Any], C[A]))
static_assert(not is_subtype_of(C[Any], C[B]))
static_assert(not is_subtype_of(C[Any], C[Any]))
# TODO: no error
# error: [static-assert-error]
static_assert(is_subtype_of(D[B], C[A]))
# TODO: no error
# error: [static-assert-error]
static_assert(is_subtype_of(D[A], C[B]))
static_assert(not is_subtype_of(D[A], C[Any]))
static_assert(not is_subtype_of(D[B], C[Any]))
@ -345,23 +313,11 @@ static_assert(not is_subtype_of(D[Any], C[B]))
static_assert(is_equivalent_to(C[A], C[A]))
static_assert(is_equivalent_to(C[B], C[B]))
# TODO: no error
# error: [static-assert-error]
static_assert(is_equivalent_to(C[B], C[A]))
# TODO: no error
# error: [static-assert-error]
static_assert(is_equivalent_to(C[A], C[B]))
# TODO: no error
# error: [static-assert-error]
static_assert(is_equivalent_to(C[A], C[Any]))
# TODO: no error
# error: [static-assert-error]
static_assert(is_equivalent_to(C[B], C[Any]))
# TODO: no error
# error: [static-assert-error]
static_assert(is_equivalent_to(C[Any], C[A]))
# TODO: no error
# error: [static-assert-error]
static_assert(is_equivalent_to(C[Any], C[B]))
static_assert(not is_equivalent_to(D[A], C[A]))
@ -380,4 +336,504 @@ static_assert(not is_equivalent_to(D[Any], C[Any]))
static_assert(not is_equivalent_to(D[Any], C[Unknown]))
```
## Mutual Recursion
This example due to Martin Huschenbett's PyCon 2025 talk,
[Linear Time variance Inference for PEP 695][linear-time-variance-talk]
```py
from ty_extensions import is_subtype_of, static_assert
from typing import Any
class A: ...
class B(A): ...
class C[X]:
def f(self) -> "D[X]":
return D()
def g(self, x: X) -> None: ...
class D[Y]:
def h(self) -> C[Y]:
return C()
```
`C` is contravariant in `X`, and `D` in `Y`:
- `C` has two occurrences of `X`
- `X` occurs in the return type of `f` as `D[X]` (`X` is substituted in for `Y`)
- `D` has one occurrence of `Y`
- `Y` occurs in the return type of `h` as `C[Y]`
- `X` occurs contravariantly as a parameter in `g`
Thus the variance of `X` in `C` depends on itself. We want to infer the least restrictive possible
variance, so in such cases we begin by assuming that the point where we detect the cycle is
bivariant.
If we thus assume `X` is bivariant in `C`, then `Y` will be bivariant in `D`, as `D`'s only
occurrence of `Y` is in `C`. Then we consider `X` in `C` once more. We have two occurrences: `D[X]`
covariantly in a return type, and `X` contravariantly in an argument type. With one bivariant and
one contravariant occurrence, we update our inference of `X` in `C` to contravariant---the supremum
of contravariant and bivariant in the lattice.
Now that we've updated the variance of `X` in `C`, we re-evaluate `Y` in `D`. It only has the one
occurrence `C[Y]`, which we now infer is contravariant, and so we infer contravariance for `Y` in
`D` as well.
Because the variance of `X` in `C` depends on that of `Y` in `D`, we have to re-evaluate now that
we've updated the latter to contravariant. The variance of `X` in `C` is now the supremum of
contravariant and contravariant---giving us contravariant---and so remains unchanged.
Once we've completed a turn around the cycle with nothing changed, we've reached a fixed-point---the
variance inference will not change any further---and so we finally conclude that both `X` in `C` and
`Y` in `D` are contravariant.
```py
static_assert(not is_subtype_of(C[B], C[A]))
static_assert(is_subtype_of(C[A], C[B]))
static_assert(not is_subtype_of(C[A], C[Any]))
static_assert(not is_subtype_of(C[B], C[Any]))
static_assert(not is_subtype_of(C[Any], C[A]))
static_assert(not is_subtype_of(C[Any], C[B]))
static_assert(not is_subtype_of(D[B], D[A]))
static_assert(is_subtype_of(D[A], D[B]))
static_assert(not is_subtype_of(D[A], D[Any]))
static_assert(not is_subtype_of(D[B], D[Any]))
static_assert(not is_subtype_of(D[Any], D[A]))
static_assert(not is_subtype_of(D[Any], D[B]))
```
## Class Attributes
### Mutable Attributes
Normal attributes are mutable, and so make the enclosing class invariant in this typevar (see
[inv]).
```py
from ty_extensions import is_subtype_of, static_assert
class A: ...
class B(A): ...
class C[T]:
x: T
static_assert(not is_subtype_of(C[B], C[A]))
static_assert(not is_subtype_of(C[A], C[B]))
```
One might think that occurrences in the types of normal attributes are covariant, but they are
mutable, and thus the occurrences are invariant.
### Immutable Attributes
Immutable attributes can't be written to, and thus constrain the typevar to covariance, not
invariance.
#### Final attributes
```py
from typing import Final
from ty_extensions import is_subtype_of, static_assert
class A: ...
class B(A): ...
class C[T]:
x: Final[T]
static_assert(is_subtype_of(C[B], C[A]))
static_assert(not is_subtype_of(C[A], C[B]))
```
#### Underscore-prefixed attributes
Underscore-prefixed instance attributes are considered private, and thus are assumed not externally
mutated.
```py
from ty_extensions import is_subtype_of, static_assert
class A: ...
class B(A): ...
class C[T]:
_x: T
@property
def x(self) -> T:
return self._x
static_assert(is_subtype_of(C[B], C[A]))
static_assert(not is_subtype_of(C[A], C[B]))
class D[T]:
def __init__(self, x: T):
self._x = x
@property
def x(self) -> T:
return self._x
static_assert(is_subtype_of(D[B], D[A]))
static_assert(not is_subtype_of(D[A], D[B]))
```
#### Frozen dataclasses in Python 3.12 and earlier
```py
from dataclasses import dataclass, field
from ty_extensions import is_subtype_of, static_assert
class A: ...
class B(A): ...
@dataclass(frozen=True)
class D[U]:
y: U
static_assert(is_subtype_of(D[B], D[A]))
static_assert(not is_subtype_of(D[A], D[B]))
@dataclass(frozen=True)
class E[U]:
y: U = field()
static_assert(is_subtype_of(E[B], E[A]))
static_assert(not is_subtype_of(E[A], E[B]))
```
#### Frozen dataclasses in Python 3.13 and later
```toml
[environment]
python-version = "3.13"
```
Python 3.13 introduced a new synthesized `__replace__` method on dataclasses, which uses every field
type in a contravariant position (as a parameter to `__replace__`). This means that frozen
dataclasses on Python 3.13+ can't be covariant in their field types.
```py
from dataclasses import dataclass
from ty_extensions import is_subtype_of, static_assert
class A: ...
class B(A): ...
@dataclass(frozen=True)
class D[U]:
y: U
static_assert(not is_subtype_of(D[B], D[A]))
static_assert(not is_subtype_of(D[A], D[B]))
```
#### NamedTuple
```py
from typing import NamedTuple
from ty_extensions import is_subtype_of, static_assert
class A: ...
class B(A): ...
class E[V](NamedTuple):
z: V
static_assert(is_subtype_of(E[B], E[A]))
static_assert(not is_subtype_of(E[A], E[B]))
```
A subclass of a `NamedTuple` can still be covariant:
```py
class D[T](E[T]):
pass
static_assert(is_subtype_of(D[B], D[A]))
static_assert(not is_subtype_of(D[A], D[B]))
```
But adding a new generic attribute on the subclass makes it invariant (the added attribute is not a
`NamedTuple` field, and thus not immutable):
```py
class C[T](E[T]):
w: T
static_assert(not is_subtype_of(C[B], C[A]))
static_assert(not is_subtype_of(C[A], C[B]))
```
### Properties
Properties constrain to covariance if they are get-only and invariant if they are get-set:
```py
from ty_extensions import static_assert, is_subtype_of
class A: ...
class B(A): ...
class C[T]:
@property
def x(self) -> T | None:
return None
class D[U]:
@property
def y(self) -> U | None:
return None
@y.setter
def y(self, value: U): ...
static_assert(is_subtype_of(C[B], C[A]))
static_assert(not is_subtype_of(C[A], C[B]))
static_assert(not is_subtype_of(D[B], D[A]))
static_assert(not is_subtype_of(D[A], D[B]))
```
### Implicit Attributes
Implicit attributes work like normal ones
```py
from ty_extensions import static_assert, is_subtype_of
class A: ...
class B(A): ...
class C[T]:
def f(self) -> None:
self.x: T | None = None
static_assert(not is_subtype_of(C[B], C[A]))
static_assert(not is_subtype_of(C[A], C[B]))
```
### Constructors: excluding `__init__` and `__new__`
We consider it invalid to call `__init__` explicitly on an existing object. Likewise, `__new__` is
only used at the beginning of an object's life. As such, we don't need to worry about the variance
impact of these methods.
```py
from ty_extensions import static_assert, is_subtype_of
class A: ...
class B(A): ...
class C[T]:
def __init__(self, x: T): ...
def __new__(self, x: T): ...
static_assert(is_subtype_of(C[B], C[A]))
static_assert(is_subtype_of(C[A], C[B]))
```
This example is then bivariant because it doesn't use `T` outside of the two exempted methods.
This holds likewise for dataclasses with synthesized `__init__`:
```py
from dataclasses import dataclass
@dataclass(init=True, frozen=True)
class D[T]:
x: T
# Covariant due to the read-only T-typed attribute; the `__init__` is ignored and doesn't make it
# invariant:
static_assert(is_subtype_of(D[B], D[A]))
static_assert(not is_subtype_of(D[A], D[B]))
```
## Union Types
Union types are covariant in all their members. If `A <: B`, then `A | C <: B | C` and
`C | A <: C | B`.
```py
from ty_extensions import is_assignable_to, is_subtype_of, static_assert
class A: ...
class B(A): ...
class C: ...
# Union types are covariant in their members
static_assert(is_subtype_of(B | C, A | C))
static_assert(is_subtype_of(C | B, C | A))
static_assert(not is_subtype_of(A | C, B | C))
static_assert(not is_subtype_of(C | A, C | B))
# Assignability follows the same pattern
static_assert(is_assignable_to(B | C, A | C))
static_assert(is_assignable_to(C | B, C | A))
static_assert(not is_assignable_to(A | C, B | C))
static_assert(not is_assignable_to(C | A, C | B))
```
## Intersection Types
Intersection types cannot be expressed directly in Python syntax, but they occur when type narrowing
creates constraints through control flow. In ty's representation, intersection types are covariant
in their positive conjuncts and contravariant in their negative conjuncts.
```py
from ty_extensions import is_assignable_to, is_subtype_of, static_assert, Intersection, Not
class A: ...
class B(A): ...
class C: ...
# Test covariance in positive conjuncts
# If B <: A, then Intersection[X, B] <: Intersection[X, A]
static_assert(is_subtype_of(Intersection[C, B], Intersection[C, A]))
static_assert(not is_subtype_of(Intersection[C, A], Intersection[C, B]))
static_assert(is_assignable_to(Intersection[C, B], Intersection[C, A]))
static_assert(not is_assignable_to(Intersection[C, A], Intersection[C, B]))
# Test contravariance in negative conjuncts
# If B <: A, then Intersection[X, Not[A]] <: Intersection[X, Not[B]]
# (excluding supertype A is more restrictive than excluding subtype B)
static_assert(is_subtype_of(Intersection[C, Not[A]], Intersection[C, Not[B]]))
static_assert(not is_subtype_of(Intersection[C, Not[B]], Intersection[C, Not[A]]))
static_assert(is_assignable_to(Intersection[C, Not[A]], Intersection[C, Not[B]]))
static_assert(not is_assignable_to(Intersection[C, Not[B]], Intersection[C, Not[A]]))
```
## Subclass Types (type[T])
The `type[T]` construct represents the type of classes that are subclasses of `T`. It is covariant
in `T` because if `A <: B`, then `type[A] <: type[B]` holds.
```py
from ty_extensions import is_assignable_to, is_subtype_of, static_assert
class A: ...
class B(A): ...
# type[T] is covariant in T
static_assert(is_subtype_of(type[B], type[A]))
static_assert(not is_subtype_of(type[A], type[B]))
static_assert(is_assignable_to(type[B], type[A]))
static_assert(not is_assignable_to(type[A], type[B]))
# With generic classes using type[T]
class ClassContainer[T]:
def __init__(self, cls: type[T]) -> None:
self.cls = cls
def create_instance(self) -> T:
return self.cls()
# ClassContainer is covariant in T due to type[T]
static_assert(is_subtype_of(ClassContainer[B], ClassContainer[A]))
static_assert(not is_subtype_of(ClassContainer[A], ClassContainer[B]))
static_assert(is_assignable_to(ClassContainer[B], ClassContainer[A]))
static_assert(not is_assignable_to(ClassContainer[A], ClassContainer[B]))
# Practical example: you can pass a ClassContainer[B] where ClassContainer[A] is expected
# because type[B] can safely be used where type[A] is expected
def use_a_class_container(container: ClassContainer[A]) -> A:
return container.create_instance()
b_container = ClassContainer[B](B)
a_instance: A = use_a_class_container(b_container) # This should work
```
## Inheriting from generic classes with inferred variance
When inheriting from a generic class with our type variable substituted in, we count its occurrences
as well. In the following example, `T` is covariant in `C`, and contravariant in the subclass `D` if
you only count its own occurrences. Because we count both then, `T` is invariant in `D`.
```py
from ty_extensions import is_subtype_of, static_assert
class A:
pass
class B(A):
pass
class C[T]:
def f() -> T | None:
pass
static_assert(is_subtype_of(C[B], C[A]))
static_assert(not is_subtype_of(C[A], C[B]))
class D[T](C[T]):
def g(x: T) -> None:
pass
static_assert(not is_subtype_of(D[B], D[A]))
static_assert(not is_subtype_of(D[A], D[B]))
```
## Inheriting from generic classes with explicit variance
```py
from typing import TypeVar, Generic
from ty_extensions import is_subtype_of, static_assert
T = TypeVar("T")
T_co = TypeVar("T_co", covariant=True)
T_contra = TypeVar("T_contra", contravariant=True)
class A:
pass
class B(A):
pass
class Invariant(Generic[T]):
pass
static_assert(not is_subtype_of(Invariant[B], Invariant[A]))
static_assert(not is_subtype_of(Invariant[A], Invariant[B]))
class DerivedInvariant[T](Invariant[T]):
pass
static_assert(not is_subtype_of(DerivedInvariant[B], DerivedInvariant[A]))
static_assert(not is_subtype_of(DerivedInvariant[A], DerivedInvariant[B]))
class Covariant(Generic[T_co]):
pass
static_assert(is_subtype_of(Covariant[B], Covariant[A]))
static_assert(not is_subtype_of(Covariant[A], Covariant[B]))
class DerivedCovariant[T](Covariant[T]):
pass
static_assert(is_subtype_of(DerivedCovariant[B], DerivedCovariant[A]))
static_assert(not is_subtype_of(DerivedCovariant[A], DerivedCovariant[B]))
class Contravariant(Generic[T_contra]):
pass
static_assert(not is_subtype_of(Contravariant[B], Contravariant[A]))
static_assert(is_subtype_of(Contravariant[A], Contravariant[B]))
class DerivedContravariant[T](Contravariant[T]):
pass
static_assert(not is_subtype_of(DerivedContravariant[B], DerivedContravariant[A]))
static_assert(is_subtype_of(DerivedContravariant[A], DerivedContravariant[B]))
```
[linear-time-variance-talk]: https://www.youtube.com/watch?v=7uixlNTOY4s&t=9705s
[spec]: https://typing.python.org/en/latest/spec/generics.html#variance

View file

@ -217,8 +217,8 @@ class B: ...
def _[T](x: A | B):
if type(x) is A[str]:
# `type()` never returns a generic alias, so `type(x)` cannot be `A[str]`
reveal_type(x) # revealed: Never
# TODO: `type()` never returns a generic alias, so `type(x)` cannot be `A[str]`
reveal_type(x) # revealed: A[int] | B
else:
reveal_type(x) # revealed: A[int] | B
```

View file

@ -60,6 +60,7 @@ use crate::types::mro::{Mro, MroError, MroIterator};
pub(crate) use crate::types::narrow::infer_narrowing_constraint;
use crate::types::signatures::{Parameter, ParameterForm, Parameters, walk_signature};
use crate::types::tuple::TupleSpec;
use crate::types::variance::{TypeVarVariance, VarianceInferable};
use crate::unpack::EvaluationMode;
pub use crate::util::diagnostics::add_inferred_python_version_hint_to_diagnostic;
use crate::{Db, FxOrderMap, FxOrderSet, Module, Program};
@ -92,6 +93,7 @@ mod subclass_of;
mod tuple;
mod type_ordering;
mod unpacker;
mod variance;
mod visitor;
mod definition;
@ -322,6 +324,29 @@ fn class_lookup_cycle_initial<'db>(
Place::bound(Type::Never).into()
}
#[allow(clippy::trivially_copy_pass_by_ref)]
fn variance_cycle_recover<'db, T>(
_db: &'db dyn Db,
_value: &TypeVarVariance,
count: u32,
_self: T,
_typevar: BoundTypeVarInstance<'db>,
) -> salsa::CycleRecoveryAction<TypeVarVariance> {
assert!(
count <= 2,
"Should only be able to cycle at most twice: there are only three levels in the lattice, each cycle should move us one"
);
salsa::CycleRecoveryAction::Iterate
}
fn variance_cycle_initial<'db, T>(
_db: &'db dyn Db,
_self: T,
_typevar: BoundTypeVarInstance<'db>,
) -> TypeVarVariance {
TypeVarVariance::Bivariant
}
/// Meta data for `Type::Todo`, which represents a known limitation in ty.
#[cfg(debug_assertions)]
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, get_size2::GetSize)]
@ -755,7 +780,7 @@ impl<'db> Type<'db> {
Name::new_static("T_all"),
None,
None,
variance,
Some(variance),
None,
TypeVarKind::Pep695,
),
@ -963,7 +988,6 @@ impl<'db> Type<'db> {
.expect("Expected a Type::FunctionLiteral variant")
}
#[cfg(test)]
pub(crate) const fn is_function_literal(&self) -> bool {
matches!(self, Type::FunctionLiteral(..))
}
@ -5533,7 +5557,10 @@ impl<'db> Type<'db> {
ast::name::Name::new_static("Self"),
Some(class_definition),
Some(TypeVarBoundOrConstraints::UpperBound(instance).into()),
TypeVarVariance::Invariant,
// According to the [spec], we can consider `Self`
// equivalent to an invariant type variable
// [spec]: https://typing.python.org/en/latest/spec/generics.html#self
Some(TypeVarVariance::Invariant),
None,
TypeVarKind::TypingSelf,
);
@ -6315,6 +6342,99 @@ impl<'db> From<&Type<'db>> for Type<'db> {
}
}
impl<'db> VarianceInferable<'db> for Type<'db> {
fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance {
tracing::debug!(
"Checking variance of '{tvar}' in `{ty:?}`",
tvar = typevar.typevar(db).name(db),
ty = self.display(db),
);
let v = match self {
Type::ClassLiteral(class_literal) => class_literal.variance_of(db, typevar),
Type::FunctionLiteral(function_type) => {
// TODO: do we need to replace self?
function_type.signature(db).variance_of(db, typevar)
}
Type::BoundMethod(method_type) => {
// TODO: do we need to replace self?
method_type
.function(db)
.signature(db)
.variance_of(db, typevar)
}
Type::NominalInstance(nominal_instance_type) => {
nominal_instance_type.variance_of(db, typevar)
}
Type::GenericAlias(generic_alias) => generic_alias.variance_of(db, typevar),
Type::Callable(callable_type) => callable_type.signatures(db).variance_of(db, typevar),
Type::TypeVar(other_typevar) | Type::NonInferableTypeVar(other_typevar)
if other_typevar == typevar =>
{
// type variables are covariant in themselves
TypeVarVariance::Covariant
}
Type::ProtocolInstance(protocol_instance_type) => {
protocol_instance_type.variance_of(db, typevar)
}
Type::Union(union_type) => union_type
.elements(db)
.iter()
.map(|ty| ty.variance_of(db, typevar))
.collect(),
Type::Intersection(intersection_type) => intersection_type
.positive(db)
.iter()
.map(|ty| ty.variance_of(db, typevar))
.chain(intersection_type.negative(db).iter().map(|ty| {
ty.with_polarity(TypeVarVariance::Contravariant)
.variance_of(db, typevar)
}))
.collect(),
Type::PropertyInstance(property_instance_type) => property_instance_type
.getter(db)
.iter()
.chain(&property_instance_type.setter(db))
.map(|ty| ty.variance_of(db, typevar))
.collect(),
Type::SubclassOf(subclass_of_type) => subclass_of_type.variance_of(db, typevar),
Type::Dynamic(_)
| Type::Never
| Type::WrapperDescriptor(_)
| Type::MethodWrapper(_)
| Type::DataclassDecorator(_)
| Type::DataclassTransformer(_)
| Type::ModuleLiteral(_)
| Type::IntLiteral(_)
| Type::BooleanLiteral(_)
| Type::StringLiteral(_)
| Type::EnumLiteral(_)
| Type::LiteralString
| Type::BytesLiteral(_)
| Type::SpecialForm(_)
| Type::KnownInstance(_)
| Type::AlwaysFalsy
| Type::AlwaysTruthy
| Type::BoundSuper(_)
| Type::TypeVar(_)
| Type::NonInferableTypeVar(_)
| Type::TypeIs(_)
| Type::TypedDict(_)
| Type::TypeAlias(_) => TypeVarVariance::Bivariant,
};
tracing::debug!(
"Result of variance of '{tvar}' in `{ty:?}` is `{v:?}`",
tvar = typevar.typevar(db).name(db),
ty = self.display(db),
);
v
}
}
/// A mapping that can be applied to a type, producing another type. This is applied inductively to
/// the components of complex types.
///
@ -6972,8 +7092,8 @@ pub struct TypeVarInstance<'db> {
/// instead (to evaluate any lazy bound or constraints).
_bound_or_constraints: Option<TypeVarBoundOrConstraintsEvaluation<'db>>,
/// The variance of the TypeVar
variance: TypeVarVariance,
/// The explicitly specified variance of the TypeVar
explicit_variance: Option<TypeVarVariance>,
/// The default type for this TypeVar, if any. Don't use this field directly, use the
/// `default_type` method instead (to evaluate any lazy default).
@ -7065,7 +7185,7 @@ impl<'db> TypeVarInstance<'db> {
.lazy_constraints(db)
.map(|constraints| constraints.normalized_impl(db, visitor).into()),
}),
self.variance(db),
self.explicit_variance(db),
self._default(db).and_then(|default| match default {
TypeVarDefaultEvaluation::Eager(ty) => Some(ty.normalized_impl(db, visitor).into()),
TypeVarDefaultEvaluation::Lazy => self
@ -7093,7 +7213,7 @@ impl<'db> TypeVarInstance<'db> {
.lazy_constraints(db)
.map(|constraints| constraints.materialize(db, variance).into()),
}),
self.variance(db),
self.explicit_variance(db),
self._default(db).and_then(|default| match default {
TypeVarDefaultEvaluation::Eager(ty) => Some(ty.materialize(db, variance).into()),
TypeVarDefaultEvaluation::Lazy => self
@ -7118,7 +7238,7 @@ impl<'db> TypeVarInstance<'db> {
Name::new(format!("{}'instance", self.name(db))),
None,
Some(bound_or_constraints.into()),
self.variance(db),
self.explicit_variance(db),
None,
self.kind(db),
))
@ -7187,6 +7307,33 @@ pub struct BoundTypeVarInstance<'db> {
// The Salsa heap is tracked separately.
impl get_size2::GetSize for BoundTypeVarInstance<'_> {}
impl<'db> BoundTypeVarInstance<'db> {
pub(crate) fn variance_with_polarity(
self,
db: &'db dyn Db,
polarity: TypeVarVariance,
) -> TypeVarVariance {
let _span = tracing::trace_span!("variance_with_polarity").entered();
match self.typevar(db).explicit_variance(db) {
Some(explicit_variance) => explicit_variance.compose(polarity),
None => match self.binding_context(db) {
BindingContext::Definition(definition) => {
let type_inference = infer_definition_types(db, definition);
type_inference
.binding_type(definition)
.with_polarity(polarity)
.variance_of(db, self)
}
BindingContext::Synthetic => TypeVarVariance::Invariant,
},
}
}
pub(crate) fn variance(self, db: &'db dyn Db) -> TypeVarVariance {
self.variance_with_polarity(db, TypeVarVariance::Covariant)
}
}
fn walk_bound_type_var_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>(
db: &'db dyn Db,
bound_typevar: BoundTypeVarInstance<'db>,
@ -7250,28 +7397,6 @@ impl<'db> BoundTypeVarInstance<'db> {
}
}
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, salsa::Update, get_size2::GetSize)]
pub enum TypeVarVariance {
Invariant,
Covariant,
Contravariant,
Bivariant,
}
impl TypeVarVariance {
/// Flips the polarity of the variance.
///
/// Covariant becomes contravariant, contravariant becomes covariant, others remain unchanged.
pub(crate) const fn flip(self) -> Self {
match self {
TypeVarVariance::Invariant => TypeVarVariance::Invariant,
TypeVarVariance::Covariant => TypeVarVariance::Contravariant,
TypeVarVariance::Contravariant => TypeVarVariance::Covariant,
TypeVarVariance::Bivariant => TypeVarVariance::Bivariant,
}
}
}
/// Whether a typevar default is eagerly specified or lazily evaluated.
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, salsa::Update, get_size2::GetSize)]
pub enum TypeVarDefaultEvaluation<'db> {

View file

@ -13,8 +13,10 @@ use crate::module_resolver::KnownModule;
use crate::semantic_index::definition::{Definition, DefinitionState};
use crate::semantic_index::place::ScopedPlaceId;
use crate::semantic_index::scope::NodeWithScopeKind;
use crate::semantic_index::symbol::Symbol;
use crate::semantic_index::{
BindingWithConstraints, DeclarationWithConstraint, SemanticIndex, attribute_declarations,
attribute_scopes,
};
use crate::types::context::InferContext;
use crate::types::diagnostic::{INVALID_LEGACY_TYPE_VARIABLE, INVALID_TYPE_ALIAS_TYPE};
@ -28,8 +30,8 @@ use crate::types::{
ApplyTypeMappingVisitor, BareTypeAliasType, Binding, BoundSuperError, BoundSuperType,
CallableType, DataclassParams, DeprecatedInstance, HasRelationToVisitor, KnownInstanceType,
NormalizedVisitor, PropertyInstanceType, StringLiteralType, TypeAliasType, TypeMapping,
TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, declaration_type,
infer_definition_types, todo_type,
TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, VarianceInferable,
declaration_type, infer_definition_types, todo_type,
};
use crate::{
Db, FxIndexMap, FxOrderSet, Program,
@ -314,6 +316,51 @@ impl<'db> From<GenericAlias<'db>> for Type<'db> {
}
}
#[salsa::tracked]
impl<'db> VarianceInferable<'db> for GenericAlias<'db> {
#[salsa::tracked]
fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance {
let origin = self.origin(db);
let specialization = self.specialization(db);
// if the class is the thing defining the variable, then it can
// reference it without it being applied to the specialization
std::iter::once(origin.variance_of(db, typevar))
.chain(
specialization
.generic_context(db)
.variables(db)
.iter()
.zip(specialization.types(db))
.map(|(generic_typevar, ty)| {
if let Some(explicit_variance) =
generic_typevar.typevar(db).explicit_variance(db)
{
ty.with_polarity(explicit_variance).variance_of(db, typevar)
} else {
// `with_polarity` composes the passed variance with the
// inferred one. The inference is done lazily, as we can
// sometimes determine the result just from the passed
// variance. This operation is commutative, so we could
// infer either first. We choose to make the `ClassLiteral`
// variance lazy, as it is known to be expensive, requiring
// that we traverse all members.
//
// If salsa let us look at the cache, we could check first
// to see if the class literal query was already run.
let typevar_variance_in_substituted_type = ty.variance_of(db, typevar);
origin
.with_polarity(typevar_variance_in_substituted_type)
.variance_of(db, *generic_typevar)
}
}),
)
.collect()
}
}
/// Represents a class type, which might be a non-generic class, or a specialization of a generic
/// class.
#[derive(
@ -1136,6 +1183,15 @@ impl<'db> From<ClassType<'db>> for Type<'db> {
}
}
impl<'db> VarianceInferable<'db> for ClassType<'db> {
fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance {
match self {
Self::NonGeneric(class) => class.variance_of(db, typevar),
Self::Generic(generic) => generic.variance_of(db, typevar),
}
}
}
/// A filter that describes which methods are considered when looking for implicit attribute assignments
/// in [`ClassLiteral::implicit_attribute`].
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
@ -3060,6 +3116,126 @@ impl<'db> From<ClassLiteral<'db>> for ClassType<'db> {
}
}
#[salsa::tracked]
impl<'db> VarianceInferable<'db> for ClassLiteral<'db> {
#[salsa::tracked(cycle_fn=crate::types::variance_cycle_recover, cycle_initial=crate::types::variance_cycle_initial)]
fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance {
let typevar_in_generic_context = self
.generic_context(db)
.is_some_and(|generic_context| generic_context.variables(db).contains(&typevar));
if !typevar_in_generic_context {
return TypeVarVariance::Bivariant;
}
let class_body_scope = self.body_scope(db);
let file = class_body_scope.file(db);
let index = semantic_index(db, file);
let explicit_bases_variances = self
.explicit_bases(db)
.iter()
.map(|class| class.variance_of(db, typevar));
let default_attribute_variance = {
let is_namedtuple = CodeGeneratorKind::NamedTuple.matches(db, self);
// Python 3.13 introduced a synthesized `__replace__` method on dataclasses which uses
// their field types in contravariant position, thus meaning a frozen dataclass must
// still be invariant in its field types. Other synthesized methods on dataclasses are
// not considered here, since they don't use field types in their signatures. TODO:
// ideally we'd have a single source of truth for information about synthesized
// methods, so we just look them up normally and don't hardcode this knowledge here.
let is_frozen_dataclass = Program::get(db).python_version(db) <= PythonVersion::PY312
&& self
.dataclass_params(db)
.is_some_and(|params| params.contains(DataclassParams::FROZEN));
if is_namedtuple || is_frozen_dataclass {
TypeVarVariance::Covariant
} else {
TypeVarVariance::Invariant
}
};
let init_name: &Name = &"__init__".into();
let new_name: &Name = &"__new__".into();
let use_def_map = index.use_def_map(class_body_scope.file_scope_id(db));
let table = place_table(db, class_body_scope);
let attribute_places_and_qualifiers =
use_def_map
.all_end_of_scope_symbol_declarations()
.map(|(symbol_id, declarations)| {
let place_and_qual =
place_from_declarations(db, declarations).ignore_conflicting_declarations();
(symbol_id, place_and_qual)
})
.chain(use_def_map.all_end_of_scope_symbol_bindings().map(
|(symbol_id, bindings)| (symbol_id, place_from_bindings(db, bindings).into()),
))
.filter_map(|(symbol_id, place_and_qual)| {
if let Some(name) = table.place(symbol_id).as_symbol().map(Symbol::name) {
(![init_name, new_name].contains(&name))
.then_some((name.to_string(), place_and_qual))
} else {
None
}
});
// Dataclasses can have some additional synthesized methods (`__eq__`, `__hash__`,
// `__lt__`, etc.) but none of these will have field types type variables in their signatures, so we
// don't need to consider them for variance.
let attribute_names = attribute_scopes(db, self.body_scope(db))
.flat_map(|function_scope_id| {
index
.place_table(function_scope_id)
.members()
.filter_map(|member| member.as_instance_attribute())
.filter(|name| *name != init_name && *name != new_name)
.map(std::string::ToString::to_string)
.collect::<Vec<_>>()
})
.dedup();
let attribute_variances = attribute_names
.map(|name| {
let place_and_quals = self.own_instance_member(db, &name);
(name, place_and_quals)
})
.chain(attribute_places_and_qualifiers)
.dedup()
.filter_map(|(name, place_and_qual)| {
place_and_qual.place.ignore_possibly_unbound().map(|ty| {
let variance = if place_and_qual
.qualifiers
// `CLASS_VAR || FINAL` is really `all()`, but
// we want to be robust against new qualifiers
.intersects(TypeQualifiers::CLASS_VAR | TypeQualifiers::FINAL)
// We don't allow mutation of methods or properties
|| ty.is_function_literal()
|| ty.is_property_instance()
// Underscore-prefixed attributes are assumed not to be externally mutated
|| name.starts_with('_')
{
// CLASS_VAR: class vars generally shouldn't contain the
// type variable, but they could if it's a
// callable type. They can't be mutated on instances.
//
// FINAL: final attributes are immutable, and thus covariant
TypeVarVariance::Covariant
} else {
default_attribute_variance
};
ty.with_polarity(variance).variance_of(db, typevar)
})
});
attribute_variances
.chain(explicit_bases_variances)
.collect()
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, get_size2::GetSize)]
pub(super) enum InheritanceCycle {
/// The class is cyclically defined and is a participant in the cycle.
@ -4673,7 +4849,7 @@ impl KnownClass {
&target.id,
Some(containing_assignment),
bound_or_constraint,
variance,
Some(variance),
default.map(Into::into),
TypeVarKind::Legacy,
),

View file

@ -549,12 +549,7 @@ impl<'db> Specialization<'db> {
.into_iter()
.zip(self.types(db))
.map(|(bound_typevar, vartype)| {
let variance = match bound_typevar.typevar(db).variance(db) {
TypeVarVariance::Invariant => TypeVarVariance::Invariant,
TypeVarVariance::Covariant => variance,
TypeVarVariance::Contravariant => variance.flip(),
TypeVarVariance::Bivariant => unreachable!(),
};
let variance = bound_typevar.variance_with_polarity(db, variance);
vartype.materialize(db, variance)
})
.collect();
@ -599,7 +594,7 @@ impl<'db> Specialization<'db> {
// - contravariant: verify that other_type <: self_type
// - invariant: verify that self_type <: other_type AND other_type <: self_type
// - bivariant: skip, can't make subtyping/assignability false
let compatible = match bound_typevar.typevar(db).variance(db) {
let compatible = match bound_typevar.variance(db) {
TypeVarVariance::Invariant => match relation {
TypeRelation::Subtyping => self_type.is_equivalent_to(db, *other_type),
TypeRelation::Assignability => {
@ -639,7 +634,7 @@ impl<'db> Specialization<'db> {
// - contravariant: verify that other_type == self_type
// - invariant: verify that self_type == other_type
// - bivariant: skip, can't make equivalence false
let compatible = match bound_typevar.typevar(db).variance(db) {
let compatible = match bound_typevar.variance(db) {
TypeVarVariance::Invariant
| TypeVarVariance::Covariant
| TypeVarVariance::Contravariant => self_type.is_equivalent_to(db, *other_type),

View file

@ -124,8 +124,8 @@ use crate::types::{
MemberLookupPolicy, MetaclassCandidate, PEP695TypeAliasType, Parameter, ParameterForm,
Parameters, SpecialFormType, SubclassOfType, Truthiness, Type, TypeAliasType,
TypeAndQualifiers, TypeIsType, TypeQualifiers, TypeVarBoundOrConstraintsEvaluation,
TypeVarDefaultEvaluation, TypeVarInstance, TypeVarKind, TypeVarVariance, UnionBuilder,
UnionType, binding_type, todo_type,
TypeVarDefaultEvaluation, TypeVarInstance, TypeVarKind, UnionBuilder, UnionType, binding_type,
todo_type,
};
use crate::unpack::{EvaluationMode, Unpack, UnpackPosition};
use crate::util::diagnostics::format_enumeration;
@ -3470,7 +3470,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
&name.id,
Some(definition),
bound_or_constraint,
TypeVarVariance::Invariant, // TODO: infer this
None,
default.as_deref().map(|_| TypeVarDefaultEvaluation::Lazy),
TypeVarKind::Pep695,
)));

View file

@ -12,7 +12,7 @@ use crate::types::protocol_class::walk_protocol_interface;
use crate::types::tuple::{TupleSpec, TupleType};
use crate::types::{
ApplyTypeMappingVisitor, ClassBase, DynamicType, HasRelationToVisitor, IsDisjointVisitor,
NormalizedVisitor, TypeMapping, TypeRelation, TypeTransformer,
NormalizedVisitor, TypeMapping, TypeRelation, TypeTransformer, VarianceInferable,
};
use crate::{Db, FxOrderSet};
@ -406,6 +406,12 @@ pub(crate) struct SliceLiteral {
pub(crate) step: Option<i32>,
}
impl<'db> VarianceInferable<'db> for NominalInstanceType<'db> {
fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance {
self.class(db).variance_of(db, typevar)
}
}
/// A `ProtocolInstanceType` represents the set of all possible runtime objects
/// that conform to the interface described by a certain protocol.
#[derive(
@ -593,6 +599,12 @@ impl<'db> ProtocolInstanceType<'db> {
}
}
impl<'db> VarianceInferable<'db> for ProtocolInstanceType<'db> {
fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance {
self.inner.variance_of(db, typevar)
}
}
/// An enumeration of the two kinds of protocol types: those that originate from a class
/// definition in source code, and those that are synthesized from a set of members.
#[derive(
@ -618,12 +630,23 @@ impl<'db> Protocol<'db> {
}
}
impl<'db> VarianceInferable<'db> for Protocol<'db> {
fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance {
match self {
Protocol::FromClass(class_type) => class_type.variance_of(db, typevar),
Protocol::Synthesized(synthesized_protocol_type) => {
synthesized_protocol_type.variance_of(db, typevar)
}
}
}
}
mod synthesized_protocol {
use crate::semantic_index::definition::Definition;
use crate::types::protocol_class::ProtocolInterface;
use crate::types::{
ApplyTypeMappingVisitor, BoundTypeVarInstance, NormalizedVisitor, TypeMapping,
TypeVarVariance,
TypeVarVariance, VarianceInferable,
};
use crate::{Db, FxOrderSet};
@ -676,4 +699,14 @@ mod synthesized_protocol {
self.0
}
}
impl<'db> VarianceInferable<'db> for SynthesizedProtocolType<'db> {
fn variance_of(
self,
db: &'db dyn Db,
typevar: BoundTypeVarInstance<'db>,
) -> TypeVarVariance {
self.0.variance_of(db, typevar)
}
}
}

View file

@ -18,7 +18,7 @@ use crate::{
types::{
BoundTypeVarInstance, CallableType, ClassBase, ClassLiteral, IsDisjointVisitor,
KnownFunction, NormalizedVisitor, PropertyInstanceType, Signature, Type, TypeMapping,
TypeQualifiers, TypeRelation, TypeTransformer,
TypeQualifiers, TypeRelation, TypeTransformer, VarianceInferable,
signatures::{Parameter, Parameters},
},
};
@ -301,6 +301,15 @@ impl<'db> ProtocolInterface<'db> {
}
}
impl<'db> VarianceInferable<'db> for ProtocolInterface<'db> {
fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance {
self.members(db)
// TODO do we need to switch on member kind?
.map(|member| member.ty().variance_of(db, typevar))
.collect()
}
}
#[derive(Debug, PartialEq, Eq, Clone, Hash, salsa::Update, get_size2::GetSize)]
pub(super) struct ProtocolMemberData<'db> {
kind: ProtocolMemberKind<'db>,

View file

@ -20,7 +20,7 @@ use crate::semantic_index::definition::Definition;
use crate::types::generics::{GenericContext, walk_generic_context};
use crate::types::{
BindingContext, BoundTypeVarInstance, KnownClass, NormalizedVisitor, TypeMapping, TypeRelation,
todo_type,
VarianceInferable, todo_type,
};
use crate::{Db, FxOrderSet};
use ruff_python_ast::{self as ast, name::Name};
@ -223,6 +223,16 @@ impl<'a, 'db> IntoIterator for &'a CallableSignature<'db> {
}
}
impl<'db> VarianceInferable<'db> for &CallableSignature<'db> {
// TODO: possibly need to replace self
fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance {
self.overloads
.iter()
.map(|signature| signature.variance_of(db, typevar))
.collect()
}
}
/// The signature of one of the overloads of a callable.
#[derive(Clone, Debug, salsa::Update, get_size2::GetSize)]
pub struct Signature<'db> {
@ -982,6 +992,28 @@ impl std::hash::Hash for Signature<'_> {
}
}
impl<'db> VarianceInferable<'db> for &Signature<'db> {
fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance {
tracing::debug!(
"Checking variance of `{tvar}` in `{self:?}`",
tvar = typevar.typevar(db).name(db)
);
itertools::chain(
self.parameters
.iter()
.filter_map(|parameter| match parameter.form {
ParameterForm::Type => None,
ParameterForm::Value => parameter.annotated_type().map(|ty| {
ty.with_polarity(TypeVarVariance::Contravariant)
.variance_of(db, typevar)
}),
}),
self.return_ty.map(|ty| ty.variance_of(db, typevar)),
)
.collect()
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, salsa::Update, get_size2::GetSize)]
pub(crate) struct Parameters<'db> {
// TODO: use SmallVec here once invariance bug is fixed

View file

@ -2,6 +2,7 @@ use ruff_python_ast::name::Name;
use crate::place::PlaceAndQualifiers;
use crate::semantic_index::definition::Definition;
use crate::types::variance::VarianceInferable;
use crate::types::{
ApplyTypeMappingVisitor, BindingContext, BoundTypeVarInstance, ClassType, DynamicType,
HasRelationToVisitor, KnownClass, MemberLookupPolicy, NormalizedVisitor, Type, TypeMapping,
@ -103,7 +104,7 @@ impl<'db> SubclassOfType<'db> {
)
.into(),
),
variance,
Some(variance),
None,
TypeVarKind::Pep695,
),
@ -215,6 +216,15 @@ impl<'db> SubclassOfType<'db> {
}
}
impl<'db> VarianceInferable<'db> for SubclassOfType<'db> {
fn variance_of(self, db: &dyn Db, typevar: BoundTypeVarInstance<'_>) -> TypeVarVariance {
match self.subclass_of {
SubclassOfInner::Dynamic(_) => TypeVarVariance::Bivariant,
SubclassOfInner::Class(class) => class.variance_of(db, typevar),
}
}
}
/// An enumeration of the different kinds of `type[]` types that a [`SubclassOfType`] can represent:
///
/// 1. A "subclass of a class": `type[C]` for any class object `C`

View file

@ -0,0 +1,138 @@
use crate::{Db, types::BoundTypeVarInstance};
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, salsa::Update, get_size2::GetSize)]
pub enum TypeVarVariance {
Invariant,
Covariant,
Contravariant,
Bivariant,
}
impl TypeVarVariance {
pub const fn bottom() -> Self {
TypeVarVariance::Bivariant
}
pub const fn top() -> Self {
TypeVarVariance::Invariant
}
// supremum
#[must_use]
pub(crate) const fn join(self, other: Self) -> Self {
use TypeVarVariance::{Bivariant, Contravariant, Covariant, Invariant};
match (self, other) {
(Invariant, _) | (_, Invariant) => Invariant,
(Covariant, Covariant) => Covariant,
(Contravariant, Contravariant) => Contravariant,
(Covariant, Contravariant) | (Contravariant, Covariant) => Invariant,
(Bivariant, other) | (other, Bivariant) => other,
}
}
/// Compose two variances: useful for combining use-site and definition-site variances, e.g.
/// `C[D[T]]` or function argument/return position variances.
///
/// `other` is a thunk to avoid unnecessary computation when `self` is `Bivariant`.
///
/// Based on the variance composition/transformation operator in
/// <https://people.cs.umass.edu/~yannis/variance-extended2011.pdf>, page 5
///
/// While their operation would have `compose(Invariant, Bivariant) ==
/// Invariant`, we instead have it evaluate to `Bivariant`. This is a valid
/// choice, as discussed on that same page, where type equality is semantic
/// rather than syntactic. To see that this holds for our setting consider
/// the type
/// ```python
/// type ConstantInt[T] = int
/// ```
/// We would say `ConstantInt[str]` = `ConstantInt[float]`, so we qualify as
/// using semantic equivalence.
#[must_use]
pub(crate) fn compose(self, other: Self) -> Self {
self.compose_thunk(|| other)
}
/// Like `compose`, but takes `other` as a thunk to avoid unnecessary
/// computation when `self` is `Bivariant`.
#[must_use]
pub(crate) fn compose_thunk<F>(self, other: F) -> Self
where
F: FnOnce() -> Self,
{
match self {
TypeVarVariance::Covariant => other(),
TypeVarVariance::Contravariant => other().flip(),
TypeVarVariance::Bivariant => TypeVarVariance::Bivariant,
TypeVarVariance::Invariant => {
if TypeVarVariance::Bivariant == other() {
TypeVarVariance::Bivariant
} else {
TypeVarVariance::Invariant
}
}
}
}
/// Flips the polarity of the variance.
///
/// Covariant becomes contravariant, contravariant becomes covariant, others remain unchanged.
pub(crate) const fn flip(self) -> Self {
match self {
TypeVarVariance::Invariant => TypeVarVariance::Invariant,
TypeVarVariance::Covariant => TypeVarVariance::Contravariant,
TypeVarVariance::Contravariant => TypeVarVariance::Covariant,
TypeVarVariance::Bivariant => TypeVarVariance::Bivariant,
}
}
}
impl std::iter::FromIterator<Self> for TypeVarVariance {
fn from_iter<T: IntoIterator<Item = Self>>(iter: T) -> Self {
use std::ops::ControlFlow;
// TODO: use `into_value` when control_flow_into_value is stable
let (ControlFlow::Break(variance) | ControlFlow::Continue(variance)) = iter
.into_iter()
.try_fold(TypeVarVariance::Bivariant, |acc, variance| {
let supremum = acc.join(variance);
match supremum {
// short circuit at top
TypeVarVariance::Invariant => ControlFlow::Break(supremum),
TypeVarVariance::Bivariant
| TypeVarVariance::Covariant
| TypeVarVariance::Contravariant => ControlFlow::Continue(supremum),
}
});
variance
}
}
pub(crate) trait VarianceInferable<'db>: Sized {
fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance;
fn with_polarity(self, polarity: TypeVarVariance) -> WithPolarity<Self> {
WithPolarity {
variance_inferable: self,
polarity,
}
}
}
pub(crate) struct WithPolarity<T> {
variance_inferable: T,
polarity: TypeVarVariance,
}
impl<'db, T> VarianceInferable<'db> for WithPolarity<T>
where
T: VarianceInferable<'db>,
{
fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance {
let WithPolarity {
variance_inferable,
polarity,
} = self;
polarity.compose_thunk(|| variance_inferable.variance_of(db, typevar))
}
}