diff --git a/crates/ty_python_semantic/resources/mdtest/annotations/literal.md b/crates/ty_python_semantic/resources/mdtest/annotations/literal.md index 05b0868523..29f6ea7ca2 100644 --- a/crates/ty_python_semantic/resources/mdtest/annotations/literal.md +++ b/crates/ty_python_semantic/resources/mdtest/annotations/literal.md @@ -51,6 +51,13 @@ invalid4: Literal[ hello, # error: [invalid-type-form] (1, 2, 3), # error: [invalid-type-form] ] + +class NotAnEnum: + x: int = 1 + +# error: [invalid-type-form] +def _(invalid: Literal[NotAnEnum.x]) -> None: + reveal_type(invalid) # revealed: Unknown ``` ## Shortening unions of literals diff --git a/crates/ty_python_semantic/src/types/enums.rs b/crates/ty_python_semantic/src/types/enums.rs index 59c814c147..5c026beff5 100644 --- a/crates/ty_python_semantic/src/types/enums.rs +++ b/crates/ty_python_semantic/src/types/enums.rs @@ -240,3 +240,10 @@ pub(crate) fn enum_member_literals<'a, 'db: 'a>( pub(crate) fn is_single_member_enum<'db>(db: &'db dyn Db, class: ClassLiteral<'db>) -> bool { enum_metadata(db, class).is_some_and(|metadata| metadata.members.len() == 1) } + +pub(crate) fn is_enum_class<'db>(db: &'db dyn Db, ty: Type<'db>) -> bool { + match ty { + Type::ClassLiteral(class_literal) => enum_metadata(db, class_literal).is_some(), + _ => false, + } +} diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 6bbfe74f80..d163eb0373 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -102,6 +102,7 @@ use crate::types::diagnostic::{ report_invalid_generator_function_return_type, report_invalid_return_type, report_possibly_unbound_attribute, }; +use crate::types::enums::{enum_metadata, is_enum_class}; use crate::types::function::{ FunctionDecorators, FunctionLiteral, FunctionType, KnownFunction, OverloadLiteral, }; @@ -10033,14 +10034,19 @@ impl<'db> TypeInferenceBuilder<'db, '_> { // For enum values ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) => { let value_ty = self.infer_expression(value); - // TODO: Check that value type is enum otherwise return None - let ty = value_ty - .member(self.db(), &attr.id) - .place - .ignore_possibly_unbound() - .unwrap_or(Type::unknown()); - self.store_expression_type(parameters, ty); - ty + + if is_enum_class(self.db(), value_ty) { + let ty = value_ty + .member(self.db(), &attr.id) + .place + .ignore_possibly_unbound() + .unwrap_or(Type::unknown()); + self.store_expression_type(parameters, ty); + ty + } else { + self.store_expression_type(parameters, Type::unknown()); + return Err(vec![parameters]); + } } // for negative and positive numbers ast::Expr::UnaryOp(u)