mirror of
https://github.com/astral-sh/ruff.git
synced 2025-08-22 11:24:35 +00:00
[ty] Expansion of enums into unions of literals (#19382)
## Summary Implement expansion of enums into unions of enum literals (and the reverse operation). For the enum below, this allows us to understand that `Color = Literal[Color.RED, Color.GREEN, Color.BLUE]`, or that `Color & ~Literal[Color.RED] = Literal[Color.GREEN, Color.BLUE]`. This helps in exhaustiveness checking, which is why we see some removed `assert_never` false positives. And since exhaustiveness checking also helps with understanding terminal control flow, we also see a few removed `invalid-return-type` and `possibly-unresolved-reference` false positives. This PR also adds expansion of enums in overload resolution and type narrowing constructs. ```py from enum import Enum from typing_extensions import Literal, assert_never from ty_extensions import Intersection, Not, static_assert, is_equivalent_to class Color(Enum): RED = 1 GREEN = 2 BLUE = 3 type Red = Literal[Color.RED] type Green = Literal[Color.GREEN] type Blue = Literal[Color.BLUE] static_assert(is_equivalent_to(Red | Green | Blue, Color)) static_assert(is_equivalent_to(Intersection[Color, Not[Red]], Green | Blue)) def color_name(color: Color) -> str: # no error here (we detect that this can not implicitly return None) if color is Color.RED: return "Red" elif color is Color.GREEN: return "Green" elif color is Color.BLUE: return "Blue" else: assert_never(color) # no error here ``` ## Performance I avoided an initial regression here for large enums, but the `UnionBuilder` and `IntersectionBuilder` parts can certainly still be optimized. We might want to use the same technique that we also use for unions of other literals. I didn't see any problems in our benchmarks so far, so this is not included yet. ## Test Plan Many new Markdown tests
This commit is contained in:
parent
926e83323a
commit
dc66019fbc
19 changed files with 750 additions and 102 deletions
|
@ -39,7 +39,7 @@ use crate::types::call::{Binding, Bindings, CallArguments, CallableBinding};
|
|||
pub(crate) use crate::types::class_base::ClassBase;
|
||||
use crate::types::context::{LintDiagnosticGuard, LintDiagnosticGuardBuilder};
|
||||
use crate::types::diagnostic::{INVALID_TYPE_FORM, UNSUPPORTED_BOOL_CONVERSION};
|
||||
use crate::types::enums::enum_metadata;
|
||||
use crate::types::enums::{enum_metadata, is_single_member_enum};
|
||||
use crate::types::function::{
|
||||
DataclassTransformerParams, FunctionSpans, FunctionType, KnownFunction,
|
||||
};
|
||||
|
@ -789,6 +789,19 @@ impl<'db> Type<'db> {
|
|||
matches!(self, Type::ClassLiteral(..))
|
||||
}
|
||||
|
||||
pub fn into_enum_literal(self) -> Option<EnumLiteralType<'db>> {
|
||||
match self {
|
||||
Type::EnumLiteral(enum_literal) => Some(enum_literal),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
pub fn expect_enum_literal(self) -> EnumLiteralType<'db> {
|
||||
self.into_enum_literal()
|
||||
.expect("Expected a Type::EnumLiteral variant")
|
||||
}
|
||||
|
||||
pub(crate) const fn into_tuple(self) -> Option<TupleType<'db>> {
|
||||
match self {
|
||||
Type::Tuple(tuple_type) => Some(tuple_type),
|
||||
|
@ -1420,6 +1433,16 @@ impl<'db> Type<'db> {
|
|||
// All `StringLiteral` types are a subtype of `LiteralString`.
|
||||
(Type::StringLiteral(_), Type::LiteralString) => true,
|
||||
|
||||
// An instance is a subtype of an enum literal, if it is an instance of the enum class
|
||||
// and the enum has only one member.
|
||||
(Type::NominalInstance(_), Type::EnumLiteral(target_enum_literal)) => {
|
||||
if target_enum_literal.enum_class_instance(db) != self {
|
||||
return false;
|
||||
}
|
||||
|
||||
is_single_member_enum(db, target_enum_literal.enum_class(db))
|
||||
}
|
||||
|
||||
// Except for the special `LiteralString` case above,
|
||||
// most `Literal` types delegate to their instance fallbacks
|
||||
// unless `self` is exactly equivalent to `target` (handled above)
|
||||
|
@ -1656,6 +1679,17 @@ impl<'db> Type<'db> {
|
|||
| (nominal @ Type::NominalInstance(n), Type::ProtocolInstance(protocol)) => {
|
||||
n.class.is_object(db) && protocol.normalized(db) == nominal
|
||||
}
|
||||
// An instance of an enum class is equivalent to an enum literal of that class,
|
||||
// if that enum has only has one member.
|
||||
(Type::NominalInstance(instance), Type::EnumLiteral(literal))
|
||||
| (Type::EnumLiteral(literal), Type::NominalInstance(instance)) => {
|
||||
if literal.enum_class_instance(db) != Type::NominalInstance(instance) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let class_literal = instance.class.class_literal(db).0;
|
||||
is_single_member_enum(db, class_literal)
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
@ -8409,6 +8443,7 @@ pub struct EnumLiteralType<'db> {
|
|||
/// A reference to the enum class this literal belongs to
|
||||
enum_class: ClassLiteral<'db>,
|
||||
/// The name of the enum member
|
||||
#[returns(ref)]
|
||||
name: Name,
|
||||
}
|
||||
|
||||
|
|
|
@ -37,6 +37,7 @@
|
|||
//! are subtypes of each other (unless exactly the same literal type), we can avoid many
|
||||
//! unnecessary `is_subtype_of` checks.
|
||||
|
||||
use crate::types::enums::{enum_member_literals, enum_metadata};
|
||||
use crate::types::{
|
||||
BytesLiteralType, IntersectionType, KnownClass, StringLiteralType, Type,
|
||||
TypeVarBoundOrConstraints, UnionType,
|
||||
|
@ -87,6 +88,13 @@ enum UnionElement<'db> {
|
|||
}
|
||||
|
||||
impl<'db> UnionElement<'db> {
|
||||
const fn to_type_element(&self) -> Option<Type<'db>> {
|
||||
match self {
|
||||
UnionElement::Type(ty) => Some(*ty),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Try reducing this `UnionElement` given the presence in the same union of `other_type`.
|
||||
fn try_reduce(&mut self, db: &'db dyn Db, other_type: Type<'db>) -> ReduceResult<'db> {
|
||||
match self {
|
||||
|
@ -374,6 +382,38 @@ impl<'db> UnionBuilder<'db> {
|
|||
self.elements.swap_remove(index);
|
||||
}
|
||||
}
|
||||
Type::EnumLiteral(enum_member_to_add) => {
|
||||
let enum_class = enum_member_to_add.enum_class(self.db);
|
||||
let metadata =
|
||||
enum_metadata(self.db, enum_class).expect("Class of enum literal is an enum");
|
||||
|
||||
let enum_members_in_union = self
|
||||
.elements
|
||||
.iter()
|
||||
.filter_map(UnionElement::to_type_element)
|
||||
.filter_map(Type::into_enum_literal)
|
||||
.map(|literal| literal.name(self.db).clone())
|
||||
.chain(std::iter::once(enum_member_to_add.name(self.db).clone()))
|
||||
.collect::<FxOrderSet<_>>();
|
||||
|
||||
let all_members_are_in_union = metadata
|
||||
.members
|
||||
.difference(&enum_members_in_union)
|
||||
.next()
|
||||
.is_none();
|
||||
|
||||
if all_members_are_in_union {
|
||||
self.add_in_place(enum_member_to_add.enum_class_instance(self.db));
|
||||
} else if !self
|
||||
.elements
|
||||
.iter()
|
||||
.filter_map(UnionElement::to_type_element)
|
||||
.any(|ty| Type::EnumLiteral(enum_member_to_add).is_subtype_of(self.db, ty))
|
||||
{
|
||||
self.elements
|
||||
.push(UnionElement::Type(Type::EnumLiteral(enum_member_to_add)));
|
||||
}
|
||||
}
|
||||
// Adding `object` to a union results in `object`.
|
||||
ty if ty.is_object(self.db) => {
|
||||
self.collapse_to_object();
|
||||
|
@ -501,72 +541,147 @@ impl<'db> IntersectionBuilder<'db> {
|
|||
}
|
||||
|
||||
pub(crate) fn add_positive(mut self, ty: Type<'db>) -> Self {
|
||||
if let Type::Union(union) = ty {
|
||||
// Distribute ourself over this union: for each union element, clone ourself and
|
||||
// intersect with that union element, then create a new union-of-intersections with all
|
||||
// of those sub-intersections in it. E.g. if `self` is a simple intersection `T1 & T2`
|
||||
// and we add `T3 | T4` to the intersection, we don't get `T1 & T2 & (T3 | T4)` (that's
|
||||
// not in DNF), we distribute the union and get `(T1 & T3) | (T2 & T3) | (T1 & T4) |
|
||||
// (T2 & T4)`. If `self` is already a union-of-intersections `(T1 & T2) | (T3 & T4)`
|
||||
// and we add `T5 | T6` to it, that flattens all the way out to `(T1 & T2 & T5) | (T1 &
|
||||
// T2 & T6) | (T3 & T4 & T5) ...` -- you get the idea.
|
||||
union
|
||||
.elements(self.db)
|
||||
.iter()
|
||||
.map(|elem| self.clone().add_positive(*elem))
|
||||
.fold(IntersectionBuilder::empty(self.db), |mut builder, sub| {
|
||||
builder.intersections.extend(sub.intersections);
|
||||
builder
|
||||
})
|
||||
} else {
|
||||
// If we are already a union-of-intersections, distribute the new intersected element
|
||||
// across all of those intersections.
|
||||
for inner in &mut self.intersections {
|
||||
inner.add_positive(self.db, ty);
|
||||
match ty {
|
||||
Type::Union(union) => {
|
||||
// Distribute ourself over this union: for each union element, clone ourself and
|
||||
// intersect with that union element, then create a new union-of-intersections with all
|
||||
// of those sub-intersections in it. E.g. if `self` is a simple intersection `T1 & T2`
|
||||
// and we add `T3 | T4` to the intersection, we don't get `T1 & T2 & (T3 | T4)` (that's
|
||||
// not in DNF), we distribute the union and get `(T1 & T3) | (T2 & T3) | (T1 & T4) |
|
||||
// (T2 & T4)`. If `self` is already a union-of-intersections `(T1 & T2) | (T3 & T4)`
|
||||
// and we add `T5 | T6` to it, that flattens all the way out to `(T1 & T2 & T5) | (T1 &
|
||||
// T2 & T6) | (T3 & T4 & T5) ...` -- you get the idea.
|
||||
union
|
||||
.elements(self.db)
|
||||
.iter()
|
||||
.map(|elem| self.clone().add_positive(*elem))
|
||||
.fold(IntersectionBuilder::empty(self.db), |mut builder, sub| {
|
||||
builder.intersections.extend(sub.intersections);
|
||||
builder
|
||||
})
|
||||
}
|
||||
// `(A & B & ~C) & (D & E & ~F)` -> `A & B & D & E & ~C & ~F`
|
||||
Type::Intersection(other) => {
|
||||
let db = self.db;
|
||||
for pos in other.positive(db) {
|
||||
self = self.add_positive(*pos);
|
||||
}
|
||||
for neg in other.negative(db) {
|
||||
self = self.add_negative(*neg);
|
||||
}
|
||||
self
|
||||
}
|
||||
Type::NominalInstance(instance)
|
||||
if enum_metadata(self.db, instance.class.class_literal(self.db).0).is_some() =>
|
||||
{
|
||||
let mut contains_enum_literal_as_negative_element = false;
|
||||
for intersection in &self.intersections {
|
||||
if intersection.negative.iter().any(|negative| {
|
||||
negative
|
||||
.into_enum_literal()
|
||||
.is_some_and(|lit| lit.enum_class_instance(self.db) == ty)
|
||||
}) {
|
||||
contains_enum_literal_as_negative_element = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if contains_enum_literal_as_negative_element {
|
||||
// If we have an enum literal of this enum already in the negative side of
|
||||
// the intersection, expand the instance into the union of enum members, and
|
||||
// add that union to the intersection.
|
||||
// Note: we manually construct a `UnionType` here instead of going through
|
||||
// `UnionBuilder` because we would simplify the union to just the enum instance
|
||||
// and end up in this branch again.
|
||||
let db = self.db;
|
||||
self.add_positive(Type::Union(UnionType::new(
|
||||
db,
|
||||
enum_member_literals(db, instance.class.class_literal(db).0, None)
|
||||
.expect("Calling `enum_member_literals` on an enum class")
|
||||
.collect::<Box<[_]>>(),
|
||||
)))
|
||||
} else {
|
||||
for inner in &mut self.intersections {
|
||||
inner.add_positive(self.db, ty);
|
||||
}
|
||||
self
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// If we are already a union-of-intersections, distribute the new intersected element
|
||||
// across all of those intersections.
|
||||
for inner in &mut self.intersections {
|
||||
inner.add_positive(self.db, ty);
|
||||
}
|
||||
self
|
||||
}
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn add_negative(mut self, ty: Type<'db>) -> Self {
|
||||
let contains_enum = |enum_instance| {
|
||||
self.intersections
|
||||
.iter()
|
||||
.flat_map(|intersection| &intersection.positive)
|
||||
.any(|ty| *ty == enum_instance)
|
||||
};
|
||||
|
||||
// See comments above in `add_positive`; this is just the negated version.
|
||||
if let Type::Union(union) = ty {
|
||||
for elem in union.elements(self.db) {
|
||||
self = self.add_negative(*elem);
|
||||
match ty {
|
||||
Type::Union(union) => {
|
||||
for elem in union.elements(self.db) {
|
||||
self = self.add_negative(*elem);
|
||||
}
|
||||
self
|
||||
}
|
||||
self
|
||||
} else if let Type::Intersection(intersection) = ty {
|
||||
// (A | B) & ~(C & ~D)
|
||||
// -> (A | B) & (~C | D)
|
||||
// -> ((A | B) & ~C) | ((A | B) & D)
|
||||
// i.e. if we have an intersection of positive constraints C
|
||||
// and negative constraints D, then our new intersection
|
||||
// is (existing & ~C) | (existing & D)
|
||||
Type::Intersection(intersection) => {
|
||||
// (A | B) & ~(C & ~D)
|
||||
// -> (A | B) & (~C | D)
|
||||
// -> ((A | B) & ~C) | ((A | B) & D)
|
||||
// i.e. if we have an intersection of positive constraints C
|
||||
// and negative constraints D, then our new intersection
|
||||
// is (existing & ~C) | (existing & D)
|
||||
|
||||
let positive_side = intersection
|
||||
.positive(self.db)
|
||||
.iter()
|
||||
// we negate all the positive constraints while distributing
|
||||
.map(|elem| self.clone().add_negative(*elem));
|
||||
let positive_side = intersection
|
||||
.positive(self.db)
|
||||
.iter()
|
||||
// we negate all the positive constraints while distributing
|
||||
.map(|elem| self.clone().add_negative(*elem));
|
||||
|
||||
let negative_side = intersection
|
||||
.negative(self.db)
|
||||
.iter()
|
||||
// all negative constraints end up becoming positive constraints
|
||||
.map(|elem| self.clone().add_positive(*elem));
|
||||
let negative_side = intersection
|
||||
.negative(self.db)
|
||||
.iter()
|
||||
// all negative constraints end up becoming positive constraints
|
||||
.map(|elem| self.clone().add_positive(*elem));
|
||||
|
||||
positive_side.chain(negative_side).fold(
|
||||
IntersectionBuilder::empty(self.db),
|
||||
|mut builder, sub| {
|
||||
builder.intersections.extend(sub.intersections);
|
||||
builder
|
||||
},
|
||||
)
|
||||
} else {
|
||||
for inner in &mut self.intersections {
|
||||
inner.add_negative(self.db, ty);
|
||||
positive_side.chain(negative_side).fold(
|
||||
IntersectionBuilder::empty(self.db),
|
||||
|mut builder, sub| {
|
||||
builder.intersections.extend(sub.intersections);
|
||||
builder
|
||||
},
|
||||
)
|
||||
}
|
||||
Type::EnumLiteral(enum_literal)
|
||||
if contains_enum(enum_literal.enum_class_instance(self.db)) =>
|
||||
{
|
||||
let db = self.db;
|
||||
self.add_positive(UnionType::from_elements(
|
||||
db,
|
||||
enum_member_literals(
|
||||
db,
|
||||
enum_literal.enum_class(db),
|
||||
Some(enum_literal.name(db)),
|
||||
)
|
||||
.expect("Calling `enum_member_literals` on an enum class"),
|
||||
))
|
||||
}
|
||||
_ => {
|
||||
for inner in &mut self.intersections {
|
||||
inner.add_negative(self.db, ty);
|
||||
}
|
||||
self
|
||||
}
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -643,15 +758,7 @@ impl<'db> InnerIntersectionBuilder<'db> {
|
|||
self.add_positive(db, Type::LiteralString);
|
||||
self.add_negative(db, Type::string_literal(db, ""));
|
||||
}
|
||||
// `(A & B & ~C) & (D & E & ~F)` -> `A & B & D & E & ~C & ~F`
|
||||
Type::Intersection(other) => {
|
||||
for pos in other.positive(db) {
|
||||
self.add_positive(db, *pos);
|
||||
}
|
||||
for neg in other.negative(db) {
|
||||
self.add_negative(db, *neg);
|
||||
}
|
||||
}
|
||||
|
||||
_ => {
|
||||
let known_instance = new_positive
|
||||
.into_nominal_instance()
|
||||
|
@ -961,7 +1068,10 @@ impl<'db> InnerIntersectionBuilder<'db> {
|
|||
mod tests {
|
||||
use super::{IntersectionBuilder, Type, UnionBuilder, UnionType};
|
||||
|
||||
use crate::KnownModule;
|
||||
use crate::db::tests::setup_db;
|
||||
use crate::place::known_module_symbol;
|
||||
use crate::types::enums::enum_member_literals;
|
||||
use crate::types::{KnownClass, Truthiness};
|
||||
|
||||
use test_case::test_case;
|
||||
|
@ -1044,4 +1154,77 @@ mod tests {
|
|||
.build();
|
||||
assert_eq!(ty, Type::BooleanLiteral(!bool_value));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_intersection_enums() {
|
||||
let db = setup_db();
|
||||
|
||||
let safe_uuid_class = known_module_symbol(&db, KnownModule::Uuid, "SafeUUID")
|
||||
.place
|
||||
.ignore_possibly_unbound()
|
||||
.unwrap();
|
||||
|
||||
let literals = enum_member_literals(&db, safe_uuid_class.expect_class_literal(), None)
|
||||
.unwrap()
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(literals.len(), 3);
|
||||
|
||||
// SafeUUID.safe
|
||||
let l_safe = literals[0];
|
||||
assert_eq!(l_safe.expect_enum_literal().name(&db), "safe");
|
||||
// SafeUUID.unsafe
|
||||
let l_unsafe = literals[1];
|
||||
assert_eq!(l_unsafe.expect_enum_literal().name(&db), "unsafe");
|
||||
// SafeUUID.unknown
|
||||
let l_unknown = literals[2];
|
||||
assert_eq!(l_unknown.expect_enum_literal().name(&db), "unknown");
|
||||
|
||||
// The enum itself: SafeUUID
|
||||
let safe_uuid = l_safe.expect_enum_literal().enum_class_instance(&db);
|
||||
|
||||
{
|
||||
let actual = IntersectionBuilder::new(&db)
|
||||
.add_positive(safe_uuid)
|
||||
.add_negative(l_safe)
|
||||
.build();
|
||||
|
||||
assert_eq!(
|
||||
actual.display(&db).to_string(),
|
||||
"Literal[SafeUUID.unsafe, SafeUUID.unknown]"
|
||||
);
|
||||
}
|
||||
{
|
||||
// Same as above, but with the order reversed
|
||||
let actual = IntersectionBuilder::new(&db)
|
||||
.add_negative(l_safe)
|
||||
.add_positive(safe_uuid)
|
||||
.build();
|
||||
|
||||
assert_eq!(
|
||||
actual.display(&db).to_string(),
|
||||
"Literal[SafeUUID.unsafe, SafeUUID.unknown]"
|
||||
);
|
||||
}
|
||||
{
|
||||
// Also the same, but now with a nested intersection
|
||||
let actual = IntersectionBuilder::new(&db)
|
||||
.add_positive(safe_uuid)
|
||||
.add_positive(IntersectionBuilder::new(&db).add_negative(l_safe).build())
|
||||
.build();
|
||||
|
||||
assert_eq!(
|
||||
actual.display(&db).to_string(),
|
||||
"Literal[SafeUUID.unsafe, SafeUUID.unknown]"
|
||||
);
|
||||
}
|
||||
{
|
||||
let actual = IntersectionBuilder::new(&db)
|
||||
.add_negative(l_safe)
|
||||
.add_positive(safe_uuid)
|
||||
.add_negative(l_unsafe)
|
||||
.build();
|
||||
|
||||
assert_eq!(actual.display(&db).to_string(), "Literal[SafeUUID.unknown]");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ use ruff_python_ast as ast;
|
|||
|
||||
use crate::Db;
|
||||
use crate::types::KnownClass;
|
||||
use crate::types::enums::enum_member_literals;
|
||||
use crate::types::tuple::{TupleSpec, TupleType};
|
||||
|
||||
use super::Type;
|
||||
|
@ -199,13 +200,22 @@ impl<'a, 'db> FromIterator<(Argument<'a>, Option<Type<'db>>)> for CallArguments<
|
|||
///
|
||||
/// Returns [`None`] if the type cannot be expanded.
|
||||
fn expand_type<'db>(db: &'db dyn Db, ty: Type<'db>) -> Option<Vec<Type<'db>>> {
|
||||
// TODO: Expand enums to their variants
|
||||
match ty {
|
||||
Type::NominalInstance(instance) if instance.class.is_known(db, KnownClass::Bool) => {
|
||||
Some(vec![
|
||||
Type::BooleanLiteral(true),
|
||||
Type::BooleanLiteral(false),
|
||||
])
|
||||
Type::NominalInstance(instance) => {
|
||||
if instance.class.is_known(db, KnownClass::Bool) {
|
||||
return Some(vec![
|
||||
Type::BooleanLiteral(true),
|
||||
Type::BooleanLiteral(false),
|
||||
]);
|
||||
}
|
||||
|
||||
let class_literal = instance.class.class_literal(db).0;
|
||||
|
||||
if let Some(enum_members) = enum_member_literals(db, class_literal, None) {
|
||||
return Some(enum_members.collect());
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
Type::Tuple(tuple_type) => {
|
||||
// Note: This should only account for tuples of known length, i.e., `tuple[bool, ...]`
|
||||
|
|
|
@ -2,22 +2,27 @@ use ruff_python_ast::name::Name;
|
|||
use rustc_hash::FxHashMap;
|
||||
|
||||
use crate::{
|
||||
Db,
|
||||
Db, FxOrderSet,
|
||||
place::{Place, PlaceAndQualifiers, place_from_bindings, place_from_declarations},
|
||||
semantic_index::{place_table, use_def_map},
|
||||
types::{ClassLiteral, DynamicType, KnownClass, MemberLookupPolicy, Type, TypeQualifiers},
|
||||
types::{
|
||||
ClassLiteral, DynamicType, EnumLiteralType, KnownClass, MemberLookupPolicy, Type,
|
||||
TypeQualifiers,
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, get_size2::GetSize)]
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub(crate) struct EnumMetadata {
|
||||
pub(crate) members: Box<[Name]>,
|
||||
pub(crate) members: FxOrderSet<Name>,
|
||||
pub(crate) aliases: FxHashMap<Name, Name>,
|
||||
}
|
||||
|
||||
impl get_size2::GetSize for EnumMetadata {}
|
||||
|
||||
impl EnumMetadata {
|
||||
fn empty() -> Self {
|
||||
EnumMetadata {
|
||||
members: Box::new([]),
|
||||
members: FxOrderSet::default(),
|
||||
aliases: FxHashMap::default(),
|
||||
}
|
||||
}
|
||||
|
@ -48,7 +53,7 @@ fn enum_metadata_cycle_initial(_db: &dyn Db, _class: ClassLiteral<'_>) -> Option
|
|||
|
||||
/// List all members of an enum.
|
||||
#[allow(clippy::ref_option, clippy::unnecessary_wraps)]
|
||||
#[salsa::tracked(returns(ref), cycle_fn=enum_metadata_cycle_recover, cycle_initial=enum_metadata_cycle_initial, heap_size=get_size2::GetSize::get_heap_size)]
|
||||
#[salsa::tracked(returns(as_ref), cycle_fn=enum_metadata_cycle_recover, cycle_initial=enum_metadata_cycle_initial, heap_size=get_size2::GetSize::get_heap_size)]
|
||||
pub(crate) fn enum_metadata<'db>(
|
||||
db: &'db dyn Db,
|
||||
class: ClassLiteral<'db>,
|
||||
|
@ -208,7 +213,7 @@ pub(crate) fn enum_metadata<'db>(
|
|||
Some(name)
|
||||
})
|
||||
.cloned()
|
||||
.collect::<Box<_>>();
|
||||
.collect::<FxOrderSet<_>>();
|
||||
|
||||
if members.is_empty() {
|
||||
// Enum subclasses without members are not considered enums.
|
||||
|
@ -217,3 +222,21 @@ pub(crate) fn enum_metadata<'db>(
|
|||
|
||||
Some(EnumMetadata { members, aliases })
|
||||
}
|
||||
|
||||
pub(crate) fn enum_member_literals<'a, 'db: 'a>(
|
||||
db: &'db dyn Db,
|
||||
class: ClassLiteral<'db>,
|
||||
exclude_member: Option<&'a Name>,
|
||||
) -> Option<impl Iterator<Item = Type<'a>> + 'a> {
|
||||
enum_metadata(db, class).map(|metadata| {
|
||||
metadata
|
||||
.members
|
||||
.iter()
|
||||
.filter(move |name| Some(*name) != exclude_member)
|
||||
.map(move |name| Type::EnumLiteral(EnumLiteralType::new(db, class, name.clone())))
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn is_single_member_enum<'db>(db: &'db dyn Db, class: ClassLiteral<'db>) -> bool {
|
||||
enum_metadata(db, class).is_some_and(|metadata| metadata.members.len() == 1)
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ use super::protocol_class::ProtocolInterface;
|
|||
use super::{ClassType, KnownClass, SubclassOfType, Type, TypeVarVariance};
|
||||
use crate::place::PlaceAndQualifiers;
|
||||
use crate::types::cyclic::PairVisitor;
|
||||
use crate::types::enums::is_single_member_enum;
|
||||
use crate::types::protocol_class::walk_protocol_interface;
|
||||
use crate::types::tuple::TupleType;
|
||||
use crate::types::{DynamicType, TypeMapping, TypeRelation, TypeTransformer, TypeVarInstance};
|
||||
|
@ -125,12 +126,14 @@ impl<'db> NominalInstanceType<'db> {
|
|||
|
||||
pub(super) fn is_singleton(self, db: &'db dyn Db) -> bool {
|
||||
self.class.known(db).is_some_and(KnownClass::is_singleton)
|
||||
|| is_single_member_enum(db, self.class.class_literal(db).0)
|
||||
}
|
||||
|
||||
pub(super) fn is_single_valued(self, db: &'db dyn Db) -> bool {
|
||||
self.class
|
||||
.known(db)
|
||||
.is_some_and(KnownClass::is_single_valued)
|
||||
|| is_single_member_enum(db, self.class.class_literal(db).0)
|
||||
}
|
||||
|
||||
pub(super) fn to_meta_type(self, db: &'db dyn Db) -> Type<'db> {
|
||||
|
|
|
@ -5,6 +5,7 @@ use crate::semantic_index::place_table;
|
|||
use crate::semantic_index::predicate::{
|
||||
CallableAndCallExpr, PatternPredicate, PatternPredicateKind, Predicate, PredicateNode,
|
||||
};
|
||||
use crate::types::enums::{enum_member_literals, enum_metadata};
|
||||
use crate::types::function::KnownFunction;
|
||||
use crate::types::infer::infer_same_file_expression_type;
|
||||
use crate::types::{
|
||||
|
@ -559,6 +560,17 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
|
|||
.map(|ty| filter_to_cannot_be_equal(db, ty, rhs_ty)),
|
||||
)
|
||||
}
|
||||
// Treat enums as a union of their members.
|
||||
Type::NominalInstance(instance)
|
||||
if enum_metadata(db, instance.class.class_literal(db).0).is_some() =>
|
||||
{
|
||||
UnionType::from_elements(
|
||||
db,
|
||||
enum_member_literals(db, instance.class.class_literal(db).0, None)
|
||||
.expect("Calling `enum_member_literals` on an enum class")
|
||||
.map(|ty| filter_to_cannot_be_equal(db, ty, rhs_ty)),
|
||||
)
|
||||
}
|
||||
_ => {
|
||||
if ty.is_single_valued(db) && !could_compare_equal(db, ty, rhs_ty) {
|
||||
ty
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue