[red-knot] Rework Type::to_instance() to return Option<Type> (#16428)

## Summary

This PR fixes https://github.com/astral-sh/ruff/issues/16302.

The PR reworks `Type::to_instance()` to return `Option<Type>` rather
than `Type`. This reflects more accurately the fact that some variants
cannot be "turned into an instance", since they _already_ represent
instances of some kind. On `main`, we silently fallback to `Unknown` for
these variants, but this implicit behaviour can be somewhat surprising
and lead to unexpected bugs.

Returning `Option<Type>` rather than `Type` means that each callsite has
to account for the possibility that the type might already represent an
instance, and decide what to do about it.
In general, I think this increases the robustness of the code. Working
on this PR revealed two latent bugs in the code:
- One which has already been fixed by
https://github.com/astral-sh/ruff/pull/16427
- One which is fixed as part of https://github.com/astral-sh/ruff/pull/16608

I added special handling to `KnownClass::to_instance()`: If we fail to find one of these classes and the `test` feature is
_not_ enabled, we log a warning to the terminal saying that we failed to
find the class in typeshed and that we will be falling back to
`Type::Unknown`. A cache is maintained so that we record all classes
that we have already logged a warning for; we only log a warning for
failing to lookup a `KnownClass` if we know that it's the first time
we're looking it up.

## Test Plan

- All existing tests pass
- I ran the property tests via `QUICKCHECK_TESTS=1000000 cargo test
--release -p red_knot_python_semantic -- --ignored
types::property_tests::stable`

I also manually checked that warnings are appropriately printed to the
terminal when `KnownClass::to_instance()` falls back to `Unknown` and
the `test` feature is not enabled. To do this, I applied this diff to
the PR branch:

<details>
<summary>Patch deleting `int` and `str` from buitins</summary>

```diff
diff --git a/crates/red_knot_vendored/vendor/typeshed/stdlib/builtins.pyi b/crates/red_knot_vendored/vendor/typeshed/stdlib/builtins.pyi
index 0a6dc57b0..86636a05b 100644
--- a/crates/red_knot_vendored/vendor/typeshed/stdlib/builtins.pyi
+++ b/crates/red_knot_vendored/vendor/typeshed/stdlib/builtins.pyi
@@ -228,111 +228,6 @@ _PositiveInteger: TypeAlias = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
 _NegativeInteger: TypeAlias = Literal[-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20]
 _LiteralInteger = _PositiveInteger | _NegativeInteger | Literal[0]  # noqa: Y026  # TODO: Use TypeAlias once mypy bugs are fixed
 
-class int:
-    @overload
-    def __new__(cls, x: ConvertibleToInt = ..., /) -> Self: ...
-    @overload
-    def __new__(cls, x: str | bytes | bytearray, /, base: SupportsIndex) -> Self: ...
-    def as_integer_ratio(self) -> tuple[int, Literal[1]]: ...
-    @property
-    def real(self) -> int: ...
-    @property
-    def imag(self) -> Literal[0]: ...
-    @property
-    def numerator(self) -> int: ...
-    @property
-    def denominator(self) -> Literal[1]: ...
-    def conjugate(self) -> int: ...
-    def bit_length(self) -> int: ...
-    if sys.version_info >= (3, 10):
-        def bit_count(self) -> int: ...
-
-    if sys.version_info >= (3, 11):
-        def to_bytes(
-            self, length: SupportsIndex = 1, byteorder: Literal["little", "big"] = "big", *, signed: bool = False
-        ) -> bytes: ...
-        @classmethod
-        def from_bytes(
-            cls,
-            bytes: Iterable[SupportsIndex] | SupportsBytes | ReadableBuffer,
-            byteorder: Literal["little", "big"] = "big",
-            *,
-            signed: bool = False,
-        ) -> Self: ...
-    else:
-        def to_bytes(self, length: SupportsIndex, byteorder: Literal["little", "big"], *, signed: bool = False) -> bytes: ...
-        @classmethod
-        def from_bytes(
-            cls,
-            bytes: Iterable[SupportsIndex] | SupportsBytes | ReadableBuffer,
-            byteorder: Literal["little", "big"],
-            *,
-            signed: bool = False,
-        ) -> Self: ...
-
-    if sys.version_info >= (3, 12):
-        def is_integer(self) -> Literal[True]: ...
-
-    def __add__(self, value: int, /) -> int: ...
-    def __sub__(self, value: int, /) -> int: ...
-    def __mul__(self, value: int, /) -> int: ...
-    def __floordiv__(self, value: int, /) -> int: ...
-    def __truediv__(self, value: int, /) -> float: ...
-    def __mod__(self, value: int, /) -> int: ...
-    def __divmod__(self, value: int, /) -> tuple[int, int]: ...
-    def __radd__(self, value: int, /) -> int: ...
-    def __rsub__(self, value: int, /) -> int: ...
-    def __rmul__(self, value: int, /) -> int: ...
-    def __rfloordiv__(self, value: int, /) -> int: ...
-    def __rtruediv__(self, value: int, /) -> float: ...
-    def __rmod__(self, value: int, /) -> int: ...
-    def __rdivmod__(self, value: int, /) -> tuple[int, int]: ...
-    @overload
-    def __pow__(self, x: Literal[0], /) -> Literal[1]: ...
-    @overload
-    def __pow__(self, value: Literal[0], mod: None, /) -> Literal[1]: ...
-    @overload
-    def __pow__(self, value: _PositiveInteger, mod: None = None, /) -> int: ...
-    @overload
-    def __pow__(self, value: _NegativeInteger, mod: None = None, /) -> float: ...
-    # positive __value -> int; negative __value -> float
-    # return type must be Any as `int | float` causes too many false-positive errors
-    @overload
-    def __pow__(self, value: int, mod: None = None, /) -> Any: ...
-    @overload
-    def __pow__(self, value: int, mod: int, /) -> int: ...
-    def __rpow__(self, value: int, mod: int | None = None, /) -> Any: ...
-    def __and__(self, value: int, /) -> int: ...
-    def __or__(self, value: int, /) -> int: ...
-    def __xor__(self, value: int, /) -> int: ...
-    def __lshift__(self, value: int, /) -> int: ...
-    def __rshift__(self, value: int, /) -> int: ...
-    def __rand__(self, value: int, /) -> int: ...
-    def __ror__(self, value: int, /) -> int: ...
-    def __rxor__(self, value: int, /) -> int: ...
-    def __rlshift__(self, value: int, /) -> int: ...
-    def __rrshift__(self, value: int, /) -> int: ...
-    def __neg__(self) -> int: ...
-    def __pos__(self) -> int: ...
-    def __invert__(self) -> int: ...
-    def __trunc__(self) -> int: ...
-    def __ceil__(self) -> int: ...
-    def __floor__(self) -> int: ...
-    def __round__(self, ndigits: SupportsIndex = ..., /) -> int: ...
-    def __getnewargs__(self) -> tuple[int]: ...
-    def __eq__(self, value: object, /) -> bool: ...
-    def __ne__(self, value: object, /) -> bool: ...
-    def __lt__(self, value: int, /) -> bool: ...
-    def __le__(self, value: int, /) -> bool: ...
-    def __gt__(self, value: int, /) -> bool: ...
-    def __ge__(self, value: int, /) -> bool: ...
-    def __float__(self) -> float: ...
-    def __int__(self) -> int: ...
-    def __abs__(self) -> int: ...
-    def __hash__(self) -> int: ...
-    def __bool__(self) -> bool: ...
-    def __index__(self) -> int: ...
-
 class float:
     def __new__(cls, x: ConvertibleToFloat = ..., /) -> Self: ...
     def as_integer_ratio(self) -> tuple[int, int]: ...
@@ -437,190 +332,6 @@ class _FormatMapMapping(Protocol):
 class _TranslateTable(Protocol):
     def __getitem__(self, key: int, /) -> str | int | None: ...
 
-class str(Sequence[str]):
-    @overload
-    def __new__(cls, object: object = ...) -> Self: ...
-    @overload
-    def __new__(cls, object: ReadableBuffer, encoding: str = ..., errors: str = ...) -> Self: ...
-    @overload
-    def capitalize(self: LiteralString) -> LiteralString: ...
-    @overload
-    def capitalize(self) -> str: ...  # type: ignore[misc]
-    @overload
-    def casefold(self: LiteralString) -> LiteralString: ...
-    @overload
-    def casefold(self) -> str: ...  # type: ignore[misc]
-    @overload
-    def center(self: LiteralString, width: SupportsIndex, fillchar: LiteralString = " ", /) -> LiteralString: ...
-    @overload
-    def center(self, width: SupportsIndex, fillchar: str = " ", /) -> str: ...  # type: ignore[misc]
-    def count(self, sub: str, start: SupportsIndex | None = ..., end: SupportsIndex | None = ..., /) -> int: ...
-    def encode(self, encoding: str = "utf-8", errors: str = "strict") -> bytes: ...
-    def endswith(
-        self, suffix: str | tuple[str, ...], start: SupportsIndex | None = ..., end: SupportsIndex | None = ..., /
-    ) -> bool: ...
-    @overload
-    def expandtabs(self: LiteralString, tabsize: SupportsIndex = 8) -> LiteralString: ...
-    @overload
-    def expandtabs(self, tabsize: SupportsIndex = 8) -> str: ...  # type: ignore[misc]
-    def find(self, sub: str, start: SupportsIndex | None = ..., end: SupportsIndex | None = ..., /) -> int: ...
-    @overload
-    def format(self: LiteralString, *args: LiteralString, **kwargs: LiteralString) -> LiteralString: ...
-    @overload
-    def format(self, *args: object, **kwargs: object) -> str: ...
-    def format_map(self, mapping: _FormatMapMapping, /) -> str: ...
-    def index(self, sub: str, start: SupportsIndex | None = ..., end: SupportsIndex | None = ..., /) -> int: ...
-    def isalnum(self) -> bool: ...
-    def isalpha(self) -> bool: ...
-    def isascii(self) -> bool: ...
-    def isdecimal(self) -> bool: ...
-    def isdigit(self) -> bool: ...
-    def isidentifier(self) -> bool: ...
-    def islower(self) -> bool: ...
-    def isnumeric(self) -> bool: ...
-    def isprintable(self) -> bool: ...
-    def isspace(self) -> bool: ...
-    def istitle(self) -> bool: ...
-    def isupper(self) -> bool: ...
-    @overload
-    def join(self: LiteralString, iterable: Iterable[LiteralString], /) -> LiteralString: ...
-    @overload
-    def join(self, iterable: Iterable[str], /) -> str: ...  # type: ignore[misc]
-    @overload
-    def ljust(self: LiteralString, width: SupportsIndex, fillchar: LiteralString = " ", /) -> LiteralString: ...
-    @overload
-    def ljust(self, width: SupportsIndex, fillchar: str = " ", /) -> str: ...  # type: ignore[misc]
-    @overload
-    def lower(self: LiteralString) -> LiteralString: ...
-    @overload
-    def lower(self) -> str: ...  # type: ignore[misc]
-    @overload
-    def lstrip(self: LiteralString, chars: LiteralString | None = None, /) -> LiteralString: ...
-    @overload
-    def lstrip(self, chars: str | None = None, /) -> str: ...  # type: ignore[misc]
-    @overload
-    def partition(self: LiteralString, sep: LiteralString, /) -> tuple[LiteralString, LiteralString, LiteralString]: ...
-    @overload
-    def partition(self, sep: str, /) -> tuple[str, str, str]: ...  # type: ignore[misc]
-    if sys.version_info >= (3, 13):
-        @overload
-        def replace(
-            self: LiteralString, old: LiteralString, new: LiteralString, /, count: SupportsIndex = -1
-        ) -> LiteralString: ...
-        @overload
-        def replace(self, old: str, new: str, /, count: SupportsIndex = -1) -> str: ...  # type: ignore[misc]
-    else:
-        @overload
-        def replace(
-            self: LiteralString, old: LiteralString, new: LiteralString, count: SupportsIndex = -1, /
-        ) -> LiteralString: ...
-        @overload
-        def replace(self, old: str, new: str, count: SupportsIndex = -1, /) -> str: ...  # type: ignore[misc]
-    if sys.version_info >= (3, 9):
-        @overload
-        def removeprefix(self: LiteralString, prefix: LiteralString, /) -> LiteralString: ...
-        @overload
-        def removeprefix(self, prefix: str, /) -> str: ...  # type: ignore[misc]
-        @overload
-        def removesuffix(self: LiteralString, suffix: LiteralString, /) -> LiteralString: ...
-        @overload
-        def removesuffix(self, suffix: str, /) -> str: ...  # type: ignore[misc]
-
-    def rfind(self, sub: str, start: SupportsIndex | None = ..., end: SupportsIndex | None = ..., /) -> int: ...
-    def rindex(self, sub: str, start: SupportsIndex | None = ..., end: SupportsIndex | None = ..., /) -> int: ...
-    @overload
-    def rjust(self: LiteralString, width: SupportsIndex, fillchar: LiteralString = " ", /) -> LiteralString: ...
-    @overload
-    def rjust(self, width: SupportsIndex, fillchar: str = " ", /) -> str: ...  # type: ignore[misc]
-    @overload
-    def rpartition(self: LiteralString, sep: LiteralString, /) -> tuple[LiteralString, LiteralString, LiteralString]: ...
-    @overload
-    def rpartition(self, sep: str, /) -> tuple[str, str, str]: ...  # type: ignore[misc]
-    @overload
-    def rsplit(self: LiteralString, sep: LiteralString | None = None, maxsplit: SupportsIndex = -1) -> list[LiteralString]: ...
-    @overload
-    def rsplit(self, sep: str | None = None, maxsplit: SupportsIndex = -1) -> list[str]: ...  # type: ignore[misc]
-    @overload
-    def rstrip(self: LiteralString, chars: LiteralString | None = None, /) -> LiteralString: ...
-    @overload
-    def rstrip(self, chars: str | None = None, /) -> str: ...  # type: ignore[misc]
-    @overload
-    def split(self: LiteralString, sep: LiteralString | None = None, maxsplit: SupportsIndex = -1) -> list[LiteralString]: ...
-    @overload
-    def split(self, sep: str | None = None, maxsplit: SupportsIndex = -1) -> list[str]: ...  # type: ignore[misc]
-    @overload
-    def splitlines(self: LiteralString, keepends: bool = False) -> list[LiteralString]: ...
-    @overload
-    def splitlines(self, keepends: bool = False) -> list[str]: ...  # type: ignore[misc]
-    def startswith(
-        self, prefix: str | tuple[str, ...], start: SupportsIndex | None = ..., end: SupportsIndex | None = ..., /
-    ) -> bool: ...
-    @overload
-    def strip(self: LiteralString, chars: LiteralString | None = None, /) -> LiteralString: ...
-    @overload
-    def strip(self, chars: str | None = None, /) -> str: ...  # type: ignore[misc]
-    @overload
-    def swapcase(self: LiteralString) -> LiteralString: ...
-    @overload
-    def swapcase(self) -> str: ...  # type: ignore[misc]
-    @overload
-    def title(self: LiteralString) -> LiteralString: ...
-    @overload
-    def title(self) -> str: ...  # type: ignore[misc]
-    def translate(self, table: _TranslateTable, /) -> str: ...
-    @overload
-    def upper(self: LiteralString) -> LiteralString: ...
-    @overload
-    def upper(self) -> str: ...  # type: ignore[misc]
-    @overload
-    def zfill(self: LiteralString, width: SupportsIndex, /) -> LiteralString: ...
-    @overload
-    def zfill(self, width: SupportsIndex, /) -> str: ...  # type: ignore[misc]
-    @staticmethod
-    @overload
-    def maketrans(x: dict[int, _T] | dict[str, _T] | dict[str | int, _T], /) -> dict[int, _T]: ...
-    @staticmethod
-    @overload
-    def maketrans(x: str, y: str, /) -> dict[int, int]: ...
-    @staticmethod
-    @overload
-    def maketrans(x: str, y: str, z: str, /) -> dict[int, int | None]: ...
-    @overload
-    def __add__(self: LiteralString, value: LiteralString, /) -> LiteralString: ...
-    @overload
-    def __add__(self, value: str, /) -> str: ...  # type: ignore[misc]
-    # Incompatible with Sequence.__contains__
-    def __contains__(self, key: str, /) -> bool: ...  # type: ignore[override]
-    def __eq__(self, value: object, /) -> bool: ...
-    def __ge__(self, value: str, /) -> bool: ...
-    @overload
-    def __getitem__(self: LiteralString, key: SupportsIndex | slice, /) -> LiteralString: ...
-    @overload
-    def __getitem__(self, key: SupportsIndex | slice, /) -> str: ...  # type: ignore[misc]
-    def __gt__(self, value: str, /) -> bool: ...
-    def __hash__(self) -> int: ...
-    @overload
-    def __iter__(self: LiteralString) -> Iterator[LiteralString]: ...
-    @overload
-    def __iter__(self) -> Iterator[str]: ...  # type: ignore[misc]
-    def __le__(self, value: str, /) -> bool: ...
-    def __len__(self) -> int: ...
-    def __lt__(self, value: str, /) -> bool: ...
-    @overload
-    def __mod__(self: LiteralString, value: LiteralString | tuple[LiteralString, ...], /) -> LiteralString: ...
-    @overload
-    def __mod__(self, value: Any, /) -> str: ...
-    @overload
-    def __mul__(self: LiteralString, value: SupportsIndex, /) -> LiteralString: ...
-    @overload
-    def __mul__(self, value: SupportsIndex, /) -> str: ...  # type: ignore[misc]
-    def __ne__(self, value: object, /) -> bool: ...
-    @overload
-    def __rmul__(self: LiteralString, value: SupportsIndex, /) -> LiteralString: ...
-    @overload
-    def __rmul__(self, value: SupportsIndex, /) -> str: ...  # type: ignore[misc]
-    def __getnewargs__(self) -> tuple[str]: ...
-
 class bytes(Sequence[int]):
```

</details>

And then ran red-knot on my
[typeshed-stats](https://github.com/AlexWaygood/typeshed-stats) project
using the command

```
cargo run -p red_knot -- check --project ../typeshed-stats --python-version="3.12" --verbose
```

I observed that the following logs were printed to the terminal, but
that each warning was only printed once (the desired behaviour):

```
INFO Python version: Python 3.12, platform: all
INFO Indexed 15 file(s)
INFO Could not find class `builtins.int` in typeshed on Python 3.12. Falling back to `Unknown` for the symbol instead.
INFO Could not find class `builtins.str` in typeshed on Python 3.12. Falling back to `Unknown` for the symbol instead.
```
This commit is contained in:
Alex Waygood 2025-03-11 16:42:44 +00:00 committed by GitHub
parent 989075dc16
commit c16237ddc0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 245 additions and 80 deletions

View file

@ -134,9 +134,8 @@ impl KnownModule {
}
pub fn name(self) -> ModuleName {
let self_as_str = self.as_str();
ModuleName::new_static(self_as_str)
.unwrap_or_else(|| panic!("{self_as_str} should be a valid module name!"))
ModuleName::new_static(self.as_str())
.unwrap_or_else(|| panic!("{self} should be a valid module name!"))
}
pub(crate) fn try_from_search_path_and_name(
@ -167,6 +166,12 @@ impl KnownModule {
}
}
impl std::fmt::Display for KnownModule {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
#[cfg(test)]
mod tests {
use super::*;

View file

@ -730,10 +730,9 @@ impl<'db> Type<'db> {
// `Literal[str]` is a subtype of `type` because the `str` class object is an instance of its metaclass `type`.
// `Literal[abc.ABC]` is a subtype of `abc.ABCMeta` because the `abc.ABC` class object
// is an instance of its metaclass `abc.ABCMeta`.
(Type::ClassLiteral(ClassLiteralType { class }), _) => class
.metaclass(db)
.to_instance(db)
.is_subtype_of(db, target),
(Type::ClassLiteral(ClassLiteralType { class }), _) => {
class.metaclass_instance_type(db).is_subtype_of(db, target)
}
// `type[str]` (== `SubclassOf("str")` in red-knot) describes all possible runtime subclasses
// of the class object `str`. It is a subtype of `type` (== `Instance("type")`) because `str`
@ -745,11 +744,9 @@ impl<'db> Type<'db> {
(Type::SubclassOf(subclass_of_ty), _) => subclass_of_ty
.subclass_of()
.into_class()
.is_some_and(|class| {
class
.metaclass(db)
.to_instance(db)
.is_subtype_of(db, target)
.map(|class| class.metaclass_instance_type(db))
.is_some_and(|metaclass_instance_type| {
metaclass_instance_type.is_subtype_of(db, target)
}),
// For example: `Type::KnownInstance(KnownInstanceType::Type)` is a subtype of `Type::Instance(_SpecialForm)`,
@ -1122,16 +1119,17 @@ impl<'db> Type<'db> {
ty.bool(db).is_always_true()
}
(Type::SubclassOf(subclass_of_ty), other)
| (other, Type::SubclassOf(subclass_of_ty)) => {
let metaclass_instance_ty = match subclass_of_ty.subclass_of() {
// for `type[Any]`/`type[Unknown]`/`type[Todo]`, we know the type cannot be any larger than `type`,
// so although the type is dynamic we can still determine disjointness in some situations
ClassBase::Dynamic(_) => KnownClass::Type.to_instance(db),
ClassBase::Class(class) => class.metaclass(db).to_instance(db),
};
other.is_disjoint_from(db, metaclass_instance_ty)
(Type::SubclassOf(subclass_of_ty), other)
| (other, Type::SubclassOf(subclass_of_ty)) => match subclass_of_ty.subclass_of() {
ClassBase::Dynamic(_) => {
KnownClass::Type.to_instance(db).is_disjoint_from(db, other)
}
ClassBase::Class(class) => class
.metaclass_instance_type(db)
.is_disjoint_from(db, other),
},
(Type::KnownInstance(known_instance), Type::Instance(InstanceType { class }))
| (Type::Instance(InstanceType { class }), Type::KnownInstance(known_instance)) => {
@ -1200,8 +1198,7 @@ impl<'db> Type<'db> {
(Type::ClassLiteral(ClassLiteralType { class }), instance @ Type::Instance(_))
| (instance @ Type::Instance(_), Type::ClassLiteral(ClassLiteralType { class })) => {
!class
.metaclass(db)
.to_instance(db)
.metaclass_instance_type(db)
.is_subtype_of(db, instance)
}
@ -2106,19 +2103,13 @@ impl<'db> Type<'db> {
Type::FunctionLiteral(_) => Truthiness::AlwaysTrue,
Type::Callable(_) => Truthiness::AlwaysTrue,
Type::ModuleLiteral(_) => Truthiness::AlwaysTrue,
Type::ClassLiteral(ClassLiteralType { class }) => {
return class
.metaclass(db)
.to_instance(db)
.try_bool_impl(db, allow_short_circuit);
}
Type::ClassLiteral(ClassLiteralType { class }) => class
.metaclass_instance_type(db)
.try_bool_impl(db, allow_short_circuit)?,
Type::SubclassOf(subclass_of_ty) => match subclass_of_ty.subclass_of() {
ClassBase::Dynamic(_) => Truthiness::Ambiguous,
ClassBase::Class(class) => {
return class
.metaclass(db)
.to_instance(db)
.try_bool_impl(db, allow_short_circuit);
Type::class_literal(class).try_bool_impl(db, allow_short_circuit)?
}
},
Type::AlwaysTruthy => Truthiness::AlwaysTrue,
@ -2948,19 +2939,19 @@ impl<'db> Type<'db> {
}
#[must_use]
pub fn to_instance(&self, db: &'db dyn Db) -> Type<'db> {
pub fn to_instance(&self, db: &'db dyn Db) -> Option<Type<'db>> {
match self {
Type::Dynamic(_) => *self,
Type::Never => Type::Never,
Type::ClassLiteral(ClassLiteralType { class }) => Type::instance(*class),
Type::SubclassOf(subclass_of_ty) => match subclass_of_ty.subclass_of() {
ClassBase::Class(class) => Type::instance(class),
ClassBase::Dynamic(dynamic) => Type::Dynamic(dynamic),
},
Type::Union(union) => union.map(db, |element| element.to_instance(db)),
Type::Intersection(_) => todo_type!("Type::Intersection.to_instance()"),
// TODO: calling `.to_instance()` on any of these should result in a diagnostic,
// since they already indicate that the object is an instance of some kind:
Type::Dynamic(_) | Type::Never => Some(*self),
Type::ClassLiteral(ClassLiteralType { class }) => Some(Type::instance(*class)),
Type::SubclassOf(subclass_of_ty) => Some(subclass_of_ty.to_instance()),
Type::Union(union) => {
let mut builder = UnionBuilder::new(db);
for element in union.elements(db) {
builder = builder.add(element.to_instance(db)?);
}
Some(builder.build())
}
Type::Intersection(_) => Some(todo_type!("Type::Intersection.to_instance()")),
Type::BooleanLiteral(_)
| Type::BytesLiteral(_)
| Type::FunctionLiteral(_)
@ -2974,7 +2965,7 @@ impl<'db> Type<'db> {
| Type::Tuple(_)
| Type::LiteralString
| Type::AlwaysTruthy
| Type::AlwaysFalsy => Type::unknown(),
| Type::AlwaysFalsy => None,
}
}

View file

@ -1,3 +1,5 @@
use std::sync::{LazyLock, Mutex};
use crate::{
module_resolver::file_to_module,
semantic_index::{
@ -18,6 +20,7 @@ use indexmap::IndexSet;
use itertools::Itertools as _;
use ruff_db::files::File;
use ruff_python_ast::{self as ast, PythonVersion};
use rustc_hash::FxHashSet;
use super::{
class_base::ClassBase, infer_expression_type, infer_unpack_types, IntersectionBuilder,
@ -185,6 +188,14 @@ impl<'db> Class<'db> {
.unwrap_or_else(|_| SubclassOfType::subclass_of_unknown())
}
/// Return a type representing "the set of all instances of the metaclass of this class".
pub(super) fn metaclass_instance_type(self, db: &'db dyn Db) -> Type<'db> {
self
.metaclass(db)
.to_instance(db)
.expect("`Type::to_instance()` should always return `Some()` when called on the type of a metaclass")
}
/// Return the metaclass of this class, or an error if the metaclass cannot be inferred.
#[salsa::tracked]
pub(super) fn try_metaclass(self, db: &'db dyn Db) -> Result<Type<'db>, MetaclassError<'db>> {
@ -879,7 +890,7 @@ impl<'db> KnownClass {
}
}
pub(crate) fn as_str(self, db: &'db dyn Db) -> &'static str {
pub(crate) fn name(self, db: &'db dyn Db) -> &'static str {
match self {
Self::Bool => "bool",
Self::Object => "object",
@ -937,17 +948,101 @@ impl<'db> KnownClass {
}
}
fn display(self, db: &'db dyn Db) -> impl std::fmt::Display + 'db {
struct KnownClassDisplay<'db> {
db: &'db dyn Db,
class: KnownClass,
}
impl std::fmt::Display for KnownClassDisplay<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let KnownClassDisplay {
class: known_class,
db,
} = *self;
write!(
f,
"{module}.{class}",
module = known_class.canonical_module(db),
class = known_class.name(db)
)
}
}
KnownClassDisplay { db, class: self }
}
/// Lookup a [`KnownClass`] in typeshed and return a [`Type`]
/// representing all possible instances of the class.
///
/// If the class cannot be found in typeshed, a debug-level log message will be emitted stating this.
pub(crate) fn to_instance(self, db: &'db dyn Db) -> Type<'db> {
self.to_class_literal(db).to_instance(db)
self.to_class_literal(db)
.into_class_literal()
.map(|ClassLiteralType { class }| Type::instance(class))
.unwrap_or_else(Type::unknown)
}
/// Attempt to lookup a [`KnownClass`] in typeshed and return a [`Type`] representing that class-literal.
///
/// Return an error if the symbol cannot be found in the expected typeshed module,
/// or if the symbol is not a class definition, or if the symbol is possibly unbound.
pub(crate) fn try_to_class_literal(
self,
db: &'db dyn Db,
) -> Result<ClassLiteralType<'db>, KnownClassLookupError<'db>> {
let symbol = known_module_symbol(db, self.canonical_module(db), self.name(db)).symbol;
match symbol {
Symbol::Type(Type::ClassLiteral(class_type), Boundness::Bound) => Ok(class_type),
Symbol::Type(Type::ClassLiteral(class_type), Boundness::PossiblyUnbound) => {
Err(KnownClassLookupError::ClassPossiblyUnbound { class_type })
}
Symbol::Type(found_type, _) => {
Err(KnownClassLookupError::SymbolNotAClass { found_type })
}
Symbol::Unbound => Err(KnownClassLookupError::ClassNotFound),
}
}
/// Lookup a [`KnownClass`] in typeshed and return a [`Type`] representing that class-literal.
///
/// If the class cannot be found in typeshed, a debug-level log message will be emitted stating this.
pub(crate) fn to_class_literal(self, db: &'db dyn Db) -> Type<'db> {
known_module_symbol(db, self.canonical_module(db), self.as_str(db))
.symbol
.ignore_possibly_unbound()
.unwrap_or(Type::unknown())
// a cache of the `KnownClass`es that we have already failed to lookup in typeshed
// (and therefore that we've already logged a warning for)
static MESSAGES: LazyLock<Mutex<FxHashSet<KnownClass>>> = LazyLock::new(Mutex::default);
self.try_to_class_literal(db)
.map(Type::ClassLiteral)
.unwrap_or_else(|lookup_error| {
if MESSAGES.lock().unwrap().insert(self) {
if matches!(
lookup_error,
KnownClassLookupError::ClassPossiblyUnbound { .. }
) {
tracing::info!("{}", lookup_error.display(db, self));
} else {
tracing::info!(
"{}. Falling back to `Unknown` for the symbol instead.",
lookup_error.display(db, self)
);
}
}
match lookup_error {
KnownClassLookupError::ClassPossiblyUnbound { class_type, .. } => {
Type::class_literal(class_type.class)
}
KnownClassLookupError::ClassNotFound { .. }
| KnownClassLookupError::SymbolNotAClass { .. } => Type::unknown(),
}
})
}
/// Lookup a [`KnownClass`] in typeshed and return a [`Type`]
/// representing that class and all possible subclasses of the class.
///
/// If the class cannot be found in typeshed, a debug-level log message will be emitted stating this.
pub(crate) fn to_subclass_of(self, db: &'db dyn Db) -> Type<'db> {
self.to_class_literal(db)
.into_class_literal()
@ -958,11 +1053,8 @@ impl<'db> KnownClass {
/// Return `true` if this symbol can be resolved to a class definition `class` in typeshed,
/// *and* `class` is a subclass of `other`.
pub(super) fn is_subclass_of(self, db: &'db dyn Db, other: Class<'db>) -> bool {
known_module_symbol(db, self.canonical_module(db), self.as_str(db))
.symbol
.ignore_possibly_unbound()
.and_then(Type::into_class_literal)
.is_some_and(|ClassLiteralType { class }| class.is_subclass_of(db, other))
self.try_to_class_literal(db)
.is_ok_and(|ClassLiteralType { class }| class.is_subclass_of(db, other))
}
/// Return the module in which we should look up the definition for this class
@ -1227,6 +1319,62 @@ impl<'db> KnownClass {
}
}
/// Enumeration of ways in which looking up a [`KnownClass`] in typeshed could fail.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum KnownClassLookupError<'db> {
/// There is no symbol by that name in the expected typeshed module.
ClassNotFound,
/// There is a symbol by that name in the expected typeshed module,
/// but it's not a class.
SymbolNotAClass { found_type: Type<'db> },
/// There is a symbol by that name in the expected typeshed module,
/// and it's a class definition, but it's possibly unbound.
ClassPossiblyUnbound { class_type: ClassLiteralType<'db> },
}
impl<'db> KnownClassLookupError<'db> {
fn display(&self, db: &'db dyn Db, class: KnownClass) -> impl std::fmt::Display + 'db {
struct ErrorDisplay<'db> {
db: &'db dyn Db,
class: KnownClass,
error: KnownClassLookupError<'db>,
}
impl std::fmt::Display for ErrorDisplay<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let ErrorDisplay { db, class, error } = *self;
let class = class.display(db);
let python_version = Program::get(db).python_version(db);
match error {
KnownClassLookupError::ClassNotFound => write!(
f,
"Could not find class `{class}` in typeshed on Python {python_version}",
),
KnownClassLookupError::SymbolNotAClass { found_type } => write!(
f,
"Error looking up `{class}` in typeshed: expected to find a class definition \
on Python {python_version}, but found a symbol of type `{found_type}` instead",
found_type = found_type.display(db),
),
KnownClassLookupError::ClassPossiblyUnbound { .. } => write!(
f,
"Error looking up `{class}` in typeshed on Python {python_version}: \
expected to find a fully bound symbol, but found one that is possibly unbound",
)
}
}
}
ErrorDisplay {
db,
class,
error: *self,
}
}
}
/// Enumeration of specific runtime that are special enough to be considered their own type.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, salsa::Update)]
pub enum KnownInstanceType<'db> {
@ -1609,7 +1757,7 @@ mod tests {
fn known_class_roundtrip_from_str() {
let db = setup_db();
for class in KnownClass::iter() {
let class_name = class.as_str(&db);
let class_name = class.name(&db);
let class_module = resolve_module(&db, &class.canonical_module(&db).name()).unwrap();
assert_eq!(

View file

@ -1702,7 +1702,10 @@ impl<'db> TypeInferenceBuilder<'db> {
for element in tuple.elements(self.db()).iter().copied() {
builder = builder.add(
if element.is_assignable_to(self.db(), type_base_exception) {
element.to_instance(self.db())
element.to_instance(self.db()).expect(
"`Type::to_instance()` should always return `Some()` \
if called on a type assignable to `type[BaseException]`",
)
} else {
if let Some(node) = node {
report_invalid_exception_caught(&self.context, node, element);
@ -1717,7 +1720,10 @@ impl<'db> TypeInferenceBuilder<'db> {
} else {
let type_base_exception = KnownClass::BaseException.to_subclass_of(self.db());
if node_ty.is_assignable_to(self.db(), type_base_exception) {
node_ty.to_instance(self.db())
node_ty.to_instance(self.db()).expect(
"`Type::to_instance()` should always return `Some()` \
if called on a type assignable to `type[BaseException]`",
)
} else {
if let Some(node) = node {
report_invalid_exception_caught(&self.context, node, node_ty);
@ -2542,7 +2548,7 @@ impl<'db> TypeInferenceBuilder<'db> {
} = raise;
let base_exception_type = KnownClass::BaseException.to_subclass_of(self.db());
let base_exception_instance = base_exception_type.to_instance(self.db());
let base_exception_instance = KnownClass::BaseException.to_instance(self.db());
let can_be_raised =
UnionType::from_elements(self.db(), [base_exception_type, base_exception_instance]);

View file

@ -8,8 +8,8 @@ use crate::semantic_index::symbol::{ScopeId, ScopedSymbolId, SymbolTable};
use crate::semantic_index::symbol_table;
use crate::types::infer::infer_same_file_expression_type;
use crate::types::{
infer_expression_types, IntersectionBuilder, KnownClass, SubclassOfType, Truthiness, Type,
UnionBuilder,
infer_expression_types, ClassLiteralType, IntersectionBuilder, KnownClass, SubclassOfType,
Truthiness, Type, UnionBuilder,
};
use crate::Db;
use itertools::Itertools;
@ -379,7 +379,11 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
keywords,
range: _,
},
}) if rhs_ty.is_class_literal() && keywords.is_empty() => {
}) if keywords.is_empty() => {
let Type::ClassLiteral(ClassLiteralType { class: rhs_class }) = rhs_ty else {
continue;
};
let [ast::Expr::Name(ast::ExprName { id, .. })] = &**args else {
continue;
};
@ -394,10 +398,10 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
continue;
}
let callable_ty =
let callable_type =
inference.expression_type(callable.scoped_expression_id(self.db, scope));
if callable_ty
if callable_type
.into_class_literal()
.is_some_and(|c| c.class().is_known(self.db, KnownClass::Type))
{
@ -405,7 +409,7 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
.symbols()
.symbol_id_by_name(id)
.expect("Should always have a symbol for every Name node");
constraints.insert(symbol, rhs_ty.to_instance(self.db));
constraints.insert(symbol, Type::instance(rhs_class));
}
}
_ => {}
@ -494,17 +498,16 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
subject: Expression<'db>,
cls: Expression<'db>,
) -> Option<NarrowingConstraints<'db>> {
if let Some(ast::ExprName { id, .. }) = subject.node_ref(self.db).as_name_expr() {
// SAFETY: we should always have a symbol for every Name node.
let symbol = self.symbols().symbol_id_by_name(id).unwrap();
let ty = infer_same_file_expression_type(self.db, cls).to_instance(self.db);
let ast::ExprName { id, .. } = subject.node_ref(self.db).as_name_expr()?;
let symbol = self
.symbols()
.symbol_id_by_name(id)
.expect("We should always have a symbol for every `Name` node");
let ty = infer_same_file_expression_type(self.db, cls).to_instance(self.db)?;
let mut constraints = NarrowingConstraints::default();
constraints.insert(symbol, ty);
Some(constraints)
} else {
None
}
}
fn evaluate_bool_op(

View file

@ -84,7 +84,7 @@ fn create_bound_method<'db>(
Type::Callable(CallableType::BoundMethod(BoundMethodType::new(
db,
function.expect_function_literal(),
builtins_class.to_instance(db),
builtins_class.to_instance(db).unwrap(),
)))
}
@ -100,11 +100,16 @@ impl Ty {
Ty::BooleanLiteral(b) => Type::BooleanLiteral(b),
Ty::LiteralString => Type::LiteralString,
Ty::BytesLiteral(s) => Type::bytes_literal(db, s.as_bytes()),
Ty::BuiltinInstance(s) => builtins_symbol(db, s).symbol.expect_type().to_instance(db),
Ty::BuiltinInstance(s) => builtins_symbol(db, s)
.symbol
.expect_type()
.to_instance(db)
.unwrap(),
Ty::AbcInstance(s) => known_module_symbol(db, KnownModule::Abc, s)
.symbol
.expect_type()
.to_instance(db),
.to_instance(db)
.unwrap(),
Ty::AbcClassLiteral(s) => known_module_symbol(db, KnownModule::Abc, s)
.symbol
.expect_type(),

View file

@ -92,4 +92,11 @@ impl<'db> SubclassOfType<'db> {
}
}
}
pub(crate) fn to_instance(self) -> Type<'db> {
match self.subclass_of {
ClassBase::Class(class) => Type::instance(class),
ClassBase::Dynamic(dynamic_type) => Type::Dynamic(dynamic_type),
}
}
}