diff --git a/crates/ty_python_semantic/resources/mdtest/attributes.md b/crates/ty_python_semantic/resources/mdtest/attributes.md index 488b82886b..f4b2d8c14b 100644 --- a/crates/ty_python_semantic/resources/mdtest/attributes.md +++ b/crates/ty_python_semantic/resources/mdtest/attributes.md @@ -2355,12 +2355,13 @@ import enum reveal_type(enum.Enum.__members__) # revealed: MappingProxyType[str, Unknown] -class Foo(enum.Enum): - BAR = 1 +class Answer(enum.Enum): + NO = 0 + YES = 1 -reveal_type(Foo.BAR) # revealed: Literal[Foo.BAR] -reveal_type(Foo.BAR.value) # revealed: Any -reveal_type(Foo.__members__) # revealed: MappingProxyType[str, Unknown] +reveal_type(Answer.NO) # revealed: Literal[Answer.NO] +reveal_type(Answer.NO.value) # revealed: Any +reveal_type(Answer.__members__) # revealed: MappingProxyType[str, Unknown] ``` ## References diff --git a/crates/ty_python_semantic/resources/mdtest/call/overloads.md b/crates/ty_python_semantic/resources/mdtest/call/overloads.md index 1f41f34087..5ab5fb3360 100644 --- a/crates/ty_python_semantic/resources/mdtest/call/overloads.md +++ b/crates/ty_python_semantic/resources/mdtest/call/overloads.md @@ -369,6 +369,8 @@ def _(x: type[A | B]): ### Expanding enums +#### Basic + `overloaded.pyi`: ```pyi @@ -394,15 +396,106 @@ def f(x: Literal[SomeEnum.C]) -> C: ... ``` ```py +from typing import Literal from overloaded import SomeEnum, A, B, C, f -def _(x: SomeEnum): +def _(x: SomeEnum, y: Literal[SomeEnum.A, SomeEnum.C]): reveal_type(f(SomeEnum.A)) # revealed: A reveal_type(f(SomeEnum.B)) # revealed: B reveal_type(f(SomeEnum.C)) # revealed: C - # TODO: This should not be an error. The return type should be `A | B | C` once enums are expanded - # error: [no-matching-overload] - reveal_type(f(x)) # revealed: Unknown + reveal_type(f(x)) # revealed: A | B | C + reveal_type(f(y)) # revealed: A | C +``` + +#### Enum with single member + +This pattern appears in typeshed. Here, it is used to represent two optional, mutually exclusive +keyword parameters: + +`overloaded.pyi`: + +```pyi +from enum import Enum, auto +from typing import overload, Literal + +class Missing(Enum): + Value = auto() + +class OnlyASpecified: ... +class OnlyBSpecified: ... +class BothMissing: ... + +@overload +def f(*, a: int, b: Literal[Missing.Value] = ...) -> OnlyASpecified: ... +@overload +def f(*, a: Literal[Missing.Value] = ..., b: int) -> OnlyBSpecified: ... +@overload +def f(*, a: Literal[Missing.Value] = ..., b: Literal[Missing.Value] = ...) -> BothMissing: ... +``` + +```py +from typing import Literal +from overloaded import f, Missing + +reveal_type(f()) # revealed: BothMissing +reveal_type(f(a=0)) # revealed: OnlyASpecified +reveal_type(f(b=0)) # revealed: OnlyBSpecified + +f(a=0, b=0) # error: [no-matching-overload] + +def _(missing: Literal[Missing.Value], missing_or_present: Literal[Missing.Value] | int): + reveal_type(f(a=missing, b=missing)) # revealed: BothMissing + reveal_type(f(a=missing)) # revealed: BothMissing + reveal_type(f(b=missing)) # revealed: BothMissing + reveal_type(f(a=0, b=missing)) # revealed: OnlyASpecified + reveal_type(f(a=missing, b=0)) # revealed: OnlyBSpecified + + reveal_type(f(a=missing_or_present)) # revealed: BothMissing | OnlyASpecified + reveal_type(f(b=missing_or_present)) # revealed: BothMissing | OnlyBSpecified + + # Here, both could be present, so this should be an error + f(a=missing_or_present, b=missing_or_present) # error: [no-matching-overload] +``` + +#### Enum subclass without members + +An `Enum` subclass without members should *not* be expanded: + +`overloaded.pyi`: + +```pyi +from enum import Enum +from typing import overload, Literal + +class MyEnumSubclass(Enum): + pass + +class ActualEnum(MyEnumSubclass): + A = 1 + B = 2 + +class OnlyA: ... +class OnlyB: ... +class Both: ... + +@overload +def f(x: Literal[ActualEnum.A]) -> OnlyA: ... +@overload +def f(x: Literal[ActualEnum.B]) -> OnlyB: ... +@overload +def f(x: ActualEnum) -> Both: ... +@overload +def f(x: MyEnumSubclass) -> MyEnumSubclass: ... +``` + +```py +from overloaded import MyEnumSubclass, ActualEnum, f + +def _(actual_enum: ActualEnum, my_enum_instance: MyEnumSubclass): + reveal_type(f(actual_enum)) # revealed: Both + reveal_type(f(ActualEnum.A)) # revealed: OnlyA + reveal_type(f(ActualEnum.B)) # revealed: OnlyB + reveal_type(f(my_enum_instance)) # revealed: MyEnumSubclass ``` ### No matching overloads diff --git a/crates/ty_python_semantic/resources/mdtest/enums.md b/crates/ty_python_semantic/resources/mdtest/enums.md index e6fbf723d5..482fdd4960 100644 --- a/crates/ty_python_semantic/resources/mdtest/enums.md +++ b/crates/ty_python_semantic/resources/mdtest/enums.md @@ -570,7 +570,111 @@ To do: ## Exhaustiveness checking -To do +## `if` statements + +```py +from enum import Enum +from typing_extensions import assert_never + +class Color(Enum): + RED = 1 + GREEN = 2 + BLUE = 3 + +def color_name(color: Color) -> str: + if color is Color.RED: + return "Red" + elif color is Color.GREEN: + return "Green" + elif color is Color.BLUE: + return "Blue" + else: + assert_never(color) + +# No `invalid-return-type` error here because the implicit `else` branch is detected as unreachable: +def color_name_without_assertion(color: Color) -> str: + if color is Color.RED: + return "Red" + elif color is Color.GREEN: + return "Green" + elif color is Color.BLUE: + return "Blue" + +def color_name_misses_one_variant(color: Color) -> str: + if color is Color.RED: + return "Red" + elif color is Color.GREEN: + return "Green" + else: + assert_never(color) # error: [type-assertion-failure] "Argument does not have asserted type `Never`" + +class Singleton(Enum): + VALUE = 1 + +def singleton_check(value: Singleton) -> str: + if value is Singleton.VALUE: + return "Singleton value" + else: + assert_never(value) +``` + +## `match` statements + +```toml +[environment] +python-version = "3.10" +``` + +```py +from enum import Enum +from typing_extensions import assert_never + +class Color(Enum): + RED = 1 + GREEN = 2 + BLUE = 3 + +def color_name(color: Color) -> str: + match color: + case Color.RED: + return "Red" + case Color.GREEN: + return "Green" + case Color.BLUE: + return "Blue" + case _: + assert_never(color) + +# TODO: this should not be an error, see https://github.com/astral-sh/ty/issues/99#issuecomment-2983054488 +# error: [invalid-return-type] "Function can implicitly return `None`, which is not assignable to return type `str`" +def color_name_without_assertion(color: Color) -> str: + match color: + case Color.RED: + return "Red" + case Color.GREEN: + return "Green" + case Color.BLUE: + return "Blue" + +def color_name_misses_one_variant(color: Color) -> str: + match color: + case Color.RED: + return "Red" + case Color.GREEN: + return "Green" + case _: + assert_never(color) # error: [type-assertion-failure] "Argument does not have asserted type `Never`" + +class Singleton(Enum): + VALUE = 1 + +def singleton_check(value: Singleton) -> str: + match value: + case Singleton.VALUE: + return "Singleton value" + case _: + assert_never(value) +``` ## References diff --git a/crates/ty_python_semantic/resources/mdtest/intersection_types.md b/crates/ty_python_semantic/resources/mdtest/intersection_types.md index 66f350bfe9..0f5b37eb88 100644 --- a/crates/ty_python_semantic/resources/mdtest/intersection_types.md +++ b/crates/ty_python_semantic/resources/mdtest/intersection_types.md @@ -763,6 +763,65 @@ def f( reveal_type(j) # revealed: Unknown & Literal[""] ``` +## Simplifications involving enums and enum literals + +```toml +[environment] +python-version = "3.12" +``` + +```py +from ty_extensions import Intersection, Not +from typing import Literal +from enum import Enum + +class Color(Enum): + RED = "red" + GREEN = "green" + BLUE = "blue" + +type Red = Literal[Color.RED] +type Green = Literal[Color.GREEN] +type Blue = Literal[Color.BLUE] + +def f( + a: Intersection[Color, Red], + b: Intersection[Color, Not[Red]], + c: Intersection[Color, Not[Red | Green]], + d: Intersection[Color, Not[Red | Green | Blue]], + e: Intersection[Red, Not[Color]], + f: Intersection[Red | Green, Not[Color]], + g: Intersection[Not[Red], Color], + h: Intersection[Red, Green], + i: Intersection[Red | Green, Green | Blue], +): + reveal_type(a) # revealed: Literal[Color.RED] + reveal_type(b) # revealed: Literal[Color.GREEN, Color.BLUE] + reveal_type(c) # revealed: Literal[Color.BLUE] + reveal_type(d) # revealed: Never + reveal_type(e) # revealed: Never + reveal_type(f) # revealed: Never + reveal_type(g) # revealed: Literal[Color.GREEN, Color.BLUE] + reveal_type(h) # revealed: Never + reveal_type(i) # revealed: Literal[Color.GREEN] + +class Single(Enum): + VALUE = 0 + +def g( + a: Intersection[Single, Literal[Single.VALUE]], + b: Intersection[Single, Not[Literal[Single.VALUE]]], + c: Intersection[Not[Literal[Single.VALUE]], Single], + d: Intersection[Single, Not[Single]], + e: Intersection[Single | int, Not[Single]], +): + reveal_type(a) # revealed: Single + reveal_type(b) # revealed: Never + reveal_type(c) # revealed: Never + reveal_type(d) # revealed: Never + reveal_type(e) # revealed: int +``` + ## Addition of a type to an intersection with many non-disjoint types This slightly strange-looking test is a regression test for a mistake that was nearly made in a PR: diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/eq.md b/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/eq.md index fb1943c0fe..6456b7764b 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/eq.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/eq.md @@ -1,4 +1,4 @@ -# Narrowing for `!=` conditionals +# Narrowing for `!=` and `==` conditionals ## `x != None` @@ -22,6 +22,12 @@ def _(x: bool): reveal_type(x) # revealed: Literal[True] else: reveal_type(x) # revealed: Literal[False] + +def _(x: bool): + if x == False: + reveal_type(x) # revealed: Literal[False] + else: + reveal_type(x) # revealed: Literal[True] ``` ### Enums @@ -35,11 +41,31 @@ class Answer(Enum): def _(answer: Answer): if answer != Answer.NO: - # TODO: This should be simplified to `Literal[Answer.YES]` - reveal_type(answer) # revealed: Answer & ~Literal[Answer.NO] + reveal_type(answer) # revealed: Literal[Answer.YES] else: - # TODO: This should be `Literal[Answer.NO]` - reveal_type(answer) # revealed: Answer + reveal_type(answer) # revealed: Literal[Answer.NO] + +def _(answer: Answer): + if answer == Answer.NO: + reveal_type(answer) # revealed: Literal[Answer.NO] + else: + reveal_type(answer) # revealed: Literal[Answer.YES] + +class Single(Enum): + VALUE = 1 + +def _(x: Single | int): + if x != Single.VALUE: + reveal_type(x) # revealed: int + else: + # `int` is not eliminated here because there could be subclasses of `int` with custom `__eq__`/`__ne__` methods + reveal_type(x) # revealed: Single | int + +def _(x: Single | int): + if x == Single.VALUE: + reveal_type(x) # revealed: Single | int + else: + reveal_type(x) # revealed: int ``` This narrowing behavior is only safe if the enum has no custom `__eq__`/`__ne__` method: diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/is.md b/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/is.md index 8736ee40fd..e3c40104f1 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/is.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/is.md @@ -78,8 +78,16 @@ def _(answer: Answer): if answer is Answer.NO: reveal_type(answer) # revealed: Literal[Answer.NO] else: - # TODO: This should be `Literal[Answer.YES]` - reveal_type(answer) # revealed: Answer & ~Literal[Answer.NO] + reveal_type(answer) # revealed: Literal[Answer.YES] + +class Single(Enum): + VALUE = 1 + +def _(x: Single | int): + if x is Single.VALUE: + reveal_type(x) # revealed: Single + else: + reveal_type(x) # revealed: int ``` ## `is` for `EllipsisType` (Python 3.10+) diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/is_not.md b/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/is_not.md index fba62e9213..0c4e87c36c 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/is_not.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/is_not.md @@ -18,6 +18,8 @@ def _(flag: bool): ## `is not` for other singleton types +Boolean literals: + ```py def _(flag: bool): x = True if flag else False @@ -29,6 +31,33 @@ def _(flag: bool): reveal_type(x) # revealed: Literal[False] ``` +Enum literals: + +```py +from enum import Enum + +class Answer(Enum): + NO = 0 + YES = 1 + +def _(answer: Answer): + if answer is not Answer.NO: + reveal_type(answer) # revealed: Literal[Answer.YES] + else: + reveal_type(answer) # revealed: Literal[Answer.NO] + + reveal_type(answer) # revealed: Answer + +class Single(Enum): + VALUE = 1 + +def _(x: Single | int): + if x is not Single.VALUE: + reveal_type(x) # revealed: int + else: + reveal_type(x) # revealed: Single +``` + ## `is not` for non-singleton types Non-singleton types should *not* narrow the type: two instances of a non-singleton class may occupy diff --git a/crates/ty_python_semantic/resources/mdtest/type_properties/is_assignable_to.md b/crates/ty_python_semantic/resources/mdtest/type_properties/is_assignable_to.md index c499773536..028a411c51 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_properties/is_assignable_to.md +++ b/crates/ty_python_semantic/resources/mdtest/type_properties/is_assignable_to.md @@ -138,11 +138,15 @@ class Answer(Enum): static_assert(is_assignable_to(Literal[Answer.YES], Literal[Answer.YES])) static_assert(is_assignable_to(Literal[Answer.YES], Answer)) static_assert(is_assignable_to(Literal[Answer.YES, Answer.NO], Answer)) -# TODO: this should not be an error -# error: [static-assert-error] static_assert(is_assignable_to(Answer, Literal[Answer.YES, Answer.NO])) static_assert(not is_assignable_to(Literal[Answer.YES], Literal[Answer.NO])) + +class Single(Enum): + VALUE = 1 + +static_assert(is_assignable_to(Literal[Single.VALUE], Single)) +static_assert(is_assignable_to(Single, Literal[Single.VALUE])) ``` ### Slice literals diff --git a/crates/ty_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md b/crates/ty_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md index d2815b3cdb..678d1c1a11 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md +++ b/crates/ty_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md @@ -21,6 +21,9 @@ class Answer(Enum): NO = 0 YES = 1 +class Single(Enum): + VALUE = 1 + static_assert(is_equivalent_to(Literal[1, 2], Literal[1, 2])) static_assert(is_equivalent_to(type[object], type)) static_assert(is_equivalent_to(type, type[object])) @@ -31,12 +34,15 @@ static_assert(not is_equivalent_to(Literal[1, 2], Literal[1, 2, 3])) static_assert(not is_equivalent_to(Literal[1, 2, 3], Literal[1, 2])) static_assert(is_equivalent_to(Literal[Answer.YES], Literal[Answer.YES])) -# TODO: these should be equivalent -# error: [static-assert-error] +static_assert(is_equivalent_to(Literal[Answer.NO, Answer.YES], Answer)) static_assert(is_equivalent_to(Literal[Answer.YES, Answer.NO], Answer)) static_assert(not is_equivalent_to(Literal[Answer.YES], Literal[Answer.NO])) static_assert(not is_equivalent_to(Literal[Answer.YES], Answer)) +static_assert(is_equivalent_to(Literal[Single.VALUE], Single)) +static_assert(is_equivalent_to(Single, Literal[Single.VALUE])) +static_assert(is_equivalent_to(Literal[Single.VALUE], Literal[Single.VALUE])) + static_assert(is_equivalent_to(Never, Never)) static_assert(is_equivalent_to(AlwaysTruthy, AlwaysTruthy)) static_assert(is_equivalent_to(AlwaysFalsy, AlwaysFalsy)) @@ -69,8 +75,9 @@ static_assert(not is_equivalent_to(type[object], type[Any])) ## Unions and intersections ```py -from typing import Any +from typing import Any, Literal from ty_extensions import Intersection, Not, Unknown, is_equivalent_to, static_assert +from enum import Enum static_assert(is_equivalent_to(str | int, str | int)) static_assert(is_equivalent_to(str | int | Any, str | int | Unknown)) @@ -111,6 +118,11 @@ static_assert(is_equivalent_to(Intersection[P, Q], Intersection[Q, P])) static_assert(is_equivalent_to(Intersection[Q, Not[P]], Intersection[Not[P], Q])) static_assert(is_equivalent_to(Intersection[Q, R, Not[P]], Intersection[Not[P], R, Q])) static_assert(is_equivalent_to(Intersection[Q | R, Not[P | S]], Intersection[Not[S | P], R | Q])) + +class Single(Enum): + VALUE = 1 + +static_assert(is_equivalent_to(P | Q | Single, Literal[Single.VALUE] | Q | P)) ``` ## Tuples diff --git a/crates/ty_python_semantic/resources/mdtest/type_properties/is_single_valued.md b/crates/ty_python_semantic/resources/mdtest/type_properties/is_single_valued.md index 5c97c2524c..ae070df864 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_properties/is_single_valued.md +++ b/crates/ty_python_semantic/resources/mdtest/type_properties/is_single_valued.md @@ -47,6 +47,9 @@ class NormalEnum(Enum): NO = 0 YES = 1 +class SingleValuedEnum(Enum): + VALUE = 1 + class ComparesEqualEnum(Enum): NO = 0 YES = 1 @@ -70,13 +73,20 @@ class CustomNeEnum(Enum): static_assert(is_single_valued(Literal[NormalEnum.NO])) static_assert(is_single_valued(Literal[NormalEnum.YES])) +static_assert(not is_single_valued(NormalEnum)) + +static_assert(is_single_valued(Literal[SingleValuedEnum.VALUE])) +static_assert(is_single_valued(SingleValuedEnum)) static_assert(is_single_valued(Literal[ComparesEqualEnum.NO])) static_assert(is_single_valued(Literal[ComparesEqualEnum.YES])) +static_assert(not is_single_valued(ComparesEqualEnum)) static_assert(not is_single_valued(Literal[CustomEqEnum.NO])) static_assert(not is_single_valued(Literal[CustomEqEnum.YES])) +static_assert(not is_single_valued(CustomEqEnum)) static_assert(not is_single_valued(Literal[CustomNeEnum.NO])) static_assert(not is_single_valued(Literal[CustomNeEnum.YES])) +static_assert(not is_single_valued(CustomNeEnum)) ``` diff --git a/crates/ty_python_semantic/resources/mdtest/type_properties/is_singleton.md b/crates/ty_python_semantic/resources/mdtest/type_properties/is_singleton.md index c50a968169..a6b87194c9 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_properties/is_singleton.md +++ b/crates/ty_python_semantic/resources/mdtest/type_properties/is_singleton.md @@ -13,11 +13,16 @@ class Answer(Enum): NO = 0 YES = 1 +class Single(Enum): + VALUE = 1 + static_assert(is_singleton(None)) static_assert(is_singleton(Literal[True])) static_assert(is_singleton(Literal[False])) static_assert(is_singleton(Literal[Answer.YES])) static_assert(is_singleton(Literal[Answer.NO])) +static_assert(is_singleton(Literal[Single.VALUE])) +static_assert(is_singleton(Single)) static_assert(is_singleton(type[bool])) diff --git a/crates/ty_python_semantic/resources/mdtest/type_properties/is_subtype_of.md b/crates/ty_python_semantic/resources/mdtest/type_properties/is_subtype_of.md index 3d46d11a47..b3201da5e2 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_properties/is_subtype_of.md +++ b/crates/ty_python_semantic/resources/mdtest/type_properties/is_subtype_of.md @@ -96,6 +96,9 @@ class Answer(Enum): NO = 0 YES = 1 +class Single(Enum): + VALUE = 1 + # Boolean literals static_assert(is_subtype_of(Literal[True], bool)) static_assert(is_subtype_of(Literal[True], int)) @@ -125,11 +128,12 @@ static_assert(is_subtype_of(Literal[b"foo"], object)) static_assert(is_subtype_of(Literal[Answer.YES], Literal[Answer.YES])) static_assert(is_subtype_of(Literal[Answer.YES], Answer)) static_assert(is_subtype_of(Literal[Answer.YES, Answer.NO], Answer)) -# TODO: this should not be an error -# error: [static-assert-error] static_assert(is_subtype_of(Answer, Literal[Answer.YES, Answer.NO])) static_assert(not is_subtype_of(Literal[Answer.YES], Literal[Answer.NO])) + +static_assert(is_subtype_of(Literal[Single.VALUE], Single)) +static_assert(is_subtype_of(Single, Literal[Single.VALUE])) ``` ## Heterogeneous tuple types diff --git a/crates/ty_python_semantic/resources/mdtest/union_types.md b/crates/ty_python_semantic/resources/mdtest/union_types.md index 919fd4921b..52a77028c7 100644 --- a/crates/ty_python_semantic/resources/mdtest/union_types.md +++ b/crates/ty_python_semantic/resources/mdtest/union_types.md @@ -114,6 +114,33 @@ def _( reveal_type(u5) # revealed: bool | Literal[17] ``` +## Enum literals + +```py +from enum import Enum +from typing import Literal + +class Color(Enum): + RED = "red" + GREEN = "green" + BLUE = "blue" + +def _( + u1: Literal[Color.RED, Color.GREEN], + u2: Color | Literal[Color.RED], + u3: Literal[Color.RED] | Color, + u4: Literal[Color.RED] | Literal[Color.RED, Color.GREEN], + u5: Literal[Color.RED, Color.GREEN, Color.BLUE], + u6: Literal[Color.RED] | Literal[Color.GREEN] | Literal[Color.BLUE], +) -> None: + reveal_type(u1) # revealed: Literal[Color.RED, Color.GREEN] + reveal_type(u2) # revealed: Color + reveal_type(u3) # revealed: Color + reveal_type(u4) # revealed: Literal[Color.RED, Color.GREEN] + reveal_type(u5) # revealed: Color + reveal_type(u6) # revealed: Color +``` + ## Do not erase `Unknown` ```py diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index badb8d7b40..19be47b8ce 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -39,7 +39,7 @@ use crate::types::call::{Binding, Bindings, CallArguments, CallableBinding}; pub(crate) use crate::types::class_base::ClassBase; use crate::types::context::{LintDiagnosticGuard, LintDiagnosticGuardBuilder}; use crate::types::diagnostic::{INVALID_TYPE_FORM, UNSUPPORTED_BOOL_CONVERSION}; -use crate::types::enums::enum_metadata; +use crate::types::enums::{enum_metadata, is_single_member_enum}; use crate::types::function::{ DataclassTransformerParams, FunctionSpans, FunctionType, KnownFunction, }; @@ -789,6 +789,19 @@ impl<'db> Type<'db> { matches!(self, Type::ClassLiteral(..)) } + pub fn into_enum_literal(self) -> Option> { + match self { + Type::EnumLiteral(enum_literal) => Some(enum_literal), + _ => None, + } + } + + #[track_caller] + pub fn expect_enum_literal(self) -> EnumLiteralType<'db> { + self.into_enum_literal() + .expect("Expected a Type::EnumLiteral variant") + } + pub(crate) const fn into_tuple(self) -> Option> { match self { Type::Tuple(tuple_type) => Some(tuple_type), @@ -1420,6 +1433,16 @@ impl<'db> Type<'db> { // All `StringLiteral` types are a subtype of `LiteralString`. (Type::StringLiteral(_), Type::LiteralString) => true, + // An instance is a subtype of an enum literal, if it is an instance of the enum class + // and the enum has only one member. + (Type::NominalInstance(_), Type::EnumLiteral(target_enum_literal)) => { + if target_enum_literal.enum_class_instance(db) != self { + return false; + } + + is_single_member_enum(db, target_enum_literal.enum_class(db)) + } + // Except for the special `LiteralString` case above, // most `Literal` types delegate to their instance fallbacks // unless `self` is exactly equivalent to `target` (handled above) @@ -1656,6 +1679,17 @@ impl<'db> Type<'db> { | (nominal @ Type::NominalInstance(n), Type::ProtocolInstance(protocol)) => { n.class.is_object(db) && protocol.normalized(db) == nominal } + // An instance of an enum class is equivalent to an enum literal of that class, + // if that enum has only has one member. + (Type::NominalInstance(instance), Type::EnumLiteral(literal)) + | (Type::EnumLiteral(literal), Type::NominalInstance(instance)) => { + if literal.enum_class_instance(db) != Type::NominalInstance(instance) { + return false; + } + + let class_literal = instance.class.class_literal(db).0; + is_single_member_enum(db, class_literal) + } _ => false, } } @@ -8409,6 +8443,7 @@ pub struct EnumLiteralType<'db> { /// A reference to the enum class this literal belongs to enum_class: ClassLiteral<'db>, /// The name of the enum member + #[returns(ref)] name: Name, } diff --git a/crates/ty_python_semantic/src/types/builder.rs b/crates/ty_python_semantic/src/types/builder.rs index 7087677322..a3f16c6178 100644 --- a/crates/ty_python_semantic/src/types/builder.rs +++ b/crates/ty_python_semantic/src/types/builder.rs @@ -37,6 +37,7 @@ //! are subtypes of each other (unless exactly the same literal type), we can avoid many //! unnecessary `is_subtype_of` checks. +use crate::types::enums::{enum_member_literals, enum_metadata}; use crate::types::{ BytesLiteralType, IntersectionType, KnownClass, StringLiteralType, Type, TypeVarBoundOrConstraints, UnionType, @@ -87,6 +88,13 @@ enum UnionElement<'db> { } impl<'db> UnionElement<'db> { + const fn to_type_element(&self) -> Option> { + match self { + UnionElement::Type(ty) => Some(*ty), + _ => None, + } + } + /// Try reducing this `UnionElement` given the presence in the same union of `other_type`. fn try_reduce(&mut self, db: &'db dyn Db, other_type: Type<'db>) -> ReduceResult<'db> { match self { @@ -374,6 +382,38 @@ impl<'db> UnionBuilder<'db> { self.elements.swap_remove(index); } } + Type::EnumLiteral(enum_member_to_add) => { + let enum_class = enum_member_to_add.enum_class(self.db); + let metadata = + enum_metadata(self.db, enum_class).expect("Class of enum literal is an enum"); + + let enum_members_in_union = self + .elements + .iter() + .filter_map(UnionElement::to_type_element) + .filter_map(Type::into_enum_literal) + .map(|literal| literal.name(self.db).clone()) + .chain(std::iter::once(enum_member_to_add.name(self.db).clone())) + .collect::>(); + + let all_members_are_in_union = metadata + .members + .difference(&enum_members_in_union) + .next() + .is_none(); + + if all_members_are_in_union { + self.add_in_place(enum_member_to_add.enum_class_instance(self.db)); + } else if !self + .elements + .iter() + .filter_map(UnionElement::to_type_element) + .any(|ty| Type::EnumLiteral(enum_member_to_add).is_subtype_of(self.db, ty)) + { + self.elements + .push(UnionElement::Type(Type::EnumLiteral(enum_member_to_add))); + } + } // Adding `object` to a union results in `object`. ty if ty.is_object(self.db) => { self.collapse_to_object(); @@ -501,72 +541,147 @@ impl<'db> IntersectionBuilder<'db> { } pub(crate) fn add_positive(mut self, ty: Type<'db>) -> Self { - if let Type::Union(union) = ty { - // Distribute ourself over this union: for each union element, clone ourself and - // intersect with that union element, then create a new union-of-intersections with all - // of those sub-intersections in it. E.g. if `self` is a simple intersection `T1 & T2` - // and we add `T3 | T4` to the intersection, we don't get `T1 & T2 & (T3 | T4)` (that's - // not in DNF), we distribute the union and get `(T1 & T3) | (T2 & T3) | (T1 & T4) | - // (T2 & T4)`. If `self` is already a union-of-intersections `(T1 & T2) | (T3 & T4)` - // and we add `T5 | T6` to it, that flattens all the way out to `(T1 & T2 & T5) | (T1 & - // T2 & T6) | (T3 & T4 & T5) ...` -- you get the idea. - union - .elements(self.db) - .iter() - .map(|elem| self.clone().add_positive(*elem)) - .fold(IntersectionBuilder::empty(self.db), |mut builder, sub| { - builder.intersections.extend(sub.intersections); - builder - }) - } else { - // If we are already a union-of-intersections, distribute the new intersected element - // across all of those intersections. - for inner in &mut self.intersections { - inner.add_positive(self.db, ty); + match ty { + Type::Union(union) => { + // Distribute ourself over this union: for each union element, clone ourself and + // intersect with that union element, then create a new union-of-intersections with all + // of those sub-intersections in it. E.g. if `self` is a simple intersection `T1 & T2` + // and we add `T3 | T4` to the intersection, we don't get `T1 & T2 & (T3 | T4)` (that's + // not in DNF), we distribute the union and get `(T1 & T3) | (T2 & T3) | (T1 & T4) | + // (T2 & T4)`. If `self` is already a union-of-intersections `(T1 & T2) | (T3 & T4)` + // and we add `T5 | T6` to it, that flattens all the way out to `(T1 & T2 & T5) | (T1 & + // T2 & T6) | (T3 & T4 & T5) ...` -- you get the idea. + union + .elements(self.db) + .iter() + .map(|elem| self.clone().add_positive(*elem)) + .fold(IntersectionBuilder::empty(self.db), |mut builder, sub| { + builder.intersections.extend(sub.intersections); + builder + }) + } + // `(A & B & ~C) & (D & E & ~F)` -> `A & B & D & E & ~C & ~F` + Type::Intersection(other) => { + let db = self.db; + for pos in other.positive(db) { + self = self.add_positive(*pos); + } + for neg in other.negative(db) { + self = self.add_negative(*neg); + } + self + } + Type::NominalInstance(instance) + if enum_metadata(self.db, instance.class.class_literal(self.db).0).is_some() => + { + let mut contains_enum_literal_as_negative_element = false; + for intersection in &self.intersections { + if intersection.negative.iter().any(|negative| { + negative + .into_enum_literal() + .is_some_and(|lit| lit.enum_class_instance(self.db) == ty) + }) { + contains_enum_literal_as_negative_element = true; + break; + } + } + + if contains_enum_literal_as_negative_element { + // If we have an enum literal of this enum already in the negative side of + // the intersection, expand the instance into the union of enum members, and + // add that union to the intersection. + // Note: we manually construct a `UnionType` here instead of going through + // `UnionBuilder` because we would simplify the union to just the enum instance + // and end up in this branch again. + let db = self.db; + self.add_positive(Type::Union(UnionType::new( + db, + enum_member_literals(db, instance.class.class_literal(db).0, None) + .expect("Calling `enum_member_literals` on an enum class") + .collect::>(), + ))) + } else { + for inner in &mut self.intersections { + inner.add_positive(self.db, ty); + } + self + } + } + _ => { + // If we are already a union-of-intersections, distribute the new intersected element + // across all of those intersections. + for inner in &mut self.intersections { + inner.add_positive(self.db, ty); + } + self } - self } } pub(crate) fn add_negative(mut self, ty: Type<'db>) -> Self { + let contains_enum = |enum_instance| { + self.intersections + .iter() + .flat_map(|intersection| &intersection.positive) + .any(|ty| *ty == enum_instance) + }; + // See comments above in `add_positive`; this is just the negated version. - if let Type::Union(union) = ty { - for elem in union.elements(self.db) { - self = self.add_negative(*elem); + match ty { + Type::Union(union) => { + for elem in union.elements(self.db) { + self = self.add_negative(*elem); + } + self } - self - } else if let Type::Intersection(intersection) = ty { - // (A | B) & ~(C & ~D) - // -> (A | B) & (~C | D) - // -> ((A | B) & ~C) | ((A | B) & D) - // i.e. if we have an intersection of positive constraints C - // and negative constraints D, then our new intersection - // is (existing & ~C) | (existing & D) + Type::Intersection(intersection) => { + // (A | B) & ~(C & ~D) + // -> (A | B) & (~C | D) + // -> ((A | B) & ~C) | ((A | B) & D) + // i.e. if we have an intersection of positive constraints C + // and negative constraints D, then our new intersection + // is (existing & ~C) | (existing & D) - let positive_side = intersection - .positive(self.db) - .iter() - // we negate all the positive constraints while distributing - .map(|elem| self.clone().add_negative(*elem)); + let positive_side = intersection + .positive(self.db) + .iter() + // we negate all the positive constraints while distributing + .map(|elem| self.clone().add_negative(*elem)); - let negative_side = intersection - .negative(self.db) - .iter() - // all negative constraints end up becoming positive constraints - .map(|elem| self.clone().add_positive(*elem)); + let negative_side = intersection + .negative(self.db) + .iter() + // all negative constraints end up becoming positive constraints + .map(|elem| self.clone().add_positive(*elem)); - positive_side.chain(negative_side).fold( - IntersectionBuilder::empty(self.db), - |mut builder, sub| { - builder.intersections.extend(sub.intersections); - builder - }, - ) - } else { - for inner in &mut self.intersections { - inner.add_negative(self.db, ty); + positive_side.chain(negative_side).fold( + IntersectionBuilder::empty(self.db), + |mut builder, sub| { + builder.intersections.extend(sub.intersections); + builder + }, + ) + } + Type::EnumLiteral(enum_literal) + if contains_enum(enum_literal.enum_class_instance(self.db)) => + { + let db = self.db; + self.add_positive(UnionType::from_elements( + db, + enum_member_literals( + db, + enum_literal.enum_class(db), + Some(enum_literal.name(db)), + ) + .expect("Calling `enum_member_literals` on an enum class"), + )) + } + _ => { + for inner in &mut self.intersections { + inner.add_negative(self.db, ty); + } + self } - self } } @@ -643,15 +758,7 @@ impl<'db> InnerIntersectionBuilder<'db> { self.add_positive(db, Type::LiteralString); self.add_negative(db, Type::string_literal(db, "")); } - // `(A & B & ~C) & (D & E & ~F)` -> `A & B & D & E & ~C & ~F` - Type::Intersection(other) => { - for pos in other.positive(db) { - self.add_positive(db, *pos); - } - for neg in other.negative(db) { - self.add_negative(db, *neg); - } - } + _ => { let known_instance = new_positive .into_nominal_instance() @@ -961,7 +1068,10 @@ impl<'db> InnerIntersectionBuilder<'db> { mod tests { use super::{IntersectionBuilder, Type, UnionBuilder, UnionType}; + use crate::KnownModule; use crate::db::tests::setup_db; + use crate::place::known_module_symbol; + use crate::types::enums::enum_member_literals; use crate::types::{KnownClass, Truthiness}; use test_case::test_case; @@ -1044,4 +1154,77 @@ mod tests { .build(); assert_eq!(ty, Type::BooleanLiteral(!bool_value)); } + + #[test] + fn build_intersection_enums() { + let db = setup_db(); + + let safe_uuid_class = known_module_symbol(&db, KnownModule::Uuid, "SafeUUID") + .place + .ignore_possibly_unbound() + .unwrap(); + + let literals = enum_member_literals(&db, safe_uuid_class.expect_class_literal(), None) + .unwrap() + .collect::>(); + assert_eq!(literals.len(), 3); + + // SafeUUID.safe + let l_safe = literals[0]; + assert_eq!(l_safe.expect_enum_literal().name(&db), "safe"); + // SafeUUID.unsafe + let l_unsafe = literals[1]; + assert_eq!(l_unsafe.expect_enum_literal().name(&db), "unsafe"); + // SafeUUID.unknown + let l_unknown = literals[2]; + assert_eq!(l_unknown.expect_enum_literal().name(&db), "unknown"); + + // The enum itself: SafeUUID + let safe_uuid = l_safe.expect_enum_literal().enum_class_instance(&db); + + { + let actual = IntersectionBuilder::new(&db) + .add_positive(safe_uuid) + .add_negative(l_safe) + .build(); + + assert_eq!( + actual.display(&db).to_string(), + "Literal[SafeUUID.unsafe, SafeUUID.unknown]" + ); + } + { + // Same as above, but with the order reversed + let actual = IntersectionBuilder::new(&db) + .add_negative(l_safe) + .add_positive(safe_uuid) + .build(); + + assert_eq!( + actual.display(&db).to_string(), + "Literal[SafeUUID.unsafe, SafeUUID.unknown]" + ); + } + { + // Also the same, but now with a nested intersection + let actual = IntersectionBuilder::new(&db) + .add_positive(safe_uuid) + .add_positive(IntersectionBuilder::new(&db).add_negative(l_safe).build()) + .build(); + + assert_eq!( + actual.display(&db).to_string(), + "Literal[SafeUUID.unsafe, SafeUUID.unknown]" + ); + } + { + let actual = IntersectionBuilder::new(&db) + .add_negative(l_safe) + .add_positive(safe_uuid) + .add_negative(l_unsafe) + .build(); + + assert_eq!(actual.display(&db).to_string(), "Literal[SafeUUID.unknown]"); + } + } } diff --git a/crates/ty_python_semantic/src/types/call/arguments.rs b/crates/ty_python_semantic/src/types/call/arguments.rs index 89c3cf0112..c463d0a982 100644 --- a/crates/ty_python_semantic/src/types/call/arguments.rs +++ b/crates/ty_python_semantic/src/types/call/arguments.rs @@ -5,6 +5,7 @@ use ruff_python_ast as ast; use crate::Db; use crate::types::KnownClass; +use crate::types::enums::enum_member_literals; use crate::types::tuple::{TupleSpec, TupleType}; use super::Type; @@ -199,13 +200,22 @@ impl<'a, 'db> FromIterator<(Argument<'a>, Option>)> for CallArguments< /// /// Returns [`None`] if the type cannot be expanded. fn expand_type<'db>(db: &'db dyn Db, ty: Type<'db>) -> Option>> { - // TODO: Expand enums to their variants match ty { - Type::NominalInstance(instance) if instance.class.is_known(db, KnownClass::Bool) => { - Some(vec![ - Type::BooleanLiteral(true), - Type::BooleanLiteral(false), - ]) + Type::NominalInstance(instance) => { + if instance.class.is_known(db, KnownClass::Bool) { + return Some(vec![ + Type::BooleanLiteral(true), + Type::BooleanLiteral(false), + ]); + } + + let class_literal = instance.class.class_literal(db).0; + + if let Some(enum_members) = enum_member_literals(db, class_literal, None) { + return Some(enum_members.collect()); + } + + None } Type::Tuple(tuple_type) => { // Note: This should only account for tuples of known length, i.e., `tuple[bool, ...]` diff --git a/crates/ty_python_semantic/src/types/enums.rs b/crates/ty_python_semantic/src/types/enums.rs index 85b3c56094..59c814c147 100644 --- a/crates/ty_python_semantic/src/types/enums.rs +++ b/crates/ty_python_semantic/src/types/enums.rs @@ -2,22 +2,27 @@ use ruff_python_ast::name::Name; use rustc_hash::FxHashMap; use crate::{ - Db, + Db, FxOrderSet, place::{Place, PlaceAndQualifiers, place_from_bindings, place_from_declarations}, semantic_index::{place_table, use_def_map}, - types::{ClassLiteral, DynamicType, KnownClass, MemberLookupPolicy, Type, TypeQualifiers}, + types::{ + ClassLiteral, DynamicType, EnumLiteralType, KnownClass, MemberLookupPolicy, Type, + TypeQualifiers, + }, }; -#[derive(Debug, PartialEq, Eq, get_size2::GetSize)] +#[derive(Debug, PartialEq, Eq)] pub(crate) struct EnumMetadata { - pub(crate) members: Box<[Name]>, + pub(crate) members: FxOrderSet, pub(crate) aliases: FxHashMap, } +impl get_size2::GetSize for EnumMetadata {} + impl EnumMetadata { fn empty() -> Self { EnumMetadata { - members: Box::new([]), + members: FxOrderSet::default(), aliases: FxHashMap::default(), } } @@ -48,7 +53,7 @@ fn enum_metadata_cycle_initial(_db: &dyn Db, _class: ClassLiteral<'_>) -> Option /// List all members of an enum. #[allow(clippy::ref_option, clippy::unnecessary_wraps)] -#[salsa::tracked(returns(ref), cycle_fn=enum_metadata_cycle_recover, cycle_initial=enum_metadata_cycle_initial, heap_size=get_size2::GetSize::get_heap_size)] +#[salsa::tracked(returns(as_ref), cycle_fn=enum_metadata_cycle_recover, cycle_initial=enum_metadata_cycle_initial, heap_size=get_size2::GetSize::get_heap_size)] pub(crate) fn enum_metadata<'db>( db: &'db dyn Db, class: ClassLiteral<'db>, @@ -208,7 +213,7 @@ pub(crate) fn enum_metadata<'db>( Some(name) }) .cloned() - .collect::>(); + .collect::>(); if members.is_empty() { // Enum subclasses without members are not considered enums. @@ -217,3 +222,21 @@ pub(crate) fn enum_metadata<'db>( Some(EnumMetadata { members, aliases }) } + +pub(crate) fn enum_member_literals<'a, 'db: 'a>( + db: &'db dyn Db, + class: ClassLiteral<'db>, + exclude_member: Option<&'a Name>, +) -> Option> + 'a> { + enum_metadata(db, class).map(|metadata| { + metadata + .members + .iter() + .filter(move |name| Some(*name) != exclude_member) + .map(move |name| Type::EnumLiteral(EnumLiteralType::new(db, class, name.clone()))) + }) +} + +pub(crate) fn is_single_member_enum<'db>(db: &'db dyn Db, class: ClassLiteral<'db>) -> bool { + enum_metadata(db, class).is_some_and(|metadata| metadata.members.len() == 1) +} diff --git a/crates/ty_python_semantic/src/types/instance.rs b/crates/ty_python_semantic/src/types/instance.rs index c92885adeb..3959b041e1 100644 --- a/crates/ty_python_semantic/src/types/instance.rs +++ b/crates/ty_python_semantic/src/types/instance.rs @@ -6,6 +6,7 @@ use super::protocol_class::ProtocolInterface; use super::{ClassType, KnownClass, SubclassOfType, Type, TypeVarVariance}; use crate::place::PlaceAndQualifiers; use crate::types::cyclic::PairVisitor; +use crate::types::enums::is_single_member_enum; use crate::types::protocol_class::walk_protocol_interface; use crate::types::tuple::TupleType; use crate::types::{DynamicType, TypeMapping, TypeRelation, TypeTransformer, TypeVarInstance}; @@ -125,12 +126,14 @@ impl<'db> NominalInstanceType<'db> { pub(super) fn is_singleton(self, db: &'db dyn Db) -> bool { self.class.known(db).is_some_and(KnownClass::is_singleton) + || is_single_member_enum(db, self.class.class_literal(db).0) } pub(super) fn is_single_valued(self, db: &'db dyn Db) -> bool { self.class .known(db) .is_some_and(KnownClass::is_single_valued) + || is_single_member_enum(db, self.class.class_literal(db).0) } pub(super) fn to_meta_type(self, db: &'db dyn Db) -> Type<'db> { diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index 1838255223..d1c0f9cef7 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -5,6 +5,7 @@ use crate::semantic_index::place_table; use crate::semantic_index::predicate::{ CallableAndCallExpr, PatternPredicate, PatternPredicateKind, Predicate, PredicateNode, }; +use crate::types::enums::{enum_member_literals, enum_metadata}; use crate::types::function::KnownFunction; use crate::types::infer::infer_same_file_expression_type; use crate::types::{ @@ -559,6 +560,17 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { .map(|ty| filter_to_cannot_be_equal(db, ty, rhs_ty)), ) } + // Treat enums as a union of their members. + Type::NominalInstance(instance) + if enum_metadata(db, instance.class.class_literal(db).0).is_some() => + { + UnionType::from_elements( + db, + enum_member_literals(db, instance.class.class_literal(db).0, None) + .expect("Calling `enum_member_literals` on an enum class") + .map(|ty| filter_to_cannot_be_equal(db, ty, rhs_ty)), + ) + } _ => { if ty.is_single_valued(db) && !could_compare_equal(db, ty, rhs_ty) { ty