[ty] Expansion of enums into unions of literals (#19382)

## Summary

Implement expansion of enums into unions of enum literals (and the
reverse operation). For the enum below, this allows us to understand
that `Color = Literal[Color.RED, Color.GREEN, Color.BLUE]`, or that
`Color & ~Literal[Color.RED] = Literal[Color.GREEN, Color.BLUE]`. This
helps in exhaustiveness checking, which is why we see some removed
`assert_never` false positives. And since exhaustiveness checking also
helps with understanding terminal control flow, we also see a few
removed `invalid-return-type` and `possibly-unresolved-reference` false
positives. This PR also adds expansion of enums in overload resolution
and type narrowing constructs.

```py
from enum import Enum
from typing_extensions import Literal, assert_never
from ty_extensions import Intersection, Not, static_assert, is_equivalent_to

class Color(Enum):
    RED = 1
    GREEN = 2
    BLUE = 3

type Red = Literal[Color.RED]
type Green = Literal[Color.GREEN]
type Blue = Literal[Color.BLUE]

static_assert(is_equivalent_to(Red | Green | Blue, Color))
static_assert(is_equivalent_to(Intersection[Color, Not[Red]], Green | Blue))


def color_name(color: Color) -> str:  # no error here (we detect that this can not implicitly return None)
    if color is Color.RED:
        return "Red"
    elif color is Color.GREEN:
        return "Green"
    elif color is Color.BLUE:
        return "Blue"
    else:
        assert_never(color)  # no error here
```

## Performance

I avoided an initial regression here for large enums, but the
`UnionBuilder` and `IntersectionBuilder` parts can certainly still be
optimized. We might want to use the same technique that we also use for
unions of other literals. I didn't see any problems in our benchmarks so
far, so this is not included yet.

## Test Plan

Many new Markdown tests
This commit is contained in:
David Peter 2025-07-21 19:37:55 +02:00 committed by GitHub
parent 926e83323a
commit dc66019fbc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 750 additions and 102 deletions

View file

@ -39,7 +39,7 @@ use crate::types::call::{Binding, Bindings, CallArguments, CallableBinding};
pub(crate) use crate::types::class_base::ClassBase;
use crate::types::context::{LintDiagnosticGuard, LintDiagnosticGuardBuilder};
use crate::types::diagnostic::{INVALID_TYPE_FORM, UNSUPPORTED_BOOL_CONVERSION};
use crate::types::enums::enum_metadata;
use crate::types::enums::{enum_metadata, is_single_member_enum};
use crate::types::function::{
DataclassTransformerParams, FunctionSpans, FunctionType, KnownFunction,
};
@ -789,6 +789,19 @@ impl<'db> Type<'db> {
matches!(self, Type::ClassLiteral(..))
}
pub fn into_enum_literal(self) -> Option<EnumLiteralType<'db>> {
match self {
Type::EnumLiteral(enum_literal) => Some(enum_literal),
_ => None,
}
}
#[track_caller]
pub fn expect_enum_literal(self) -> EnumLiteralType<'db> {
self.into_enum_literal()
.expect("Expected a Type::EnumLiteral variant")
}
pub(crate) const fn into_tuple(self) -> Option<TupleType<'db>> {
match self {
Type::Tuple(tuple_type) => Some(tuple_type),
@ -1420,6 +1433,16 @@ impl<'db> Type<'db> {
// All `StringLiteral` types are a subtype of `LiteralString`.
(Type::StringLiteral(_), Type::LiteralString) => true,
// An instance is a subtype of an enum literal, if it is an instance of the enum class
// and the enum has only one member.
(Type::NominalInstance(_), Type::EnumLiteral(target_enum_literal)) => {
if target_enum_literal.enum_class_instance(db) != self {
return false;
}
is_single_member_enum(db, target_enum_literal.enum_class(db))
}
// Except for the special `LiteralString` case above,
// most `Literal` types delegate to their instance fallbacks
// unless `self` is exactly equivalent to `target` (handled above)
@ -1656,6 +1679,17 @@ impl<'db> Type<'db> {
| (nominal @ Type::NominalInstance(n), Type::ProtocolInstance(protocol)) => {
n.class.is_object(db) && protocol.normalized(db) == nominal
}
// An instance of an enum class is equivalent to an enum literal of that class,
// if that enum has only has one member.
(Type::NominalInstance(instance), Type::EnumLiteral(literal))
| (Type::EnumLiteral(literal), Type::NominalInstance(instance)) => {
if literal.enum_class_instance(db) != Type::NominalInstance(instance) {
return false;
}
let class_literal = instance.class.class_literal(db).0;
is_single_member_enum(db, class_literal)
}
_ => false,
}
}
@ -8409,6 +8443,7 @@ pub struct EnumLiteralType<'db> {
/// A reference to the enum class this literal belongs to
enum_class: ClassLiteral<'db>,
/// The name of the enum member
#[returns(ref)]
name: Name,
}

