[ty] Expansion of enums into unions of literals (#19382)

## Summary

Implement expansion of enums into unions of enum literals (and the
reverse operation). For the enum below, this allows us to understand
that `Color = Literal[Color.RED, Color.GREEN, Color.BLUE]`, or that
`Color & ~Literal[Color.RED] = Literal[Color.GREEN, Color.BLUE]`. This
helps in exhaustiveness checking, which is why we see some removed
`assert_never` false positives. And since exhaustiveness checking also
helps with understanding terminal control flow, we also see a few
removed `invalid-return-type` and `possibly-unresolved-reference` false
positives. This PR also adds expansion of enums in overload resolution
and type narrowing constructs.

```py
from enum import Enum
from typing_extensions import Literal, assert_never
from ty_extensions import Intersection, Not, static_assert, is_equivalent_to

class Color(Enum):
    RED = 1
    GREEN = 2
    BLUE = 3

type Red = Literal[Color.RED]
type Green = Literal[Color.GREEN]
type Blue = Literal[Color.BLUE]

static_assert(is_equivalent_to(Red | Green | Blue, Color))
static_assert(is_equivalent_to(Intersection[Color, Not[Red]], Green | Blue))


def color_name(color: Color) -> str:  # no error here (we detect that this can not implicitly return None)
    if color is Color.RED:
        return "Red"
    elif color is Color.GREEN:
        return "Green"
    elif color is Color.BLUE:
        return "Blue"
    else:
        assert_never(color)  # no error here
```

## Performance

I avoided an initial regression here for large enums, but the
`UnionBuilder` and `IntersectionBuilder` parts can certainly still be
optimized. We might want to use the same technique that we also use for
unions of other literals. I didn't see any problems in our benchmarks so
far, so this is not included yet.

## Test Plan

Many new Markdown tests
This commit is contained in:
David Peter 2025-07-21 19:37:55 +02:00 committed by GitHub
parent 926e83323a
commit dc66019fbc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 750 additions and 102 deletions

View file

@ -2355,12 +2355,13 @@ import enum
reveal_type(enum.Enum.__members__) # revealed: MappingProxyType[str, Unknown]
class Foo(enum.Enum):
BAR = 1
class Answer(enum.Enum):
NO = 0
YES = 1
reveal_type(Foo.BAR) # revealed: Literal[Foo.BAR]
reveal_type(Foo.BAR.value) # revealed: Any
reveal_type(Foo.__members__) # revealed: MappingProxyType[str, Unknown]
reveal_type(Answer.NO) # revealed: Literal[Answer.NO]
reveal_type(Answer.NO.value) # revealed: Any
reveal_type(Answer.__members__) # revealed: MappingProxyType[str, Unknown]
```
## References

View file

@ -369,6 +369,8 @@ def _(x: type[A | B]):
### Expanding enums
#### Basic
`overloaded.pyi`:
```pyi
@ -394,15 +396,106 @@ def f(x: Literal[SomeEnum.C]) -> C: ...
```
```py
from typing import Literal
from overloaded import SomeEnum, A, B, C, f
def _(x: SomeEnum):
def _(x: SomeEnum, y: Literal[SomeEnum.A, SomeEnum.C]):
reveal_type(f(SomeEnum.A)) # revealed: A
reveal_type(f(SomeEnum.B)) # revealed: B
reveal_type(f(SomeEnum.C)) # revealed: C
# TODO: This should not be an error. The return type should be `A | B | C` once enums are expanded
# error: [no-matching-overload]
reveal_type(f(x)) # revealed: Unknown
reveal_type(f(x)) # revealed: A | B | C
reveal_type(f(y)) # revealed: A | C
```
#### Enum with single member
This pattern appears in typeshed. Here, it is used to represent two optional, mutually exclusive
keyword parameters:
`overloaded.pyi`:
```pyi
from enum import Enum, auto
from typing import overload, Literal
class Missing(Enum):
Value = auto()
class OnlyASpecified: ...
class OnlyBSpecified: ...
class BothMissing: ...
@overload
def f(*, a: int, b: Literal[Missing.Value] = ...) -> OnlyASpecified: ...
@overload
def f(*, a: Literal[Missing.Value] = ..., b: int) -> OnlyBSpecified: ...
@overload
def f(*, a: Literal[Missing.Value] = ..., b: Literal[Missing.Value] = ...) -> BothMissing: ...
```
```py
from typing import Literal
from overloaded import f, Missing
reveal_type(f()) # revealed: BothMissing
reveal_type(f(a=0)) # revealed: OnlyASpecified
reveal_type(f(b=0)) # revealed: OnlyBSpecified
f(a=0, b=0) # error: [no-matching-overload]
def _(missing: Literal[Missing.Value], missing_or_present: Literal[Missing.Value] | int):
reveal_type(f(a=missing, b=missing)) # revealed: BothMissing
reveal_type(f(a=missing)) # revealed: BothMissing
reveal_type(f(b=missing)) # revealed: BothMissing
reveal_type(f(a=0, b=missing)) # revealed: OnlyASpecified
reveal_type(f(a=missing, b=0)) # revealed: OnlyBSpecified
reveal_type(f(a=missing_or_present)) # revealed: BothMissing | OnlyASpecified
reveal_type(f(b=missing_or_present)) # revealed: BothMissing | OnlyBSpecified
# Here, both could be present, so this should be an error
f(a=missing_or_present, b=missing_or_present) # error: [no-matching-overload]
```
#### Enum subclass without members
An `Enum` subclass without members should *not* be expanded:
`overloaded.pyi`:
```pyi
from enum import Enum
from typing import overload, Literal
class MyEnumSubclass(Enum):
pass
class ActualEnum(MyEnumSubclass):
A = 1
B = 2
class OnlyA: ...
class OnlyB: ...
class Both: ...
@overload
def f(x: Literal[ActualEnum.A]) -> OnlyA: ...
@overload
def f(x: Literal[ActualEnum.B]) -> OnlyB: ...
@overload
def f(x: ActualEnum) -> Both: ...
@overload
def f(x: MyEnumSubclass) -> MyEnumSubclass: ...
```
```py
from overloaded import MyEnumSubclass, ActualEnum, f
def _(actual_enum: ActualEnum, my_enum_instance: MyEnumSubclass):
reveal_type(f(actual_enum)) # revealed: Both
reveal_type(f(ActualEnum.A)) # revealed: OnlyA
reveal_type(f(ActualEnum.B)) # revealed: OnlyB
reveal_type(f(my_enum_instance)) # revealed: MyEnumSubclass
```
### No matching overloads

View file

@ -570,7 +570,111 @@ To do: <https://typing.python.org/en/latest/spec/enums.html#enum-definition>
## Exhaustiveness checking
To do
## `if` statements
```py
from enum import Enum
from typing_extensions import assert_never
class Color(Enum):
RED = 1
GREEN = 2
BLUE = 3
def color_name(color: Color) -> str:
if color is Color.RED:
return "Red"
elif color is Color.GREEN:
return "Green"
elif color is Color.BLUE:
return "Blue"
else:
assert_never(color)
# No `invalid-return-type` error here because the implicit `else` branch is detected as unreachable:
def color_name_without_assertion(color: Color) -> str:
if color is Color.RED:
return "Red"
elif color is Color.GREEN:
return "Green"
elif color is Color.BLUE:
return "Blue"
def color_name_misses_one_variant(color: Color) -> str:
if color is Color.RED:
return "Red"
elif color is Color.GREEN:
return "Green"
else:
assert_never(color) # error: [type-assertion-failure] "Argument does not have asserted type `Never`"
class Singleton(Enum):
VALUE = 1
def singleton_check(value: Singleton) -> str:
if value is Singleton.VALUE:
return "Singleton value"
else:
assert_never(value)
```
## `match` statements
```toml
[environment]
python-version = "3.10"
```
```py
from enum import Enum
from typing_extensions import assert_never
class Color(Enum):
RED = 1
GREEN = 2
BLUE = 3
def color_name(color: Color) -> str:
match color:
case Color.RED:
return "Red"
case Color.GREEN:
return "Green"
case Color.BLUE:
return "Blue"
case _:
assert_never(color)
# TODO: this should not be an error, see https://github.com/astral-sh/ty/issues/99#issuecomment-2983054488
# error: [invalid-return-type] "Function can implicitly return `None`, which is not assignable to return type `str`"
def color_name_without_assertion(color: Color) -> str:
match color:
case Color.RED:
return "Red"
case Color.GREEN:
return "Green"
case Color.BLUE:
return "Blue"
def color_name_misses_one_variant(color: Color) -> str:
match color:
case Color.RED:
return "Red"
case Color.GREEN:
return "Green"
case _:
assert_never(color) # error: [type-assertion-failure] "Argument does not have asserted type `Never`"
class Singleton(Enum):
VALUE = 1
def singleton_check(value: Singleton) -> str:
match value:
case Singleton.VALUE:
return "Singleton value"
case _:
assert_never(value)
```
## References

View file

@ -763,6 +763,65 @@ def f(
reveal_type(j) # revealed: Unknown & Literal[""]
```
## Simplifications involving enums and enum literals
```toml
[environment]
python-version = "3.12"
```
```py
from ty_extensions import Intersection, Not
from typing import Literal
from enum import Enum
class Color(Enum):
RED = "red"
GREEN = "green"
BLUE = "blue"
type Red = Literal[Color.RED]
type Green = Literal[Color.GREEN]
type Blue = Literal[Color.BLUE]
def f(
a: Intersection[Color, Red],
b: Intersection[Color, Not[Red]],
c: Intersection[Color, Not[Red | Green]],
d: Intersection[Color, Not[Red | Green | Blue]],
e: Intersection[Red, Not[Color]],
f: Intersection[Red | Green, Not[Color]],
g: Intersection[Not[Red], Color],
h: Intersection[Red, Green],
i: Intersection[Red | Green, Green | Blue],
):
reveal_type(a) # revealed: Literal[Color.RED]
reveal_type(b) # revealed: Literal[Color.GREEN, Color.BLUE]
reveal_type(c) # revealed: Literal[Color.BLUE]
reveal_type(d) # revealed: Never
reveal_type(e) # revealed: Never
reveal_type(f) # revealed: Never
reveal_type(g) # revealed: Literal[Color.GREEN, Color.BLUE]
reveal_type(h) # revealed: Never
reveal_type(i) # revealed: Literal[Color.GREEN]
class Single(Enum):
VALUE = 0
def g(
a: Intersection[Single, Literal[Single.VALUE]],
b: Intersection[Single, Not[Literal[Single.VALUE]]],
c: Intersection[Not[Literal[Single.VALUE]], Single],
d: Intersection[Single, Not[Single]],
e: Intersection[Single | int, Not[Single]],
):
reveal_type(a) # revealed: Single
reveal_type(b) # revealed: Never
reveal_type(c) # revealed: Never
reveal_type(d) # revealed: Never
reveal_type(e) # revealed: int
```
## Addition of a type to an intersection with many non-disjoint types
This slightly strange-looking test is a regression test for a mistake that was nearly made in a PR:

View file

@ -1,4 +1,4 @@
# Narrowing for `!=` conditionals
# Narrowing for `!=` and `==` conditionals
## `x != None`
@ -22,6 +22,12 @@ def _(x: bool):
reveal_type(x) # revealed: Literal[True]
else:
reveal_type(x) # revealed: Literal[False]
def _(x: bool):
if x == False:
reveal_type(x) # revealed: Literal[False]
else:
reveal_type(x) # revealed: Literal[True]
```
### Enums
@ -35,11 +41,31 @@ class Answer(Enum):
def _(answer: Answer):
if answer != Answer.NO:
# TODO: This should be simplified to `Literal[Answer.YES]`
reveal_type(answer) # revealed: Answer & ~Literal[Answer.NO]
reveal_type(answer) # revealed: Literal[Answer.YES]
else:
# TODO: This should be `Literal[Answer.NO]`
reveal_type(answer) # revealed: Answer
reveal_type(answer) # revealed: Literal[Answer.NO]
def _(answer: Answer):
if answer == Answer.NO:
reveal_type(answer) # revealed: Literal[Answer.NO]
else:
reveal_type(answer) # revealed: Literal[Answer.YES]
class Single(Enum):
VALUE = 1
def _(x: Single | int):
if x != Single.VALUE:
reveal_type(x) # revealed: int
else:
# `int` is not eliminated here because there could be subclasses of `int` with custom `__eq__`/`__ne__` methods
reveal_type(x) # revealed: Single | int
def _(x: Single | int):
if x == Single.VALUE:
reveal_type(x) # revealed: Single | int
else:
reveal_type(x) # revealed: int
```
This narrowing behavior is only safe if the enum has no custom `__eq__`/`__ne__` method:

View file

@ -78,8 +78,16 @@ def _(answer: Answer):
if answer is Answer.NO:
reveal_type(answer) # revealed: Literal[Answer.NO]
else:
# TODO: This should be `Literal[Answer.YES]`
reveal_type(answer) # revealed: Answer & ~Literal[Answer.NO]
reveal_type(answer) # revealed: Literal[Answer.YES]
class Single(Enum):
VALUE = 1
def _(x: Single | int):
if x is Single.VALUE:
reveal_type(x) # revealed: Single
else:
reveal_type(x) # revealed: int
```
## `is` for `EllipsisType` (Python 3.10+)

View file

@ -18,6 +18,8 @@ def _(flag: bool):
## `is not` for other singleton types
Boolean literals:
```py
def _(flag: bool):
x = True if flag else False
@ -29,6 +31,33 @@ def _(flag: bool):
reveal_type(x) # revealed: Literal[False]
```
Enum literals:
```py
from enum import Enum
class Answer(Enum):
NO = 0
YES = 1
def _(answer: Answer):
if answer is not Answer.NO:
reveal_type(answer) # revealed: Literal[Answer.YES]
else:
reveal_type(answer) # revealed: Literal[Answer.NO]
reveal_type(answer) # revealed: Answer
class Single(Enum):
VALUE = 1
def _(x: Single | int):
if x is not Single.VALUE:
reveal_type(x) # revealed: int
else:
reveal_type(x) # revealed: Single
```
## `is not` for non-singleton types
Non-singleton types should *not* narrow the type: two instances of a non-singleton class may occupy

View file

@ -138,11 +138,15 @@ class Answer(Enum):
static_assert(is_assignable_to(Literal[Answer.YES], Literal[Answer.YES]))
static_assert(is_assignable_to(Literal[Answer.YES], Answer))
static_assert(is_assignable_to(Literal[Answer.YES, Answer.NO], Answer))
# TODO: this should not be an error
# error: [static-assert-error]
static_assert(is_assignable_to(Answer, Literal[Answer.YES, Answer.NO]))
static_assert(not is_assignable_to(Literal[Answer.YES], Literal[Answer.NO]))
class Single(Enum):
VALUE = 1
static_assert(is_assignable_to(Literal[Single.VALUE], Single))
static_assert(is_assignable_to(Single, Literal[Single.VALUE]))
```
### Slice literals

View file

@ -21,6 +21,9 @@ class Answer(Enum):
NO = 0
YES = 1
class Single(Enum):
VALUE = 1
static_assert(is_equivalent_to(Literal[1, 2], Literal[1, 2]))
static_assert(is_equivalent_to(type[object], type))
static_assert(is_equivalent_to(type, type[object]))
@ -31,12 +34,15 @@ static_assert(not is_equivalent_to(Literal[1, 2], Literal[1, 2, 3]))
static_assert(not is_equivalent_to(Literal[1, 2, 3], Literal[1, 2]))
static_assert(is_equivalent_to(Literal[Answer.YES], Literal[Answer.YES]))
# TODO: these should be equivalent
# error: [static-assert-error]
static_assert(is_equivalent_to(Literal[Answer.NO, Answer.YES], Answer))
static_assert(is_equivalent_to(Literal[Answer.YES, Answer.NO], Answer))
static_assert(not is_equivalent_to(Literal[Answer.YES], Literal[Answer.NO]))
static_assert(not is_equivalent_to(Literal[Answer.YES], Answer))
static_assert(is_equivalent_to(Literal[Single.VALUE], Single))
static_assert(is_equivalent_to(Single, Literal[Single.VALUE]))
static_assert(is_equivalent_to(Literal[Single.VALUE], Literal[Single.VALUE]))
static_assert(is_equivalent_to(Never, Never))
static_assert(is_equivalent_to(AlwaysTruthy, AlwaysTruthy))
static_assert(is_equivalent_to(AlwaysFalsy, AlwaysFalsy))
@ -69,8 +75,9 @@ static_assert(not is_equivalent_to(type[object], type[Any]))
## Unions and intersections
```py
from typing import Any
from typing import Any, Literal
from ty_extensions import Intersection, Not, Unknown, is_equivalent_to, static_assert
from enum import Enum
static_assert(is_equivalent_to(str | int, str | int))
static_assert(is_equivalent_to(str | int | Any, str | int | Unknown))
@ -111,6 +118,11 @@ static_assert(is_equivalent_to(Intersection[P, Q], Intersection[Q, P]))
static_assert(is_equivalent_to(Intersection[Q, Not[P]], Intersection[Not[P], Q]))
static_assert(is_equivalent_to(Intersection[Q, R, Not[P]], Intersection[Not[P], R, Q]))
static_assert(is_equivalent_to(Intersection[Q | R, Not[P | S]], Intersection[Not[S | P], R | Q]))
class Single(Enum):
VALUE = 1
static_assert(is_equivalent_to(P | Q | Single, Literal[Single.VALUE] | Q | P))
```
## Tuples

View file

@ -47,6 +47,9 @@ class NormalEnum(Enum):
NO = 0
YES = 1
class SingleValuedEnum(Enum):
VALUE = 1
class ComparesEqualEnum(Enum):
NO = 0
YES = 1
@ -70,13 +73,20 @@ class CustomNeEnum(Enum):
static_assert(is_single_valued(Literal[NormalEnum.NO]))
static_assert(is_single_valued(Literal[NormalEnum.YES]))
static_assert(not is_single_valued(NormalEnum))
static_assert(is_single_valued(Literal[SingleValuedEnum.VALUE]))
static_assert(is_single_valued(SingleValuedEnum))
static_assert(is_single_valued(Literal[ComparesEqualEnum.NO]))
static_assert(is_single_valued(Literal[ComparesEqualEnum.YES]))
static_assert(not is_single_valued(ComparesEqualEnum))
static_assert(not is_single_valued(Literal[CustomEqEnum.NO]))
static_assert(not is_single_valued(Literal[CustomEqEnum.YES]))
static_assert(not is_single_valued(CustomEqEnum))
static_assert(not is_single_valued(Literal[CustomNeEnum.NO]))
static_assert(not is_single_valued(Literal[CustomNeEnum.YES]))
static_assert(not is_single_valued(CustomNeEnum))
```

View file

@ -13,11 +13,16 @@ class Answer(Enum):
NO = 0
YES = 1
class Single(Enum):
VALUE = 1
static_assert(is_singleton(None))
static_assert(is_singleton(Literal[True]))
static_assert(is_singleton(Literal[False]))
static_assert(is_singleton(Literal[Answer.YES]))
static_assert(is_singleton(Literal[Answer.NO]))
static_assert(is_singleton(Literal[Single.VALUE]))
static_assert(is_singleton(Single))
static_assert(is_singleton(type[bool]))

View file

@ -96,6 +96,9 @@ class Answer(Enum):
NO = 0
YES = 1
class Single(Enum):
VALUE = 1
# Boolean literals
static_assert(is_subtype_of(Literal[True], bool))
static_assert(is_subtype_of(Literal[True], int))
@ -125,11 +128,12 @@ static_assert(is_subtype_of(Literal[b"foo"], object))
static_assert(is_subtype_of(Literal[Answer.YES], Literal[Answer.YES]))
static_assert(is_subtype_of(Literal[Answer.YES], Answer))
static_assert(is_subtype_of(Literal[Answer.YES, Answer.NO], Answer))
# TODO: this should not be an error
# error: [static-assert-error]
static_assert(is_subtype_of(Answer, Literal[Answer.YES, Answer.NO]))
static_assert(not is_subtype_of(Literal[Answer.YES], Literal[Answer.NO]))
static_assert(is_subtype_of(Literal[Single.VALUE], Single))
static_assert(is_subtype_of(Single, Literal[Single.VALUE]))
```
## Heterogeneous tuple types

View file

@ -114,6 +114,33 @@ def _(
reveal_type(u5) # revealed: bool | Literal[17]
```
## Enum literals
```py
from enum import Enum
from typing import Literal
class Color(Enum):
RED = "red"
GREEN = "green"
BLUE = "blue"
def _(
u1: Literal[Color.RED, Color.GREEN],
u2: Color | Literal[Color.RED],
u3: Literal[Color.RED] | Color,
u4: Literal[Color.RED] | Literal[Color.RED, Color.GREEN],
u5: Literal[Color.RED, Color.GREEN, Color.BLUE],
u6: Literal[Color.RED] | Literal[Color.GREEN] | Literal[Color.BLUE],
) -> None:
reveal_type(u1) # revealed: Literal[Color.RED, Color.GREEN]
reveal_type(u2) # revealed: Color
reveal_type(u3) # revealed: Color
reveal_type(u4) # revealed: Literal[Color.RED, Color.GREEN]
reveal_type(u5) # revealed: Color
reveal_type(u6) # revealed: Color
```
## Do not erase `Unknown`
```py

View file

@ -39,7 +39,7 @@ use crate::types::call::{Binding, Bindings, CallArguments, CallableBinding};
pub(crate) use crate::types::class_base::ClassBase;
use crate::types::context::{LintDiagnosticGuard, LintDiagnosticGuardBuilder};
use crate::types::diagnostic::{INVALID_TYPE_FORM, UNSUPPORTED_BOOL_CONVERSION};
use crate::types::enums::enum_metadata;
use crate::types::enums::{enum_metadata, is_single_member_enum};
use crate::types::function::{
DataclassTransformerParams, FunctionSpans, FunctionType, KnownFunction,
};
@ -789,6 +789,19 @@ impl<'db> Type<'db> {
matches!(self, Type::ClassLiteral(..))
}
pub fn into_enum_literal(self) -> Option<EnumLiteralType<'db>> {
match self {
Type::EnumLiteral(enum_literal) => Some(enum_literal),
_ => None,
}
}
#[track_caller]
pub fn expect_enum_literal(self) -> EnumLiteralType<'db> {
self.into_enum_literal()
.expect("Expected a Type::EnumLiteral variant")
}
pub(crate) const fn into_tuple(self) -> Option<TupleType<'db>> {
match self {
Type::Tuple(tuple_type) => Some(tuple_type),
@ -1420,6 +1433,16 @@ impl<'db> Type<'db> {
// All `StringLiteral` types are a subtype of `LiteralString`.
(Type::StringLiteral(_), Type::LiteralString) => true,
// An instance is a subtype of an enum literal, if it is an instance of the enum class
// and the enum has only one member.
(Type::NominalInstance(_), Type::EnumLiteral(target_enum_literal)) => {
if target_enum_literal.enum_class_instance(db) != self {
return false;
}
is_single_member_enum(db, target_enum_literal.enum_class(db))
}
// Except for the special `LiteralString` case above,
// most `Literal` types delegate to their instance fallbacks
// unless `self` is exactly equivalent to `target` (handled above)
@ -1656,6 +1679,17 @@ impl<'db> Type<'db> {
| (nominal @ Type::NominalInstance(n), Type::ProtocolInstance(protocol)) => {
n.class.is_object(db) && protocol.normalized(db) == nominal
}
// An instance of an enum class is equivalent to an enum literal of that class,
// if that enum has only has one member.
(Type::NominalInstance(instance), Type::EnumLiteral(literal))
| (Type::EnumLiteral(literal), Type::NominalInstance(instance)) => {
if literal.enum_class_instance(db) != Type::NominalInstance(instance) {
return false;
}
let class_literal = instance.class.class_literal(db).0;
is_single_member_enum(db, class_literal)
}
_ => false,
}
}
@ -8409,6 +8443,7 @@ pub struct EnumLiteralType<'db> {
/// A reference to the enum class this literal belongs to
enum_class: ClassLiteral<'db>,
/// The name of the enum member
#[returns(ref)]
name: Name,
}

View file

@ -37,6 +37,7 @@
//! are subtypes of each other (unless exactly the same literal type), we can avoid many
//! unnecessary `is_subtype_of` checks.
use crate::types::enums::{enum_member_literals, enum_metadata};
use crate::types::{
BytesLiteralType, IntersectionType, KnownClass, StringLiteralType, Type,
TypeVarBoundOrConstraints, UnionType,
@ -87,6 +88,13 @@ enum UnionElement<'db> {
}
impl<'db> UnionElement<'db> {
const fn to_type_element(&self) -> Option<Type<'db>> {
match self {
UnionElement::Type(ty) => Some(*ty),
_ => None,
}
}
/// Try reducing this `UnionElement` given the presence in the same union of `other_type`.
fn try_reduce(&mut self, db: &'db dyn Db, other_type: Type<'db>) -> ReduceResult<'db> {
match self {
@ -374,6 +382,38 @@ impl<'db> UnionBuilder<'db> {
self.elements.swap_remove(index);
}
}
Type::EnumLiteral(enum_member_to_add) => {
let enum_class = enum_member_to_add.enum_class(self.db);
let metadata =
enum_metadata(self.db, enum_class).expect("Class of enum literal is an enum");
let enum_members_in_union = self
.elements
.iter()
.filter_map(UnionElement::to_type_element)
.filter_map(Type::into_enum_literal)
.map(|literal| literal.name(self.db).clone())
.chain(std::iter::once(enum_member_to_add.name(self.db).clone()))
.collect::<FxOrderSet<_>>();
let all_members_are_in_union = metadata
.members
.difference(&enum_members_in_union)
.next()
.is_none();
if all_members_are_in_union {
self.add_in_place(enum_member_to_add.enum_class_instance(self.db));
} else if !self
.elements
.iter()
.filter_map(UnionElement::to_type_element)
.any(|ty| Type::EnumLiteral(enum_member_to_add).is_subtype_of(self.db, ty))
{
self.elements
.push(UnionElement::Type(Type::EnumLiteral(enum_member_to_add)));
}
}
// Adding `object` to a union results in `object`.
ty if ty.is_object(self.db) => {
self.collapse_to_object();
@ -501,7 +541,8 @@ impl<'db> IntersectionBuilder<'db> {
}
pub(crate) fn add_positive(mut self, ty: Type<'db>) -> Self {
if let Type::Union(union) = ty {
match ty {
Type::Union(union) => {
// Distribute ourself over this union: for each union element, clone ourself and
// intersect with that union element, then create a new union-of-intersections with all
// of those sub-intersections in it. E.g. if `self` is a simple intersection `T1 & T2`
@ -518,7 +559,55 @@ impl<'db> IntersectionBuilder<'db> {
builder.intersections.extend(sub.intersections);
builder
})
}
// `(A & B & ~C) & (D & E & ~F)` -> `A & B & D & E & ~C & ~F`
Type::Intersection(other) => {
let db = self.db;
for pos in other.positive(db) {
self = self.add_positive(*pos);
}
for neg in other.negative(db) {
self = self.add_negative(*neg);
}
self
}
Type::NominalInstance(instance)
if enum_metadata(self.db, instance.class.class_literal(self.db).0).is_some() =>
{
let mut contains_enum_literal_as_negative_element = false;
for intersection in &self.intersections {
if intersection.negative.iter().any(|negative| {
negative
.into_enum_literal()
.is_some_and(|lit| lit.enum_class_instance(self.db) == ty)
}) {
contains_enum_literal_as_negative_element = true;
break;
}
}
if contains_enum_literal_as_negative_element {
// If we have an enum literal of this enum already in the negative side of
// the intersection, expand the instance into the union of enum members, and
// add that union to the intersection.
// Note: we manually construct a `UnionType` here instead of going through
// `UnionBuilder` because we would simplify the union to just the enum instance
// and end up in this branch again.
let db = self.db;
self.add_positive(Type::Union(UnionType::new(
db,
enum_member_literals(db, instance.class.class_literal(db).0, None)
.expect("Calling `enum_member_literals` on an enum class")
.collect::<Box<[_]>>(),
)))
} else {
for inner in &mut self.intersections {
inner.add_positive(self.db, ty);
}
self
}
}
_ => {
// If we are already a union-of-intersections, distribute the new intersected element
// across all of those intersections.
for inner in &mut self.intersections {
@ -527,15 +616,25 @@ impl<'db> IntersectionBuilder<'db> {
self
}
}
}
pub(crate) fn add_negative(mut self, ty: Type<'db>) -> Self {
let contains_enum = |enum_instance| {
self.intersections
.iter()
.flat_map(|intersection| &intersection.positive)
.any(|ty| *ty == enum_instance)
};
// See comments above in `add_positive`; this is just the negated version.
if let Type::Union(union) = ty {
match ty {
Type::Union(union) => {
for elem in union.elements(self.db) {
self = self.add_negative(*elem);
}
self
} else if let Type::Intersection(intersection) = ty {
}
Type::Intersection(intersection) => {
// (A | B) & ~(C & ~D)
// -> (A | B) & (~C | D)
// -> ((A | B) & ~C) | ((A | B) & D)
@ -562,13 +661,29 @@ impl<'db> IntersectionBuilder<'db> {
builder
},
)
} else {
}
Type::EnumLiteral(enum_literal)
if contains_enum(enum_literal.enum_class_instance(self.db)) =>
{
let db = self.db;
self.add_positive(UnionType::from_elements(
db,
enum_member_literals(
db,
enum_literal.enum_class(db),
Some(enum_literal.name(db)),
)
.expect("Calling `enum_member_literals` on an enum class"),
))
}
_ => {
for inner in &mut self.intersections {
inner.add_negative(self.db, ty);
}
self
}
}
}
pub(crate) fn positive_elements<I, T>(mut self, elements: I) -> Self
where
@ -643,15 +758,7 @@ impl<'db> InnerIntersectionBuilder<'db> {
self.add_positive(db, Type::LiteralString);
self.add_negative(db, Type::string_literal(db, ""));
}
// `(A & B & ~C) & (D & E & ~F)` -> `A & B & D & E & ~C & ~F`
Type::Intersection(other) => {
for pos in other.positive(db) {
self.add_positive(db, *pos);
}
for neg in other.negative(db) {
self.add_negative(db, *neg);
}
}
_ => {
let known_instance = new_positive
.into_nominal_instance()
@ -961,7 +1068,10 @@ impl<'db> InnerIntersectionBuilder<'db> {
mod tests {
use super::{IntersectionBuilder, Type, UnionBuilder, UnionType};
use crate::KnownModule;
use crate::db::tests::setup_db;
use crate::place::known_module_symbol;
use crate::types::enums::enum_member_literals;
use crate::types::{KnownClass, Truthiness};
use test_case::test_case;
@ -1044,4 +1154,77 @@ mod tests {
.build();
assert_eq!(ty, Type::BooleanLiteral(!bool_value));
}
#[test]
fn build_intersection_enums() {
let db = setup_db();
let safe_uuid_class = known_module_symbol(&db, KnownModule::Uuid, "SafeUUID")
.place
.ignore_possibly_unbound()
.unwrap();
let literals = enum_member_literals(&db, safe_uuid_class.expect_class_literal(), None)
.unwrap()
.collect::<Vec<_>>();
assert_eq!(literals.len(), 3);
// SafeUUID.safe
let l_safe = literals[0];
assert_eq!(l_safe.expect_enum_literal().name(&db), "safe");
// SafeUUID.unsafe
let l_unsafe = literals[1];
assert_eq!(l_unsafe.expect_enum_literal().name(&db), "unsafe");
// SafeUUID.unknown
let l_unknown = literals[2];
assert_eq!(l_unknown.expect_enum_literal().name(&db), "unknown");
// The enum itself: SafeUUID
let safe_uuid = l_safe.expect_enum_literal().enum_class_instance(&db);
{
let actual = IntersectionBuilder::new(&db)
.add_positive(safe_uuid)
.add_negative(l_safe)
.build();
assert_eq!(
actual.display(&db).to_string(),
"Literal[SafeUUID.unsafe, SafeUUID.unknown]"
);
}
{
// Same as above, but with the order reversed
let actual = IntersectionBuilder::new(&db)
.add_negative(l_safe)
.add_positive(safe_uuid)
.build();
assert_eq!(
actual.display(&db).to_string(),
"Literal[SafeUUID.unsafe, SafeUUID.unknown]"
);
}
{
// Also the same, but now with a nested intersection
let actual = IntersectionBuilder::new(&db)
.add_positive(safe_uuid)
.add_positive(IntersectionBuilder::new(&db).add_negative(l_safe).build())
.build();
assert_eq!(
actual.display(&db).to_string(),
"Literal[SafeUUID.unsafe, SafeUUID.unknown]"
);
}
{
let actual = IntersectionBuilder::new(&db)
.add_negative(l_safe)
.add_positive(safe_uuid)
.add_negative(l_unsafe)
.build();
assert_eq!(actual.display(&db).to_string(), "Literal[SafeUUID.unknown]");
}
}
}

View file

@ -5,6 +5,7 @@ use ruff_python_ast as ast;
use crate::Db;
use crate::types::KnownClass;
use crate::types::enums::enum_member_literals;
use crate::types::tuple::{TupleSpec, TupleType};
use super::Type;
@ -199,13 +200,22 @@ impl<'a, 'db> FromIterator<(Argument<'a>, Option<Type<'db>>)> for CallArguments<
///
/// Returns [`None`] if the type cannot be expanded.
fn expand_type<'db>(db: &'db dyn Db, ty: Type<'db>) -> Option<Vec<Type<'db>>> {
// TODO: Expand enums to their variants
match ty {
Type::NominalInstance(instance) if instance.class.is_known(db, KnownClass::Bool) => {
Some(vec![
Type::NominalInstance(instance) => {
if instance.class.is_known(db, KnownClass::Bool) {
return Some(vec![
Type::BooleanLiteral(true),
Type::BooleanLiteral(false),
])
]);
}
let class_literal = instance.class.class_literal(db).0;
if let Some(enum_members) = enum_member_literals(db, class_literal, None) {
return Some(enum_members.collect());
}
None
}
Type::Tuple(tuple_type) => {
// Note: This should only account for tuples of known length, i.e., `tuple[bool, ...]`

View file

@ -2,22 +2,27 @@ use ruff_python_ast::name::Name;
use rustc_hash::FxHashMap;
use crate::{
Db,
Db, FxOrderSet,
place::{Place, PlaceAndQualifiers, place_from_bindings, place_from_declarations},
semantic_index::{place_table, use_def_map},
types::{ClassLiteral, DynamicType, KnownClass, MemberLookupPolicy, Type, TypeQualifiers},
types::{
ClassLiteral, DynamicType, EnumLiteralType, KnownClass, MemberLookupPolicy, Type,
TypeQualifiers,
},
};
#[derive(Debug, PartialEq, Eq, get_size2::GetSize)]
#[derive(Debug, PartialEq, Eq)]
pub(crate) struct EnumMetadata {
pub(crate) members: Box<[Name]>,
pub(crate) members: FxOrderSet<Name>,
pub(crate) aliases: FxHashMap<Name, Name>,
}
impl get_size2::GetSize for EnumMetadata {}
impl EnumMetadata {
fn empty() -> Self {
EnumMetadata {
members: Box::new([]),
members: FxOrderSet::default(),
aliases: FxHashMap::default(),
}
}
@ -48,7 +53,7 @@ fn enum_metadata_cycle_initial(_db: &dyn Db, _class: ClassLiteral<'_>) -> Option
/// List all members of an enum.
#[allow(clippy::ref_option, clippy::unnecessary_wraps)]
#[salsa::tracked(returns(ref), cycle_fn=enum_metadata_cycle_recover, cycle_initial=enum_metadata_cycle_initial, heap_size=get_size2::GetSize::get_heap_size)]
#[salsa::tracked(returns(as_ref), cycle_fn=enum_metadata_cycle_recover, cycle_initial=enum_metadata_cycle_initial, heap_size=get_size2::GetSize::get_heap_size)]
pub(crate) fn enum_metadata<'db>(
db: &'db dyn Db,
class: ClassLiteral<'db>,
@ -208,7 +213,7 @@ pub(crate) fn enum_metadata<'db>(
Some(name)
})
.cloned()
.collect::<Box<_>>();
.collect::<FxOrderSet<_>>();
if members.is_empty() {
// Enum subclasses without members are not considered enums.
@ -217,3 +222,21 @@ pub(crate) fn enum_metadata<'db>(
Some(EnumMetadata { members, aliases })
}
pub(crate) fn enum_member_literals<'a, 'db: 'a>(
db: &'db dyn Db,
class: ClassLiteral<'db>,
exclude_member: Option<&'a Name>,
) -> Option<impl Iterator<Item = Type<'a>> + 'a> {
enum_metadata(db, class).map(|metadata| {
metadata
.members
.iter()
.filter(move |name| Some(*name) != exclude_member)
.map(move |name| Type::EnumLiteral(EnumLiteralType::new(db, class, name.clone())))
})
}
pub(crate) fn is_single_member_enum<'db>(db: &'db dyn Db, class: ClassLiteral<'db>) -> bool {
enum_metadata(db, class).is_some_and(|metadata| metadata.members.len() == 1)
}

View file

@ -6,6 +6,7 @@ use super::protocol_class::ProtocolInterface;
use super::{ClassType, KnownClass, SubclassOfType, Type, TypeVarVariance};
use crate::place::PlaceAndQualifiers;
use crate::types::cyclic::PairVisitor;
use crate::types::enums::is_single_member_enum;
use crate::types::protocol_class::walk_protocol_interface;
use crate::types::tuple::TupleType;
use crate::types::{DynamicType, TypeMapping, TypeRelation, TypeTransformer, TypeVarInstance};
@ -125,12 +126,14 @@ impl<'db> NominalInstanceType<'db> {
pub(super) fn is_singleton(self, db: &'db dyn Db) -> bool {
self.class.known(db).is_some_and(KnownClass::is_singleton)
|| is_single_member_enum(db, self.class.class_literal(db).0)
}
pub(super) fn is_single_valued(self, db: &'db dyn Db) -> bool {
self.class
.known(db)
.is_some_and(KnownClass::is_single_valued)
|| is_single_member_enum(db, self.class.class_literal(db).0)
}
pub(super) fn to_meta_type(self, db: &'db dyn Db) -> Type<'db> {

View file

@ -5,6 +5,7 @@ use crate::semantic_index::place_table;
use crate::semantic_index::predicate::{
CallableAndCallExpr, PatternPredicate, PatternPredicateKind, Predicate, PredicateNode,
};
use crate::types::enums::{enum_member_literals, enum_metadata};
use crate::types::function::KnownFunction;
use crate::types::infer::infer_same_file_expression_type;
use crate::types::{
@ -559,6 +560,17 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
.map(|ty| filter_to_cannot_be_equal(db, ty, rhs_ty)),
)
}
// Treat enums as a union of their members.
Type::NominalInstance(instance)
if enum_metadata(db, instance.class.class_literal(db).0).is_some() =>
{
UnionType::from_elements(
db,
enum_member_literals(db, instance.class.class_literal(db).0, None)
.expect("Calling `enum_member_literals` on an enum class")
.map(|ty| filter_to_cannot_be_equal(db, ty, rhs_ty)),
)
}
_ => {
if ty.is_single_valued(db) && !could_compare_equal(db, ty, rhs_ty) {
ty