diff --git a/crates/ty_python_semantic/resources/mdtest/generics/pep695/classes.md b/crates/ty_python_semantic/resources/mdtest/generics/pep695/classes.md index 92840bcf1a..467b6ed42c 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/pep695/classes.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/pep695/classes.md @@ -237,10 +237,15 @@ If the type of a constructor parameter is a class typevar, we can use that to in parameter. The types inferred from a type context and from a constructor parameter must be consistent with each other. +We have to add `x: T` to the classes to ensure they're not bivariant in `T` (__new__ and __init__ +signatures don't count towards variance). + ### `__new__` only ```py class C[T]: + x: T + def __new__(cls, x: T) -> "C[T]": return object.__new__(cls) @@ -254,6 +259,8 @@ wrong_innards: C[int] = C("five") ```py class C[T]: + x: T + def __init__(self, x: T) -> None: ... reveal_type(C(1)) # revealed: C[int] @@ -266,6 +273,8 @@ wrong_innards: C[int] = C("five") ```py class C[T]: + x: T + def __new__(cls, x: T) -> "C[T]": return object.__new__(cls) @@ -281,6 +290,8 @@ wrong_innards: C[int] = C("five") ```py class C[T]: + x: T + def __new__(cls, *args, **kwargs) -> "C[T]": return object.__new__(cls) @@ -292,6 +303,8 @@ reveal_type(C(1)) # revealed: C[int] wrong_innards: C[int] = C("five") class D[T]: + x: T + def __new__(cls, x: T) -> "D[T]": return object.__new__(cls) @@ -378,6 +391,8 @@ def func8(t1: tuple[complex, list[int]], t2: tuple[int, *tuple[str, ...]], t3: t ```py class C[T]: + x: T + def __init__[S](self, x: T, y: S) -> None: ... reveal_type(C(1, 1)) # revealed: C[int] @@ -395,6 +410,10 @@ from __future__ import annotations from typing import overload class C[T]: + # we need to use the type variable or else the class is bivariant in T, and + # specializations become meaningless + x: T + @overload def __init__(self: C[str], x: str) -> None: ... @overload diff --git a/crates/ty_python_semantic/resources/mdtest/generics/pep695/variance.md b/crates/ty_python_semantic/resources/mdtest/generics/pep695/variance.md index e98d01a35f..58ec406333 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/pep695/variance.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/pep695/variance.md @@ -40,8 +40,6 @@ class C[T]: class D[U](C[U]): pass -# TODO: no error -# error: [static-assert-error] static_assert(is_assignable_to(C[B], C[A])) static_assert(not is_assignable_to(C[A], C[B])) static_assert(is_assignable_to(C[A], C[Any])) @@ -49,8 +47,6 @@ static_assert(is_assignable_to(C[B], C[Any])) static_assert(is_assignable_to(C[Any], C[A])) static_assert(is_assignable_to(C[Any], C[B])) -# TODO: no error -# error: [static-assert-error] static_assert(is_assignable_to(D[B], C[A])) static_assert(not is_assignable_to(D[A], C[B])) static_assert(is_assignable_to(D[A], C[Any])) @@ -58,8 +54,6 @@ static_assert(is_assignable_to(D[B], C[Any])) static_assert(is_assignable_to(D[Any], C[A])) static_assert(is_assignable_to(D[Any], C[B])) -# TODO: no error -# error: [static-assert-error] static_assert(is_subtype_of(C[B], C[A])) static_assert(not is_subtype_of(C[A], C[B])) static_assert(not is_subtype_of(C[A], C[Any])) @@ -67,8 +61,6 @@ static_assert(not is_subtype_of(C[B], C[Any])) static_assert(not is_subtype_of(C[Any], C[A])) static_assert(not is_subtype_of(C[Any], C[B])) -# TODO: no error -# error: [static-assert-error] static_assert(is_subtype_of(D[B], C[A])) static_assert(not is_subtype_of(D[A], C[B])) static_assert(not is_subtype_of(D[A], C[Any])) @@ -124,8 +116,6 @@ class D[U](C[U]): pass static_assert(not is_assignable_to(C[B], C[A])) -# TODO: no error -# error: [static-assert-error] static_assert(is_assignable_to(C[A], C[B])) static_assert(is_assignable_to(C[A], C[Any])) static_assert(is_assignable_to(C[B], C[Any])) @@ -133,8 +123,6 @@ static_assert(is_assignable_to(C[Any], C[A])) static_assert(is_assignable_to(C[Any], C[B])) static_assert(not is_assignable_to(D[B], C[A])) -# TODO: no error -# error: [static-assert-error] static_assert(is_assignable_to(D[A], C[B])) static_assert(is_assignable_to(D[A], C[Any])) static_assert(is_assignable_to(D[B], C[Any])) @@ -142,8 +130,6 @@ static_assert(is_assignable_to(D[Any], C[A])) static_assert(is_assignable_to(D[Any], C[B])) static_assert(not is_subtype_of(C[B], C[A])) -# TODO: no error -# error: [static-assert-error] static_assert(is_subtype_of(C[A], C[B])) static_assert(not is_subtype_of(C[A], C[Any])) static_assert(not is_subtype_of(C[B], C[Any])) @@ -151,8 +137,6 @@ static_assert(not is_subtype_of(C[Any], C[A])) static_assert(not is_subtype_of(C[Any], C[B])) static_assert(not is_subtype_of(D[B], C[A])) -# TODO: no error -# error: [static-assert-error] static_assert(is_subtype_of(D[A], C[B])) static_assert(not is_subtype_of(D[A], C[Any])) static_assert(not is_subtype_of(D[B], C[Any])) @@ -297,34 +281,22 @@ class C[T]: class D[U](C[U]): pass -# TODO: no error -# error: [static-assert-error] static_assert(is_assignable_to(C[B], C[A])) -# TODO: no error -# error: [static-assert-error] static_assert(is_assignable_to(C[A], C[B])) static_assert(is_assignable_to(C[A], C[Any])) static_assert(is_assignable_to(C[B], C[Any])) static_assert(is_assignable_to(C[Any], C[A])) static_assert(is_assignable_to(C[Any], C[B])) -# TODO: no error -# error: [static-assert-error] static_assert(is_assignable_to(D[B], C[A])) static_assert(is_subtype_of(C[A], C[A])) -# TODO: no error -# error: [static-assert-error] static_assert(is_assignable_to(D[A], C[B])) static_assert(is_assignable_to(D[A], C[Any])) static_assert(is_assignable_to(D[B], C[Any])) static_assert(is_assignable_to(D[Any], C[A])) static_assert(is_assignable_to(D[Any], C[B])) -# TODO: no error -# error: [static-assert-error] static_assert(is_subtype_of(C[B], C[A])) -# TODO: no error -# error: [static-assert-error] static_assert(is_subtype_of(C[A], C[B])) static_assert(not is_subtype_of(C[A], C[Any])) static_assert(not is_subtype_of(C[B], C[Any])) @@ -332,11 +304,7 @@ static_assert(not is_subtype_of(C[Any], C[A])) static_assert(not is_subtype_of(C[Any], C[B])) static_assert(not is_subtype_of(C[Any], C[Any])) -# TODO: no error -# error: [static-assert-error] static_assert(is_subtype_of(D[B], C[A])) -# TODO: no error -# error: [static-assert-error] static_assert(is_subtype_of(D[A], C[B])) static_assert(not is_subtype_of(D[A], C[Any])) static_assert(not is_subtype_of(D[B], C[Any])) @@ -345,23 +313,11 @@ static_assert(not is_subtype_of(D[Any], C[B])) static_assert(is_equivalent_to(C[A], C[A])) static_assert(is_equivalent_to(C[B], C[B])) -# TODO: no error -# error: [static-assert-error] static_assert(is_equivalent_to(C[B], C[A])) -# TODO: no error -# error: [static-assert-error] static_assert(is_equivalent_to(C[A], C[B])) -# TODO: no error -# error: [static-assert-error] static_assert(is_equivalent_to(C[A], C[Any])) -# TODO: no error -# error: [static-assert-error] static_assert(is_equivalent_to(C[B], C[Any])) -# TODO: no error -# error: [static-assert-error] static_assert(is_equivalent_to(C[Any], C[A])) -# TODO: no error -# error: [static-assert-error] static_assert(is_equivalent_to(C[Any], C[B])) static_assert(not is_equivalent_to(D[A], C[A])) @@ -380,4 +336,504 @@ static_assert(not is_equivalent_to(D[Any], C[Any])) static_assert(not is_equivalent_to(D[Any], C[Unknown])) ``` +## Mutual Recursion + +This example due to Martin Huschenbett's PyCon 2025 talk, +[Linear Time variance Inference for PEP 695][linear-time-variance-talk] + +```py +from ty_extensions import is_subtype_of, static_assert +from typing import Any + +class A: ... +class B(A): ... + +class C[X]: + def f(self) -> "D[X]": + return D() + + def g(self, x: X) -> None: ... + +class D[Y]: + def h(self) -> C[Y]: + return C() +``` + +`C` is contravariant in `X`, and `D` in `Y`: + +- `C` has two occurrences of `X` + - `X` occurs in the return type of `f` as `D[X]` (`X` is substituted in for `Y`) + - `D` has one occurrence of `Y` + - `Y` occurs in the return type of `h` as `C[Y]` + - `X` occurs contravariantly as a parameter in `g` + +Thus the variance of `X` in `C` depends on itself. We want to infer the least restrictive possible +variance, so in such cases we begin by assuming that the point where we detect the cycle is +bivariant. + +If we thus assume `X` is bivariant in `C`, then `Y` will be bivariant in `D`, as `D`'s only +occurrence of `Y` is in `C`. Then we consider `X` in `C` once more. We have two occurrences: `D[X]` +covariantly in a return type, and `X` contravariantly in an argument type. With one bivariant and +one contravariant occurrence, we update our inference of `X` in `C` to contravariant---the supremum +of contravariant and bivariant in the lattice. + +Now that we've updated the variance of `X` in `C`, we re-evaluate `Y` in `D`. It only has the one +occurrence `C[Y]`, which we now infer is contravariant, and so we infer contravariance for `Y` in +`D` as well. + +Because the variance of `X` in `C` depends on that of `Y` in `D`, we have to re-evaluate now that +we've updated the latter to contravariant. The variance of `X` in `C` is now the supremum of +contravariant and contravariant---giving us contravariant---and so remains unchanged. + +Once we've completed a turn around the cycle with nothing changed, we've reached a fixed-point---the +variance inference will not change any further---and so we finally conclude that both `X` in `C` and +`Y` in `D` are contravariant. + +```py +static_assert(not is_subtype_of(C[B], C[A])) +static_assert(is_subtype_of(C[A], C[B])) +static_assert(not is_subtype_of(C[A], C[Any])) +static_assert(not is_subtype_of(C[B], C[Any])) +static_assert(not is_subtype_of(C[Any], C[A])) +static_assert(not is_subtype_of(C[Any], C[B])) + +static_assert(not is_subtype_of(D[B], D[A])) +static_assert(is_subtype_of(D[A], D[B])) +static_assert(not is_subtype_of(D[A], D[Any])) +static_assert(not is_subtype_of(D[B], D[Any])) +static_assert(not is_subtype_of(D[Any], D[A])) +static_assert(not is_subtype_of(D[Any], D[B])) +``` + +## Class Attributes + +### Mutable Attributes + +Normal attributes are mutable, and so make the enclosing class invariant in this typevar (see +[inv]). + +```py +from ty_extensions import is_subtype_of, static_assert + +class A: ... +class B(A): ... + +class C[T]: + x: T + +static_assert(not is_subtype_of(C[B], C[A])) +static_assert(not is_subtype_of(C[A], C[B])) +``` + +One might think that occurrences in the types of normal attributes are covariant, but they are +mutable, and thus the occurrences are invariant. + +### Immutable Attributes + +Immutable attributes can't be written to, and thus constrain the typevar to covariance, not +invariance. + +#### Final attributes + +```py +from typing import Final +from ty_extensions import is_subtype_of, static_assert + +class A: ... +class B(A): ... + +class C[T]: + x: Final[T] + +static_assert(is_subtype_of(C[B], C[A])) +static_assert(not is_subtype_of(C[A], C[B])) +``` + +#### Underscore-prefixed attributes + +Underscore-prefixed instance attributes are considered private, and thus are assumed not externally +mutated. + +```py +from ty_extensions import is_subtype_of, static_assert + +class A: ... +class B(A): ... + +class C[T]: + _x: T + + @property + def x(self) -> T: + return self._x + +static_assert(is_subtype_of(C[B], C[A])) +static_assert(not is_subtype_of(C[A], C[B])) + +class D[T]: + def __init__(self, x: T): + self._x = x + + @property + def x(self) -> T: + return self._x + +static_assert(is_subtype_of(D[B], D[A])) +static_assert(not is_subtype_of(D[A], D[B])) +``` + +#### Frozen dataclasses in Python 3.12 and earlier + +```py +from dataclasses import dataclass, field +from ty_extensions import is_subtype_of, static_assert + +class A: ... +class B(A): ... + +@dataclass(frozen=True) +class D[U]: + y: U + +static_assert(is_subtype_of(D[B], D[A])) +static_assert(not is_subtype_of(D[A], D[B])) + +@dataclass(frozen=True) +class E[U]: + y: U = field() + +static_assert(is_subtype_of(E[B], E[A])) +static_assert(not is_subtype_of(E[A], E[B])) +``` + +#### Frozen dataclasses in Python 3.13 and later + +```toml +[environment] +python-version = "3.13" +``` + +Python 3.13 introduced a new synthesized `__replace__` method on dataclasses, which uses every field +type in a contravariant position (as a parameter to `__replace__`). This means that frozen +dataclasses on Python 3.13+ can't be covariant in their field types. + +```py +from dataclasses import dataclass +from ty_extensions import is_subtype_of, static_assert + +class A: ... +class B(A): ... + +@dataclass(frozen=True) +class D[U]: + y: U + +static_assert(not is_subtype_of(D[B], D[A])) +static_assert(not is_subtype_of(D[A], D[B])) +``` + +#### NamedTuple + +```py +from typing import NamedTuple +from ty_extensions import is_subtype_of, static_assert + +class A: ... +class B(A): ... + +class E[V](NamedTuple): + z: V + +static_assert(is_subtype_of(E[B], E[A])) +static_assert(not is_subtype_of(E[A], E[B])) +``` + +A subclass of a `NamedTuple` can still be covariant: + +```py +class D[T](E[T]): + pass + +static_assert(is_subtype_of(D[B], D[A])) +static_assert(not is_subtype_of(D[A], D[B])) +``` + +But adding a new generic attribute on the subclass makes it invariant (the added attribute is not a +`NamedTuple` field, and thus not immutable): + +```py +class C[T](E[T]): + w: T + +static_assert(not is_subtype_of(C[B], C[A])) +static_assert(not is_subtype_of(C[A], C[B])) +``` + +### Properties + +Properties constrain to covariance if they are get-only and invariant if they are get-set: + +```py +from ty_extensions import static_assert, is_subtype_of + +class A: ... +class B(A): ... + +class C[T]: + @property + def x(self) -> T | None: + return None + +class D[U]: + @property + def y(self) -> U | None: + return None + + @y.setter + def y(self, value: U): ... + +static_assert(is_subtype_of(C[B], C[A])) +static_assert(not is_subtype_of(C[A], C[B])) +static_assert(not is_subtype_of(D[B], D[A])) +static_assert(not is_subtype_of(D[A], D[B])) +``` + +### Implicit Attributes + +Implicit attributes work like normal ones + +```py +from ty_extensions import static_assert, is_subtype_of + +class A: ... +class B(A): ... + +class C[T]: + def f(self) -> None: + self.x: T | None = None + +static_assert(not is_subtype_of(C[B], C[A])) +static_assert(not is_subtype_of(C[A], C[B])) +``` + +### Constructors: excluding `__init__` and `__new__` + +We consider it invalid to call `__init__` explicitly on an existing object. Likewise, `__new__` is +only used at the beginning of an object's life. As such, we don't need to worry about the variance +impact of these methods. + +```py +from ty_extensions import static_assert, is_subtype_of + +class A: ... +class B(A): ... + +class C[T]: + def __init__(self, x: T): ... + def __new__(self, x: T): ... + +static_assert(is_subtype_of(C[B], C[A])) +static_assert(is_subtype_of(C[A], C[B])) +``` + +This example is then bivariant because it doesn't use `T` outside of the two exempted methods. + +This holds likewise for dataclasses with synthesized `__init__`: + +```py +from dataclasses import dataclass + +@dataclass(init=True, frozen=True) +class D[T]: + x: T + +# Covariant due to the read-only T-typed attribute; the `__init__` is ignored and doesn't make it +# invariant: + +static_assert(is_subtype_of(D[B], D[A])) +static_assert(not is_subtype_of(D[A], D[B])) +``` + +## Union Types + +Union types are covariant in all their members. If `A <: B`, then `A | C <: B | C` and +`C | A <: C | B`. + +```py +from ty_extensions import is_assignable_to, is_subtype_of, static_assert + +class A: ... +class B(A): ... +class C: ... + +# Union types are covariant in their members +static_assert(is_subtype_of(B | C, A | C)) +static_assert(is_subtype_of(C | B, C | A)) +static_assert(not is_subtype_of(A | C, B | C)) +static_assert(not is_subtype_of(C | A, C | B)) + +# Assignability follows the same pattern +static_assert(is_assignable_to(B | C, A | C)) +static_assert(is_assignable_to(C | B, C | A)) +static_assert(not is_assignable_to(A | C, B | C)) +static_assert(not is_assignable_to(C | A, C | B)) +``` + +## Intersection Types + +Intersection types cannot be expressed directly in Python syntax, but they occur when type narrowing +creates constraints through control flow. In ty's representation, intersection types are covariant +in their positive conjuncts and contravariant in their negative conjuncts. + +```py +from ty_extensions import is_assignable_to, is_subtype_of, static_assert, Intersection, Not + +class A: ... +class B(A): ... +class C: ... + +# Test covariance in positive conjuncts +# If B <: A, then Intersection[X, B] <: Intersection[X, A] +static_assert(is_subtype_of(Intersection[C, B], Intersection[C, A])) +static_assert(not is_subtype_of(Intersection[C, A], Intersection[C, B])) + +static_assert(is_assignable_to(Intersection[C, B], Intersection[C, A])) +static_assert(not is_assignable_to(Intersection[C, A], Intersection[C, B])) + +# Test contravariance in negative conjuncts +# If B <: A, then Intersection[X, Not[A]] <: Intersection[X, Not[B]] +# (excluding supertype A is more restrictive than excluding subtype B) +static_assert(is_subtype_of(Intersection[C, Not[A]], Intersection[C, Not[B]])) +static_assert(not is_subtype_of(Intersection[C, Not[B]], Intersection[C, Not[A]])) + +static_assert(is_assignable_to(Intersection[C, Not[A]], Intersection[C, Not[B]])) +static_assert(not is_assignable_to(Intersection[C, Not[B]], Intersection[C, Not[A]])) +``` + +## Subclass Types (type[T]) + +The `type[T]` construct represents the type of classes that are subclasses of `T`. It is covariant +in `T` because if `A <: B`, then `type[A] <: type[B]` holds. + +```py +from ty_extensions import is_assignable_to, is_subtype_of, static_assert + +class A: ... +class B(A): ... + +# type[T] is covariant in T +static_assert(is_subtype_of(type[B], type[A])) +static_assert(not is_subtype_of(type[A], type[B])) + +static_assert(is_assignable_to(type[B], type[A])) +static_assert(not is_assignable_to(type[A], type[B])) + +# With generic classes using type[T] +class ClassContainer[T]: + def __init__(self, cls: type[T]) -> None: + self.cls = cls + + def create_instance(self) -> T: + return self.cls() + +# ClassContainer is covariant in T due to type[T] +static_assert(is_subtype_of(ClassContainer[B], ClassContainer[A])) +static_assert(not is_subtype_of(ClassContainer[A], ClassContainer[B])) + +static_assert(is_assignable_to(ClassContainer[B], ClassContainer[A])) +static_assert(not is_assignable_to(ClassContainer[A], ClassContainer[B])) + +# Practical example: you can pass a ClassContainer[B] where ClassContainer[A] is expected +# because type[B] can safely be used where type[A] is expected +def use_a_class_container(container: ClassContainer[A]) -> A: + return container.create_instance() + +b_container = ClassContainer[B](B) +a_instance: A = use_a_class_container(b_container) # This should work +``` + +## Inheriting from generic classes with inferred variance + +When inheriting from a generic class with our type variable substituted in, we count its occurrences +as well. In the following example, `T` is covariant in `C`, and contravariant in the subclass `D` if +you only count its own occurrences. Because we count both then, `T` is invariant in `D`. + +```py +from ty_extensions import is_subtype_of, static_assert + +class A: + pass + +class B(A): + pass + +class C[T]: + def f() -> T | None: + pass + +static_assert(is_subtype_of(C[B], C[A])) +static_assert(not is_subtype_of(C[A], C[B])) + +class D[T](C[T]): + def g(x: T) -> None: + pass + +static_assert(not is_subtype_of(D[B], D[A])) +static_assert(not is_subtype_of(D[A], D[B])) +``` + +## Inheriting from generic classes with explicit variance + +```py +from typing import TypeVar, Generic +from ty_extensions import is_subtype_of, static_assert + +T = TypeVar("T") +T_co = TypeVar("T_co", covariant=True) +T_contra = TypeVar("T_contra", contravariant=True) + +class A: + pass + +class B(A): + pass + +class Invariant(Generic[T]): + pass + +static_assert(not is_subtype_of(Invariant[B], Invariant[A])) +static_assert(not is_subtype_of(Invariant[A], Invariant[B])) + +class DerivedInvariant[T](Invariant[T]): + pass + +static_assert(not is_subtype_of(DerivedInvariant[B], DerivedInvariant[A])) +static_assert(not is_subtype_of(DerivedInvariant[A], DerivedInvariant[B])) + +class Covariant(Generic[T_co]): + pass + +static_assert(is_subtype_of(Covariant[B], Covariant[A])) +static_assert(not is_subtype_of(Covariant[A], Covariant[B])) + +class DerivedCovariant[T](Covariant[T]): + pass + +static_assert(is_subtype_of(DerivedCovariant[B], DerivedCovariant[A])) +static_assert(not is_subtype_of(DerivedCovariant[A], DerivedCovariant[B])) + +class Contravariant(Generic[T_contra]): + pass + +static_assert(not is_subtype_of(Contravariant[B], Contravariant[A])) +static_assert(is_subtype_of(Contravariant[A], Contravariant[B])) + +class DerivedContravariant[T](Contravariant[T]): + pass + +static_assert(not is_subtype_of(DerivedContravariant[B], DerivedContravariant[A])) +static_assert(is_subtype_of(DerivedContravariant[A], DerivedContravariant[B])) +``` + +[linear-time-variance-talk]: https://www.youtube.com/watch?v=7uixlNTOY4s&t=9705s [spec]: https://typing.python.org/en/latest/spec/generics.html#variance diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/type.md b/crates/ty_python_semantic/resources/mdtest/narrow/type.md index fccd6e54fa..3cf1aa23db 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/type.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/type.md @@ -217,8 +217,8 @@ class B: ... def _[T](x: A | B): if type(x) is A[str]: - # `type()` never returns a generic alias, so `type(x)` cannot be `A[str]` - reveal_type(x) # revealed: Never + # TODO: `type()` never returns a generic alias, so `type(x)` cannot be `A[str]` + reveal_type(x) # revealed: A[int] | B else: reveal_type(x) # revealed: A[int] | B ``` diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 378583d9ae..a910e6a206 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -60,6 +60,7 @@ use crate::types::mro::{Mro, MroError, MroIterator}; pub(crate) use crate::types::narrow::infer_narrowing_constraint; use crate::types::signatures::{Parameter, ParameterForm, Parameters, walk_signature}; use crate::types::tuple::TupleSpec; +use crate::types::variance::{TypeVarVariance, VarianceInferable}; use crate::unpack::EvaluationMode; pub use crate::util::diagnostics::add_inferred_python_version_hint_to_diagnostic; use crate::{Db, FxOrderMap, FxOrderSet, Module, Program}; @@ -92,6 +93,7 @@ mod subclass_of; mod tuple; mod type_ordering; mod unpacker; +mod variance; mod visitor; mod definition; @@ -322,6 +324,29 @@ fn class_lookup_cycle_initial<'db>( Place::bound(Type::Never).into() } +#[allow(clippy::trivially_copy_pass_by_ref)] +fn variance_cycle_recover<'db, T>( + _db: &'db dyn Db, + _value: &TypeVarVariance, + count: u32, + _self: T, + _typevar: BoundTypeVarInstance<'db>, +) -> salsa::CycleRecoveryAction { + assert!( + count <= 2, + "Should only be able to cycle at most twice: there are only three levels in the lattice, each cycle should move us one" + ); + salsa::CycleRecoveryAction::Iterate +} + +fn variance_cycle_initial<'db, T>( + _db: &'db dyn Db, + _self: T, + _typevar: BoundTypeVarInstance<'db>, +) -> TypeVarVariance { + TypeVarVariance::Bivariant +} + /// Meta data for `Type::Todo`, which represents a known limitation in ty. #[cfg(debug_assertions)] #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, get_size2::GetSize)] @@ -755,7 +780,7 @@ impl<'db> Type<'db> { Name::new_static("T_all"), None, None, - variance, + Some(variance), None, TypeVarKind::Pep695, ), @@ -963,7 +988,6 @@ impl<'db> Type<'db> { .expect("Expected a Type::FunctionLiteral variant") } - #[cfg(test)] pub(crate) const fn is_function_literal(&self) -> bool { matches!(self, Type::FunctionLiteral(..)) } @@ -5533,7 +5557,10 @@ impl<'db> Type<'db> { ast::name::Name::new_static("Self"), Some(class_definition), Some(TypeVarBoundOrConstraints::UpperBound(instance).into()), - TypeVarVariance::Invariant, + // According to the [spec], we can consider `Self` + // equivalent to an invariant type variable + // [spec]: https://typing.python.org/en/latest/spec/generics.html#self + Some(TypeVarVariance::Invariant), None, TypeVarKind::TypingSelf, ); @@ -6315,6 +6342,99 @@ impl<'db> From<&Type<'db>> for Type<'db> { } } +impl<'db> VarianceInferable<'db> for Type<'db> { + fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance { + tracing::debug!( + "Checking variance of '{tvar}' in `{ty:?}`", + tvar = typevar.typevar(db).name(db), + ty = self.display(db), + ); + + let v = match self { + Type::ClassLiteral(class_literal) => class_literal.variance_of(db, typevar), + + Type::FunctionLiteral(function_type) => { + // TODO: do we need to replace self? + function_type.signature(db).variance_of(db, typevar) + } + + Type::BoundMethod(method_type) => { + // TODO: do we need to replace self? + method_type + .function(db) + .signature(db) + .variance_of(db, typevar) + } + + Type::NominalInstance(nominal_instance_type) => { + nominal_instance_type.variance_of(db, typevar) + } + Type::GenericAlias(generic_alias) => generic_alias.variance_of(db, typevar), + Type::Callable(callable_type) => callable_type.signatures(db).variance_of(db, typevar), + Type::TypeVar(other_typevar) | Type::NonInferableTypeVar(other_typevar) + if other_typevar == typevar => + { + // type variables are covariant in themselves + TypeVarVariance::Covariant + } + Type::ProtocolInstance(protocol_instance_type) => { + protocol_instance_type.variance_of(db, typevar) + } + Type::Union(union_type) => union_type + .elements(db) + .iter() + .map(|ty| ty.variance_of(db, typevar)) + .collect(), + Type::Intersection(intersection_type) => intersection_type + .positive(db) + .iter() + .map(|ty| ty.variance_of(db, typevar)) + .chain(intersection_type.negative(db).iter().map(|ty| { + ty.with_polarity(TypeVarVariance::Contravariant) + .variance_of(db, typevar) + })) + .collect(), + Type::PropertyInstance(property_instance_type) => property_instance_type + .getter(db) + .iter() + .chain(&property_instance_type.setter(db)) + .map(|ty| ty.variance_of(db, typevar)) + .collect(), + Type::SubclassOf(subclass_of_type) => subclass_of_type.variance_of(db, typevar), + Type::Dynamic(_) + | Type::Never + | Type::WrapperDescriptor(_) + | Type::MethodWrapper(_) + | Type::DataclassDecorator(_) + | Type::DataclassTransformer(_) + | Type::ModuleLiteral(_) + | Type::IntLiteral(_) + | Type::BooleanLiteral(_) + | Type::StringLiteral(_) + | Type::EnumLiteral(_) + | Type::LiteralString + | Type::BytesLiteral(_) + | Type::SpecialForm(_) + | Type::KnownInstance(_) + | Type::AlwaysFalsy + | Type::AlwaysTruthy + | Type::BoundSuper(_) + | Type::TypeVar(_) + | Type::NonInferableTypeVar(_) + | Type::TypeIs(_) + | Type::TypedDict(_) + | Type::TypeAlias(_) => TypeVarVariance::Bivariant, + }; + + tracing::debug!( + "Result of variance of '{tvar}' in `{ty:?}` is `{v:?}`", + tvar = typevar.typevar(db).name(db), + ty = self.display(db), + ); + v + } +} + /// A mapping that can be applied to a type, producing another type. This is applied inductively to /// the components of complex types. /// @@ -6972,8 +7092,8 @@ pub struct TypeVarInstance<'db> { /// instead (to evaluate any lazy bound or constraints). _bound_or_constraints: Option>, - /// The variance of the TypeVar - variance: TypeVarVariance, + /// The explicitly specified variance of the TypeVar + explicit_variance: Option, /// The default type for this TypeVar, if any. Don't use this field directly, use the /// `default_type` method instead (to evaluate any lazy default). @@ -7065,7 +7185,7 @@ impl<'db> TypeVarInstance<'db> { .lazy_constraints(db) .map(|constraints| constraints.normalized_impl(db, visitor).into()), }), - self.variance(db), + self.explicit_variance(db), self._default(db).and_then(|default| match default { TypeVarDefaultEvaluation::Eager(ty) => Some(ty.normalized_impl(db, visitor).into()), TypeVarDefaultEvaluation::Lazy => self @@ -7093,7 +7213,7 @@ impl<'db> TypeVarInstance<'db> { .lazy_constraints(db) .map(|constraints| constraints.materialize(db, variance).into()), }), - self.variance(db), + self.explicit_variance(db), self._default(db).and_then(|default| match default { TypeVarDefaultEvaluation::Eager(ty) => Some(ty.materialize(db, variance).into()), TypeVarDefaultEvaluation::Lazy => self @@ -7118,7 +7238,7 @@ impl<'db> TypeVarInstance<'db> { Name::new(format!("{}'instance", self.name(db))), None, Some(bound_or_constraints.into()), - self.variance(db), + self.explicit_variance(db), None, self.kind(db), )) @@ -7187,6 +7307,33 @@ pub struct BoundTypeVarInstance<'db> { // The Salsa heap is tracked separately. impl get_size2::GetSize for BoundTypeVarInstance<'_> {} +impl<'db> BoundTypeVarInstance<'db> { + pub(crate) fn variance_with_polarity( + self, + db: &'db dyn Db, + polarity: TypeVarVariance, + ) -> TypeVarVariance { + let _span = tracing::trace_span!("variance_with_polarity").entered(); + match self.typevar(db).explicit_variance(db) { + Some(explicit_variance) => explicit_variance.compose(polarity), + None => match self.binding_context(db) { + BindingContext::Definition(definition) => { + let type_inference = infer_definition_types(db, definition); + type_inference + .binding_type(definition) + .with_polarity(polarity) + .variance_of(db, self) + } + BindingContext::Synthetic => TypeVarVariance::Invariant, + }, + } + } + + pub(crate) fn variance(self, db: &'db dyn Db) -> TypeVarVariance { + self.variance_with_polarity(db, TypeVarVariance::Covariant) + } +} + fn walk_bound_type_var_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, bound_typevar: BoundTypeVarInstance<'db>, @@ -7250,28 +7397,6 @@ impl<'db> BoundTypeVarInstance<'db> { } } -#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, salsa::Update, get_size2::GetSize)] -pub enum TypeVarVariance { - Invariant, - Covariant, - Contravariant, - Bivariant, -} - -impl TypeVarVariance { - /// Flips the polarity of the variance. - /// - /// Covariant becomes contravariant, contravariant becomes covariant, others remain unchanged. - pub(crate) const fn flip(self) -> Self { - match self { - TypeVarVariance::Invariant => TypeVarVariance::Invariant, - TypeVarVariance::Covariant => TypeVarVariance::Contravariant, - TypeVarVariance::Contravariant => TypeVarVariance::Covariant, - TypeVarVariance::Bivariant => TypeVarVariance::Bivariant, - } - } -} - /// Whether a typevar default is eagerly specified or lazily evaluated. #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, salsa::Update, get_size2::GetSize)] pub enum TypeVarDefaultEvaluation<'db> { diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index 59929bf59c..8580dfb61f 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -13,8 +13,10 @@ use crate::module_resolver::KnownModule; use crate::semantic_index::definition::{Definition, DefinitionState}; use crate::semantic_index::place::ScopedPlaceId; use crate::semantic_index::scope::NodeWithScopeKind; +use crate::semantic_index::symbol::Symbol; use crate::semantic_index::{ BindingWithConstraints, DeclarationWithConstraint, SemanticIndex, attribute_declarations, + attribute_scopes, }; use crate::types::context::InferContext; use crate::types::diagnostic::{INVALID_LEGACY_TYPE_VARIABLE, INVALID_TYPE_ALIAS_TYPE}; @@ -28,8 +30,8 @@ use crate::types::{ ApplyTypeMappingVisitor, BareTypeAliasType, Binding, BoundSuperError, BoundSuperType, CallableType, DataclassParams, DeprecatedInstance, HasRelationToVisitor, KnownInstanceType, NormalizedVisitor, PropertyInstanceType, StringLiteralType, TypeAliasType, TypeMapping, - TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, declaration_type, - infer_definition_types, todo_type, + TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, VarianceInferable, + declaration_type, infer_definition_types, todo_type, }; use crate::{ Db, FxIndexMap, FxOrderSet, Program, @@ -314,6 +316,51 @@ impl<'db> From> for Type<'db> { } } +#[salsa::tracked] +impl<'db> VarianceInferable<'db> for GenericAlias<'db> { + #[salsa::tracked] + fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance { + let origin = self.origin(db); + + let specialization = self.specialization(db); + + // if the class is the thing defining the variable, then it can + // reference it without it being applied to the specialization + std::iter::once(origin.variance_of(db, typevar)) + .chain( + specialization + .generic_context(db) + .variables(db) + .iter() + .zip(specialization.types(db)) + .map(|(generic_typevar, ty)| { + if let Some(explicit_variance) = + generic_typevar.typevar(db).explicit_variance(db) + { + ty.with_polarity(explicit_variance).variance_of(db, typevar) + } else { + // `with_polarity` composes the passed variance with the + // inferred one. The inference is done lazily, as we can + // sometimes determine the result just from the passed + // variance. This operation is commutative, so we could + // infer either first. We choose to make the `ClassLiteral` + // variance lazy, as it is known to be expensive, requiring + // that we traverse all members. + // + // If salsa let us look at the cache, we could check first + // to see if the class literal query was already run. + + let typevar_variance_in_substituted_type = ty.variance_of(db, typevar); + origin + .with_polarity(typevar_variance_in_substituted_type) + .variance_of(db, *generic_typevar) + } + }), + ) + .collect() + } +} + /// Represents a class type, which might be a non-generic class, or a specialization of a generic /// class. #[derive( @@ -1136,6 +1183,15 @@ impl<'db> From> for Type<'db> { } } +impl<'db> VarianceInferable<'db> for ClassType<'db> { + fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance { + match self { + Self::NonGeneric(class) => class.variance_of(db, typevar), + Self::Generic(generic) => generic.variance_of(db, typevar), + } + } +} + /// A filter that describes which methods are considered when looking for implicit attribute assignments /// in [`ClassLiteral::implicit_attribute`]. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -3060,6 +3116,126 @@ impl<'db> From> for ClassType<'db> { } } +#[salsa::tracked] +impl<'db> VarianceInferable<'db> for ClassLiteral<'db> { + #[salsa::tracked(cycle_fn=crate::types::variance_cycle_recover, cycle_initial=crate::types::variance_cycle_initial)] + fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance { + let typevar_in_generic_context = self + .generic_context(db) + .is_some_and(|generic_context| generic_context.variables(db).contains(&typevar)); + + if !typevar_in_generic_context { + return TypeVarVariance::Bivariant; + } + let class_body_scope = self.body_scope(db); + + let file = class_body_scope.file(db); + let index = semantic_index(db, file); + + let explicit_bases_variances = self + .explicit_bases(db) + .iter() + .map(|class| class.variance_of(db, typevar)); + + let default_attribute_variance = { + let is_namedtuple = CodeGeneratorKind::NamedTuple.matches(db, self); + // Python 3.13 introduced a synthesized `__replace__` method on dataclasses which uses + // their field types in contravariant position, thus meaning a frozen dataclass must + // still be invariant in its field types. Other synthesized methods on dataclasses are + // not considered here, since they don't use field types in their signatures. TODO: + // ideally we'd have a single source of truth for information about synthesized + // methods, so we just look them up normally and don't hardcode this knowledge here. + let is_frozen_dataclass = Program::get(db).python_version(db) <= PythonVersion::PY312 + && self + .dataclass_params(db) + .is_some_and(|params| params.contains(DataclassParams::FROZEN)); + if is_namedtuple || is_frozen_dataclass { + TypeVarVariance::Covariant + } else { + TypeVarVariance::Invariant + } + }; + + let init_name: &Name = &"__init__".into(); + let new_name: &Name = &"__new__".into(); + + let use_def_map = index.use_def_map(class_body_scope.file_scope_id(db)); + let table = place_table(db, class_body_scope); + let attribute_places_and_qualifiers = + use_def_map + .all_end_of_scope_symbol_declarations() + .map(|(symbol_id, declarations)| { + let place_and_qual = + place_from_declarations(db, declarations).ignore_conflicting_declarations(); + (symbol_id, place_and_qual) + }) + .chain(use_def_map.all_end_of_scope_symbol_bindings().map( + |(symbol_id, bindings)| (symbol_id, place_from_bindings(db, bindings).into()), + )) + .filter_map(|(symbol_id, place_and_qual)| { + if let Some(name) = table.place(symbol_id).as_symbol().map(Symbol::name) { + (![init_name, new_name].contains(&name)) + .then_some((name.to_string(), place_and_qual)) + } else { + None + } + }); + + // Dataclasses can have some additional synthesized methods (`__eq__`, `__hash__`, + // `__lt__`, etc.) but none of these will have field types type variables in their signatures, so we + // don't need to consider them for variance. + + let attribute_names = attribute_scopes(db, self.body_scope(db)) + .flat_map(|function_scope_id| { + index + .place_table(function_scope_id) + .members() + .filter_map(|member| member.as_instance_attribute()) + .filter(|name| *name != init_name && *name != new_name) + .map(std::string::ToString::to_string) + .collect::>() + }) + .dedup(); + + let attribute_variances = attribute_names + .map(|name| { + let place_and_quals = self.own_instance_member(db, &name); + (name, place_and_quals) + }) + .chain(attribute_places_and_qualifiers) + .dedup() + .filter_map(|(name, place_and_qual)| { + place_and_qual.place.ignore_possibly_unbound().map(|ty| { + let variance = if place_and_qual + .qualifiers + // `CLASS_VAR || FINAL` is really `all()`, but + // we want to be robust against new qualifiers + .intersects(TypeQualifiers::CLASS_VAR | TypeQualifiers::FINAL) + // We don't allow mutation of methods or properties + || ty.is_function_literal() + || ty.is_property_instance() + // Underscore-prefixed attributes are assumed not to be externally mutated + || name.starts_with('_') + { + // CLASS_VAR: class vars generally shouldn't contain the + // type variable, but they could if it's a + // callable type. They can't be mutated on instances. + // + // FINAL: final attributes are immutable, and thus covariant + TypeVarVariance::Covariant + } else { + default_attribute_variance + }; + ty.with_polarity(variance).variance_of(db, typevar) + }) + }); + + attribute_variances + .chain(explicit_bases_variances) + .collect() + } +} + #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, get_size2::GetSize)] pub(super) enum InheritanceCycle { /// The class is cyclically defined and is a participant in the cycle. @@ -4673,7 +4849,7 @@ impl KnownClass { &target.id, Some(containing_assignment), bound_or_constraint, - variance, + Some(variance), default.map(Into::into), TypeVarKind::Legacy, ), diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index e0619e8343..049b953949 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -549,12 +549,7 @@ impl<'db> Specialization<'db> { .into_iter() .zip(self.types(db)) .map(|(bound_typevar, vartype)| { - let variance = match bound_typevar.typevar(db).variance(db) { - TypeVarVariance::Invariant => TypeVarVariance::Invariant, - TypeVarVariance::Covariant => variance, - TypeVarVariance::Contravariant => variance.flip(), - TypeVarVariance::Bivariant => unreachable!(), - }; + let variance = bound_typevar.variance_with_polarity(db, variance); vartype.materialize(db, variance) }) .collect(); @@ -599,7 +594,7 @@ impl<'db> Specialization<'db> { // - contravariant: verify that other_type <: self_type // - invariant: verify that self_type <: other_type AND other_type <: self_type // - bivariant: skip, can't make subtyping/assignability false - let compatible = match bound_typevar.typevar(db).variance(db) { + let compatible = match bound_typevar.variance(db) { TypeVarVariance::Invariant => match relation { TypeRelation::Subtyping => self_type.is_equivalent_to(db, *other_type), TypeRelation::Assignability => { @@ -639,7 +634,7 @@ impl<'db> Specialization<'db> { // - contravariant: verify that other_type == self_type // - invariant: verify that self_type == other_type // - bivariant: skip, can't make equivalence false - let compatible = match bound_typevar.typevar(db).variance(db) { + let compatible = match bound_typevar.variance(db) { TypeVarVariance::Invariant | TypeVarVariance::Covariant | TypeVarVariance::Contravariant => self_type.is_equivalent_to(db, *other_type), diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 5502715e38..07f82a7329 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -124,8 +124,8 @@ use crate::types::{ MemberLookupPolicy, MetaclassCandidate, PEP695TypeAliasType, Parameter, ParameterForm, Parameters, SpecialFormType, SubclassOfType, Truthiness, Type, TypeAliasType, TypeAndQualifiers, TypeIsType, TypeQualifiers, TypeVarBoundOrConstraintsEvaluation, - TypeVarDefaultEvaluation, TypeVarInstance, TypeVarKind, TypeVarVariance, UnionBuilder, - UnionType, binding_type, todo_type, + TypeVarDefaultEvaluation, TypeVarInstance, TypeVarKind, UnionBuilder, UnionType, binding_type, + todo_type, }; use crate::unpack::{EvaluationMode, Unpack, UnpackPosition}; use crate::util::diagnostics::format_enumeration; @@ -3470,7 +3470,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { &name.id, Some(definition), bound_or_constraint, - TypeVarVariance::Invariant, // TODO: infer this + None, default.as_deref().map(|_| TypeVarDefaultEvaluation::Lazy), TypeVarKind::Pep695, ))); diff --git a/crates/ty_python_semantic/src/types/instance.rs b/crates/ty_python_semantic/src/types/instance.rs index ca837cd294..eb6978af22 100644 --- a/crates/ty_python_semantic/src/types/instance.rs +++ b/crates/ty_python_semantic/src/types/instance.rs @@ -12,7 +12,7 @@ use crate::types::protocol_class::walk_protocol_interface; use crate::types::tuple::{TupleSpec, TupleType}; use crate::types::{ ApplyTypeMappingVisitor, ClassBase, DynamicType, HasRelationToVisitor, IsDisjointVisitor, - NormalizedVisitor, TypeMapping, TypeRelation, TypeTransformer, + NormalizedVisitor, TypeMapping, TypeRelation, TypeTransformer, VarianceInferable, }; use crate::{Db, FxOrderSet}; @@ -406,6 +406,12 @@ pub(crate) struct SliceLiteral { pub(crate) step: Option, } +impl<'db> VarianceInferable<'db> for NominalInstanceType<'db> { + fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance { + self.class(db).variance_of(db, typevar) + } +} + /// A `ProtocolInstanceType` represents the set of all possible runtime objects /// that conform to the interface described by a certain protocol. #[derive( @@ -593,6 +599,12 @@ impl<'db> ProtocolInstanceType<'db> { } } +impl<'db> VarianceInferable<'db> for ProtocolInstanceType<'db> { + fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance { + self.inner.variance_of(db, typevar) + } +} + /// An enumeration of the two kinds of protocol types: those that originate from a class /// definition in source code, and those that are synthesized from a set of members. #[derive( @@ -618,12 +630,23 @@ impl<'db> Protocol<'db> { } } +impl<'db> VarianceInferable<'db> for Protocol<'db> { + fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance { + match self { + Protocol::FromClass(class_type) => class_type.variance_of(db, typevar), + Protocol::Synthesized(synthesized_protocol_type) => { + synthesized_protocol_type.variance_of(db, typevar) + } + } + } +} + mod synthesized_protocol { use crate::semantic_index::definition::Definition; use crate::types::protocol_class::ProtocolInterface; use crate::types::{ ApplyTypeMappingVisitor, BoundTypeVarInstance, NormalizedVisitor, TypeMapping, - TypeVarVariance, + TypeVarVariance, VarianceInferable, }; use crate::{Db, FxOrderSet}; @@ -676,4 +699,14 @@ mod synthesized_protocol { self.0 } } + + impl<'db> VarianceInferable<'db> for SynthesizedProtocolType<'db> { + fn variance_of( + self, + db: &'db dyn Db, + typevar: BoundTypeVarInstance<'db>, + ) -> TypeVarVariance { + self.0.variance_of(db, typevar) + } + } } diff --git a/crates/ty_python_semantic/src/types/protocol_class.rs b/crates/ty_python_semantic/src/types/protocol_class.rs index 12fb5fb988..d910bdc1ce 100644 --- a/crates/ty_python_semantic/src/types/protocol_class.rs +++ b/crates/ty_python_semantic/src/types/protocol_class.rs @@ -18,7 +18,7 @@ use crate::{ types::{ BoundTypeVarInstance, CallableType, ClassBase, ClassLiteral, IsDisjointVisitor, KnownFunction, NormalizedVisitor, PropertyInstanceType, Signature, Type, TypeMapping, - TypeQualifiers, TypeRelation, TypeTransformer, + TypeQualifiers, TypeRelation, TypeTransformer, VarianceInferable, signatures::{Parameter, Parameters}, }, }; @@ -301,6 +301,15 @@ impl<'db> ProtocolInterface<'db> { } } +impl<'db> VarianceInferable<'db> for ProtocolInterface<'db> { + fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance { + self.members(db) + // TODO do we need to switch on member kind? + .map(|member| member.ty().variance_of(db, typevar)) + .collect() + } +} + #[derive(Debug, PartialEq, Eq, Clone, Hash, salsa::Update, get_size2::GetSize)] pub(super) struct ProtocolMemberData<'db> { kind: ProtocolMemberKind<'db>, diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index 8d1492adc6..4c80c37571 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -20,7 +20,7 @@ use crate::semantic_index::definition::Definition; use crate::types::generics::{GenericContext, walk_generic_context}; use crate::types::{ BindingContext, BoundTypeVarInstance, KnownClass, NormalizedVisitor, TypeMapping, TypeRelation, - todo_type, + VarianceInferable, todo_type, }; use crate::{Db, FxOrderSet}; use ruff_python_ast::{self as ast, name::Name}; @@ -223,6 +223,16 @@ impl<'a, 'db> IntoIterator for &'a CallableSignature<'db> { } } +impl<'db> VarianceInferable<'db> for &CallableSignature<'db> { + // TODO: possibly need to replace self + fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance { + self.overloads + .iter() + .map(|signature| signature.variance_of(db, typevar)) + .collect() + } +} + /// The signature of one of the overloads of a callable. #[derive(Clone, Debug, salsa::Update, get_size2::GetSize)] pub struct Signature<'db> { @@ -982,6 +992,28 @@ impl std::hash::Hash for Signature<'_> { } } +impl<'db> VarianceInferable<'db> for &Signature<'db> { + fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance { + tracing::debug!( + "Checking variance of `{tvar}` in `{self:?}`", + tvar = typevar.typevar(db).name(db) + ); + itertools::chain( + self.parameters + .iter() + .filter_map(|parameter| match parameter.form { + ParameterForm::Type => None, + ParameterForm::Value => parameter.annotated_type().map(|ty| { + ty.with_polarity(TypeVarVariance::Contravariant) + .variance_of(db, typevar) + }), + }), + self.return_ty.map(|ty| ty.variance_of(db, typevar)), + ) + .collect() + } +} + #[derive(Clone, Debug, PartialEq, Eq, Hash, salsa::Update, get_size2::GetSize)] pub(crate) struct Parameters<'db> { // TODO: use SmallVec here once invariance bug is fixed diff --git a/crates/ty_python_semantic/src/types/subclass_of.rs b/crates/ty_python_semantic/src/types/subclass_of.rs index 02cb7bbebc..eeb3294363 100644 --- a/crates/ty_python_semantic/src/types/subclass_of.rs +++ b/crates/ty_python_semantic/src/types/subclass_of.rs @@ -2,6 +2,7 @@ use ruff_python_ast::name::Name; use crate::place::PlaceAndQualifiers; use crate::semantic_index::definition::Definition; +use crate::types::variance::VarianceInferable; use crate::types::{ ApplyTypeMappingVisitor, BindingContext, BoundTypeVarInstance, ClassType, DynamicType, HasRelationToVisitor, KnownClass, MemberLookupPolicy, NormalizedVisitor, Type, TypeMapping, @@ -103,7 +104,7 @@ impl<'db> SubclassOfType<'db> { ) .into(), ), - variance, + Some(variance), None, TypeVarKind::Pep695, ), @@ -215,6 +216,15 @@ impl<'db> SubclassOfType<'db> { } } +impl<'db> VarianceInferable<'db> for SubclassOfType<'db> { + fn variance_of(self, db: &dyn Db, typevar: BoundTypeVarInstance<'_>) -> TypeVarVariance { + match self.subclass_of { + SubclassOfInner::Dynamic(_) => TypeVarVariance::Bivariant, + SubclassOfInner::Class(class) => class.variance_of(db, typevar), + } + } +} + /// An enumeration of the different kinds of `type[]` types that a [`SubclassOfType`] can represent: /// /// 1. A "subclass of a class": `type[C]` for any class object `C` diff --git a/crates/ty_python_semantic/src/types/variance.rs b/crates/ty_python_semantic/src/types/variance.rs new file mode 100644 index 0000000000..63d250db56 --- /dev/null +++ b/crates/ty_python_semantic/src/types/variance.rs @@ -0,0 +1,138 @@ +use crate::{Db, types::BoundTypeVarInstance}; + +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, salsa::Update, get_size2::GetSize)] +pub enum TypeVarVariance { + Invariant, + Covariant, + Contravariant, + Bivariant, +} + +impl TypeVarVariance { + pub const fn bottom() -> Self { + TypeVarVariance::Bivariant + } + + pub const fn top() -> Self { + TypeVarVariance::Invariant + } + + // supremum + #[must_use] + pub(crate) const fn join(self, other: Self) -> Self { + use TypeVarVariance::{Bivariant, Contravariant, Covariant, Invariant}; + match (self, other) { + (Invariant, _) | (_, Invariant) => Invariant, + (Covariant, Covariant) => Covariant, + (Contravariant, Contravariant) => Contravariant, + (Covariant, Contravariant) | (Contravariant, Covariant) => Invariant, + (Bivariant, other) | (other, Bivariant) => other, + } + } + + /// Compose two variances: useful for combining use-site and definition-site variances, e.g. + /// `C[D[T]]` or function argument/return position variances. + /// + /// `other` is a thunk to avoid unnecessary computation when `self` is `Bivariant`. + /// + /// Based on the variance composition/transformation operator in + /// , page 5 + /// + /// While their operation would have `compose(Invariant, Bivariant) == + /// Invariant`, we instead have it evaluate to `Bivariant`. This is a valid + /// choice, as discussed on that same page, where type equality is semantic + /// rather than syntactic. To see that this holds for our setting consider + /// the type + /// ```python + /// type ConstantInt[T] = int + /// ``` + /// We would say `ConstantInt[str]` = `ConstantInt[float]`, so we qualify as + /// using semantic equivalence. + #[must_use] + pub(crate) fn compose(self, other: Self) -> Self { + self.compose_thunk(|| other) + } + + /// Like `compose`, but takes `other` as a thunk to avoid unnecessary + /// computation when `self` is `Bivariant`. + #[must_use] + pub(crate) fn compose_thunk(self, other: F) -> Self + where + F: FnOnce() -> Self, + { + match self { + TypeVarVariance::Covariant => other(), + TypeVarVariance::Contravariant => other().flip(), + TypeVarVariance::Bivariant => TypeVarVariance::Bivariant, + TypeVarVariance::Invariant => { + if TypeVarVariance::Bivariant == other() { + TypeVarVariance::Bivariant + } else { + TypeVarVariance::Invariant + } + } + } + } + + /// Flips the polarity of the variance. + /// + /// Covariant becomes contravariant, contravariant becomes covariant, others remain unchanged. + pub(crate) const fn flip(self) -> Self { + match self { + TypeVarVariance::Invariant => TypeVarVariance::Invariant, + TypeVarVariance::Covariant => TypeVarVariance::Contravariant, + TypeVarVariance::Contravariant => TypeVarVariance::Covariant, + TypeVarVariance::Bivariant => TypeVarVariance::Bivariant, + } + } +} + +impl std::iter::FromIterator for TypeVarVariance { + fn from_iter>(iter: T) -> Self { + use std::ops::ControlFlow; + // TODO: use `into_value` when control_flow_into_value is stable + let (ControlFlow::Break(variance) | ControlFlow::Continue(variance)) = iter + .into_iter() + .try_fold(TypeVarVariance::Bivariant, |acc, variance| { + let supremum = acc.join(variance); + match supremum { + // short circuit at top + TypeVarVariance::Invariant => ControlFlow::Break(supremum), + TypeVarVariance::Bivariant + | TypeVarVariance::Covariant + | TypeVarVariance::Contravariant => ControlFlow::Continue(supremum), + } + }); + variance + } +} + +pub(crate) trait VarianceInferable<'db>: Sized { + fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance; + + fn with_polarity(self, polarity: TypeVarVariance) -> WithPolarity { + WithPolarity { + variance_inferable: self, + polarity, + } + } +} + +pub(crate) struct WithPolarity { + variance_inferable: T, + polarity: TypeVarVariance, +} + +impl<'db, T> VarianceInferable<'db> for WithPolarity +where + T: VarianceInferable<'db>, +{ + fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance { + let WithPolarity { + variance_inferable, + polarity, + } = self; + + polarity.compose_thunk(|| variance_inferable.variance_of(db, typevar)) + } +}