View file

@ -37,6 +37,7 @@
//! are subtypes of each other (unless exactly the same literal type), we can avoid many
//! unnecessary `is_subtype_of` checks.
use crate::types::enums::{enum_member_literals, enum_metadata};
use crate::types::{
BytesLiteralType, IntersectionType, KnownClass, StringLiteralType, Type,
TypeVarBoundOrConstraints, UnionType,
@ -87,6 +88,13 @@ enum UnionElement<'db> {
}
impl<'db> UnionElement<'db> {
const fn to_type_element(&self) -> Option<Type<'db>> {
match self {
UnionElement::Type(ty) => Some(*ty),
_ => None,
}
}
/// Try reducing this `UnionElement` given the presence in the same union of `other_type`.
fn try_reduce(&mut self, db: &'db dyn Db, other_type: Type<'db>) -> ReduceResult<'db> {
match self {
@ -374,6 +382,38 @@ impl<'db> UnionBuilder<'db> {
self.elements.swap_remove(index);
}
}
Type::EnumLiteral(enum_member_to_add) => {
let enum_class = enum_member_to_add.enum_class(self.db);
let metadata =
enum_metadata(self.db, enum_class).expect("Class of enum literal is an enum");
let enum_members_in_union = self
.elements
.iter()
.filter_map(UnionElement::to_type_element)
.filter_map(Type::into_enum_literal)
.map(|literal| literal.name(self.db).clone())
.chain(std::iter::once(enum_member_to_add.name(self.db).clone()))
.collect::<FxOrderSet<_>>();
let all_members_are_in_union = metadata
.members
.difference(&enum_members_in_union)
.next()
.is_none();
if all_members_are_in_union {
self.add_in_place(enum_member_to_add.enum_class_instance(self.db));
} else if !self
.elements
.iter()
.filter_map(UnionElement::to_type_element)
.any(|ty| Type::EnumLiteral(enum_member_to_add).is_subtype_of(self.db, ty))
{
self.elements
.push(UnionElement::Type(Type::EnumLiteral(enum_member_to_add)));
}
}
// Adding `object` to a union results in `object`.
ty if ty.is_object(self.db) => {
self.collapse_to_object();
@ -501,72 +541,147 @@ impl<'db> IntersectionBuilder<'db> {
}
pub(crate) fn add_positive(mut self, ty: Type<'db>) -> Self {
if let Type::Union(union) = ty {
// Distribute ourself over this union: for each union element, clone ourself and
// intersect with that union element, then create a new union-of-intersections with all
// of those sub-intersections in it. E.g. if `self` is a simple intersection `T1 & T2`
// and we add `T3 | T4` to the intersection, we don't get `T1 & T2 & (T3 | T4)` (that's
// not in DNF), we distribute the union and get `(T1 & T3) | (T2 & T3) | (T1 & T4) |
// (T2 & T4)`. If `self` is already a union-of-intersections `(T1 & T2) | (T3 & T4)`
// and we add `T5 | T6` to it, that flattens all the way out to `(T1 & T2 & T5) | (T1 &
// T2 & T6) | (T3 & T4 & T5) ...` -- you get the idea.
union
.elements(self.db)
.iter()
.map(|elem| self.clone().add_positive(*elem))
.fold(IntersectionBuilder::empty(self.db), |mut builder, sub| {
builder.intersections.extend(sub.intersections);
builder
})
} else {
// If we are already a union-of-intersections, distribute the new intersected element
// across all of those intersections.
for inner in &mut self.intersections {
inner.add_positive(self.db, ty);
match ty {
Type::Union(union) => {
// Distribute ourself over this union: for each union element, clone ourself and
// intersect with that union element, then create a new union-of-intersections with all
// of those sub-intersections in it. E.g. if `self` is a simple intersection `T1 & T2`
// and we add `T3 | T4` to the intersection, we don't get `T1 & T2 & (T3 | T4)` (that's
// not in DNF), we distribute the union and get `(T1 & T3) | (T2 & T3) | (T1 & T4) |
// (T2 & T4)`. If `self` is already a union-of-intersections `(T1 & T2) | (T3 & T4)`
// and we add `T5 | T6` to it, that flattens all the way out to `(T1 & T2 & T5) | (T1 &
// T2 & T6) | (T3 & T4 & T5) ...` -- you get the idea.
union
.elements(self.db)
.iter()
.map(|elem| self.clone().add_positive(*elem))
.fold(IntersectionBuilder::empty(self.db), |mut builder, sub| {
builder.intersections.extend(sub.intersections);
builder
})
}
// `(A & B & ~C) & (D & E & ~F)` -> `A & B & D & E & ~C & ~F`
Type::Intersection(other) => {
let db = self.db;
for pos in other.positive(db) {
self = self.add_positive(*pos);
}
for neg in other.negative(db) {
self = self.add_negative(*neg);
}
self
}
Type::NominalInstance(instance)
if enum_metadata(self.db, instance.class.class_literal(self.db).0).is_some() =>
{
let mut contains_enum_literal_as_negative_element = false;
for intersection in &self.intersections {
if intersection.negative.iter().any(|negative| {
negative
.into_enum_literal()
.is_some_and(|lit| lit.enum_class_instance(self.db) == ty)
}) {
contains_enum_literal_as_negative_element = true;
break;
}
}
if contains_enum_literal_as_negative_element {
// If we have an enum literal of this enum already in the negative side of
// the intersection, expand the instance into the union of enum members, and
// add that union to the intersection.
// Note: we manually construct a `UnionType` here instead of going through
// `UnionBuilder` because we would simplify the union to just the enum instance
// and end up in this branch again.
let db = self.db;
self.add_positive(Type::Union(UnionType::new(
db,
enum_member_literals(db, instance.class.class_literal(db).0, None)
.expect("Calling `enum_member_literals` on an enum class")
.collect::<Box<[_]>>(),
)))
} else {
for inner in &mut self.intersections {
inner.add_positive(self.db, ty);
}
self
}
}
_ => {
// If we are already a union-of-intersections, distribute the new intersected element
// across all of those intersections.
for inner in &mut self.intersections {
inner.add_positive(self.db, ty);
}
self
}
self
}
}
pub(crate) fn add_negative(mut self, ty: Type<'db>) -> Self {
let contains_enum = |enum_instance| {
self.intersections
.iter()
.flat_map(|intersection| &intersection.positive)
.any(|ty| *ty == enum_instance)
};
// See comments above in `add_positive`; this is just the negated version.
if let Type::Union(union) = ty {
for elem in union.elements(self.db) {
self = self.add_negative(*elem);
match ty {
Type::Union(union) => {
for elem in union.elements(self.db) {
self = self.add_negative(*elem);
}
self
}
self
} else if let Type::Intersection(intersection) = ty {
// (A | B) & ~(C & ~D)
// -> (A | B) & (~C | D)
// -> ((A | B) & ~C) | ((A | B) & D)
// i.e. if we have an intersection of positive constraints C
// and negative constraints D, then our new intersection
// is (existing & ~C) | (existing & D)
Type::Intersection(intersection) => {
// (A | B) & ~(C & ~D)
// -> (A | B) & (~C | D)
// -> ((A | B) & ~C) | ((A | B) & D)
// i.e. if we have an intersection of positive constraints C
// and negative constraints D, then our new intersection
// is (existing & ~C) | (existing & D)
let positive_side = intersection
.positive(self.db)
.iter()
// we negate all the positive constraints while distributing
.map(|elem| self.clone().add_negative(*elem));
let positive_side = intersection
.positive(self.db)
.iter()
// we negate all the positive constraints while distributing
.map(|elem| self.clone().add_negative(*elem));
let negative_side = intersection
.negative(self.db)
.iter()
// all negative constraints end up becoming positive constraints
.map(|elem| self.clone().add_positive(*elem));
let negative_side = intersection
.negative(self.db)
.iter()
// all negative constraints end up becoming positive constraints
.map(|elem| self.clone().add_positive(*elem));
positive_side.chain(negative_side).fold(
IntersectionBuilder::empty(self.db),
|mut builder, sub| {
builder.intersections.extend(sub.intersections);
builder
},
)
} else {
for inner in &mut self.intersections {
inner.add_negative(self.db, ty);
positive_side.chain(negative_side).fold(
IntersectionBuilder::empty(self.db),
|mut builder, sub| {
builder.intersections.extend(sub.intersections);
builder
},
)
}
Type::EnumLiteral(enum_literal)
if contains_enum(enum_literal.enum_class_instance(self.db)) =>
{
let db = self.db;
self.add_positive(UnionType::from_elements(
db,
enum_member_literals(
db,
enum_literal.enum_class(db),
Some(enum_literal.name(db)),
)
.expect("Calling `enum_member_literals` on an enum class"),
))
}
_ => {
for inner in &mut self.intersections {
inner.add_negative(self.db, ty);
}
self
}
self
}
}
@ -643,15 +758,7 @@ impl<'db> InnerIntersectionBuilder<'db> {
self.add_positive(db, Type::LiteralString);
self.add_negative(db, Type::string_literal(db, ""));
}
// `(A & B & ~C) & (D & E & ~F)` -> `A & B & D & E & ~C & ~F`
Type::Intersection(other) => {
for pos in other.positive(db) {
self.add_positive(db, *pos);
}
for neg in other.negative(db) {
self.add_negative(db, *neg);
}
}
_ => {
let known_instance = new_positive
.into_nominal_instance()
@ -961,7 +1068,10 @@ impl<'db> InnerIntersectionBuilder<'db> {
mod tests {
use super::{IntersectionBuilder, Type, UnionBuilder, UnionType};
use crate::KnownModule;
use crate::db::tests::setup_db;
use crate::place::known_module_symbol;
use crate::types::enums::enum_member_literals;
use crate::types::{KnownClass, Truthiness};
use test_case::test_case;
@ -1044,4 +1154,77 @@ mod tests {
.build();
assert_eq!(ty, Type::BooleanLiteral(!bool_value));
}
#[test]
fn build_intersection_enums() {
let db = setup_db();
let safe_uuid_class = known_module_symbol(&db, KnownModule::Uuid, "SafeUUID")
.place
.ignore_possibly_unbound()
.unwrap();
let literals = enum_member_literals(&db, safe_uuid_class.expect_class_literal(), None)
.unwrap()
.collect::<Vec<_>>();
assert_eq!(literals.len(), 3);
// SafeUUID.safe
let l_safe = literals[0];
assert_eq!(l_safe.expect_enum_literal().name(&db), "safe");
// SafeUUID.unsafe
let l_unsafe = literals[1];
assert_eq!(l_unsafe.expect_enum_literal().name(&db), "unsafe");
// SafeUUID.unknown
let l_unknown = literals[2];
assert_eq!(l_unknown.expect_enum_literal().name(&db), "unknown");
// The enum itself: SafeUUID
let safe_uuid = l_safe.expect_enum_literal().enum_class_instance(&db);
{
let actual = IntersectionBuilder::new(&db)
.add_positive(safe_uuid)
.add_negative(l_safe)
.build();
assert_eq!(
actual.display(&db).to_string(),
"Literal[SafeUUID.unsafe, SafeUUID.unknown]"
);
}
{
// Same as above, but with the order reversed
let actual = IntersectionBuilder::new(&db)
.add_negative(l_safe)
.add_positive(safe_uuid)
.build();
assert_eq!(
actual.display(&db).to_string(),
"Literal[SafeUUID.unsafe, SafeUUID.unknown]"
);
}
{
// Also the same, but now with a nested intersection
let actual = IntersectionBuilder::new(&db)
.add_positive(safe_uuid)
.add_positive(IntersectionBuilder::new(&db).add_negative(l_safe).build())
.build();
assert_eq!(
actual.display(&db).to_string(),
"Literal[SafeUUID.unsafe, SafeUUID.unknown]"
);
}
{
let actual = IntersectionBuilder::new(&db)
.add_negative(l_safe)
.add_positive(safe_uuid)
.add_negative(l_unsafe)
.build();
assert_eq!(actual.display(&db).to_string(), "Literal[SafeUUID.unknown]");
}
}
}

View file

@ -5,6 +5,7 @@ use ruff_python_ast as ast;
use crate::Db;
use crate::types::KnownClass;
use crate::types::enums::enum_member_literals;
use crate::types::tuple::{TupleSpec, TupleType};
use super::Type;
@ -199,13 +200,22 @@ impl<'a, 'db> FromIterator<(Argument<'a>, Option<Type<'db>>)> for CallArguments<
///
/// Returns [`None`] if the type cannot be expanded.
fn expand_type<'db>(db: &'db dyn Db, ty: Type<'db>) -> Option<Vec<Type<'db>>> {
// TODO: Expand enums to their variants
match ty {
Type::NominalInstance(instance) if instance.class.is_known(db, KnownClass::Bool) => {
Some(vec![
Type::BooleanLiteral(true),
Type::BooleanLiteral(false),
])
Type::NominalInstance(instance) => {
if instance.class.is_known(db, KnownClass::Bool) {
return Some(vec![
Type::BooleanLiteral(true),
Type::BooleanLiteral(false),
]);
}
let class_literal = instance.class.class_literal(db).0;
if let Some(enum_members) = enum_member_literals(db, class_literal, None) {
return Some(enum_members.collect());
}
None
}
Type::Tuple(tuple_type) => {
// Note: This should only account for tuples of known length, i.e., `tuple[bool, ...]`

View file

@ -2,22 +2,27 @@ use ruff_python_ast::name::Name;
use rustc_hash::FxHashMap;
use crate::{
Db,
Db, FxOrderSet,
place::{Place, PlaceAndQualifiers, place_from_bindings, place_from_declarations},
semantic_index::{place_table, use_def_map},
types::{ClassLiteral, DynamicType, KnownClass, MemberLookupPolicy, Type, TypeQualifiers},
types::{
ClassLiteral, DynamicType, EnumLiteralType, KnownClass, MemberLookupPolicy, Type,
TypeQualifiers,
},
};
#[derive(Debug, PartialEq, Eq, get_size2::GetSize)]
#[derive(Debug, PartialEq, Eq)]
pub(crate) struct EnumMetadata {
pub(crate) members: Box<[Name]>,
pub(crate) members: FxOrderSet<Name>,
pub(crate) aliases: FxHashMap<Name, Name>,
}
impl get_size2::GetSize for EnumMetadata {}
impl EnumMetadata {
fn empty() -> Self {
EnumMetadata {
members: Box::new([]),
members: FxOrderSet::default(),
aliases: FxHashMap::default(),
}
}
@ -48,7 +53,7 @@ fn enum_metadata_cycle_initial(_db: &dyn Db, _class: ClassLiteral<'_>) -> Option
/// List all members of an enum.
#[allow(clippy::ref_option, clippy::unnecessary_wraps)]
#[salsa::tracked(returns(ref), cycle_fn=enum_metadata_cycle_recover, cycle_initial=enum_metadata_cycle_initial, heap_size=get_size2::GetSize::get_heap_size)]
#[salsa::tracked(returns(as_ref), cycle_fn=enum_metadata_cycle_recover, cycle_initial=enum_metadata_cycle_initial, heap_size=get_size2::GetSize::get_heap_size)]
pub(crate) fn enum_metadata<'db>(
db: &'db dyn Db,
class: ClassLiteral<'db>,
@ -208,7 +213,7 @@ pub(crate) fn enum_metadata<'db>(
Some(name)
})
.cloned()
.collect::<Box<_>>();
.collect::<FxOrderSet<_>>();
if members.is_empty() {
// Enum subclasses without members are not considered enums.
@ -217,3 +222,21 @@ pub(crate) fn enum_metadata<'db>(
Some(EnumMetadata { members, aliases })
}
pub(crate) fn enum_member_literals<'a, 'db: 'a>(
db: &'db dyn Db,
class: ClassLiteral<'db>,
exclude_member: Option<&'a Name>,
) -> Option<impl Iterator<Item = Type<'a>> + 'a> {
enum_metadata(db, class).map(|metadata| {
metadata
.members
.iter()
.filter(move |name| Some(*name) != exclude_member)
.map(move |name| Type::EnumLiteral(EnumLiteralType::new(db, class, name.clone())))
})
}
pub(crate) fn is_single_member_enum<'db>(db: &'db dyn Db, class: ClassLiteral<'db>) -> bool {
enum_metadata(db, class).is_some_and(|metadata| metadata.members.len() == 1)
}

View file

@ -6,6 +6,7 @@ use super::protocol_class::ProtocolInterface;
use super::{ClassType, KnownClass, SubclassOfType, Type, TypeVarVariance};
use crate::place::PlaceAndQualifiers;
use crate::types::cyclic::PairVisitor;
use crate::types::enums::is_single_member_enum;
use crate::types::protocol_class::walk_protocol_interface;
use crate::types::tuple::TupleType;
use crate::types::{DynamicType, TypeMapping, TypeRelation, TypeTransformer, TypeVarInstance};
@ -125,12 +126,14 @@ impl<'db> NominalInstanceType<'db> {
pub(super) fn is_singleton(self, db: &'db dyn Db) -> bool {
self.class.known(db).is_some_and(KnownClass::is_singleton)
|| is_single_member_enum(db, self.class.class_literal(db).0)
}
pub(super) fn is_single_valued(self, db: &'db dyn Db) -> bool {
self.class
.known(db)
.is_some_and(KnownClass::is_single_valued)
|| is_single_member_enum(db, self.class.class_literal(db).0)
}
pub(super) fn to_meta_type(self, db: &'db dyn Db) -> Type<'db> {

View file

@ -5,6 +5,7 @@ use crate::semantic_index::place_table;
use crate::semantic_index::predicate::{
CallableAndCallExpr, PatternPredicate, PatternPredicateKind, Predicate, PredicateNode,
};
use crate::types::enums::{enum_member_literals, enum_metadata};
use crate::types::function::KnownFunction;
use crate::types::infer::infer_same_file_expression_type;
use crate::types::{
@ -559,6 +560,17 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
.map(|ty| filter_to_cannot_be_equal(db, ty, rhs_ty)),
)
}
// Treat enums as a union of their members.
Type::NominalInstance(instance)
if enum_metadata(db, instance.class.class_literal(db).0).is_some() =>
{
UnionType::from_elements(
db,
enum_member_literals(db, instance.class.class_literal(db).0, None)
.expect("Calling `enum_member_literals` on an enum class")
.map(|ty| filter_to_cannot_be_equal(db, ty, rhs_ty)),
)
}
_ => {
if ty.is_single_valued(db) && !could_compare_equal(db, ty, rhs_ty) {
ty