diff --git a/Cargo.lock b/Cargo.lock index 87fabedf36..7e109660db 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2499,6 +2499,8 @@ dependencies = [ "serde", "smallvec", "static_assertions", + "strum", + "strum_macros", "tempfile", "test-case", "thiserror 2.0.11", diff --git a/crates/red_knot_python_semantic/Cargo.toml b/crates/red_knot_python_semantic/Cargo.toml index e024e234de..c1ef36cb14 100644 --- a/crates/red_knot_python_semantic/Cargo.toml +++ b/crates/red_knot_python_semantic/Cargo.toml @@ -42,6 +42,8 @@ smallvec = { workspace = true } static_assertions = { workspace = true } test-case = { workspace = true } memchr = { workspace = true } +strum = { workspace = true} +strum_macros = { workspace = true} [dev-dependencies] ruff_db = { workspace = true, features = ["testing", "os"] } diff --git a/crates/red_knot_python_semantic/src/module_resolver/module.rs b/crates/red_knot_python_semantic/src/module_resolver/module.rs index d85a19a23d..14d7608f1e 100644 --- a/crates/red_knot_python_semantic/src/module_resolver/module.rs +++ b/crates/red_knot_python_semantic/src/module_resolver/module.rs @@ -1,4 +1,5 @@ use std::fmt::Formatter; +use std::str::FromStr; use std::sync::Arc; use ruff_db::files::File; @@ -98,10 +99,13 @@ impl ModuleKind { } /// Enumeration of various core stdlib modules in which important types are located -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, strum_macros::EnumString)] +#[cfg_attr(test, derive(strum_macros::EnumIter))] +#[strum(serialize_all = "snake_case")] pub enum KnownModule { Builtins, Types, + #[strum(serialize = "_typeshed")] Typeshed, TypingExtensions, Typing, @@ -139,21 +143,10 @@ impl KnownModule { search_path: &SearchPath, name: &ModuleName, ) -> Option { - if !search_path.is_standard_library() { - return None; - } - match name.as_str() { - "builtins" => Some(Self::Builtins), - "types" => Some(Self::Types), - "typing" => Some(Self::Typing), - "_typeshed" => Some(Self::Typeshed), - "typing_extensions" => Some(Self::TypingExtensions), - "sys" => Some(Self::Sys), - "abc" => Some(Self::Abc), - "collections" => Some(Self::Collections), - "inspect" => Some(Self::Inspect), - "knot_extensions" => Some(Self::KnotExtensions), - _ => None, + if search_path.is_standard_library() { + Self::from_str(name.as_str()).ok() + } else { + None } } @@ -168,4 +161,29 @@ impl KnownModule { pub const fn is_knot_extensions(self) -> bool { matches!(self, Self::KnotExtensions) } + + pub const fn is_inspect(self) -> bool { + matches!(self, Self::Inspect) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use strum::IntoEnumIterator; + + #[test] + fn known_module_roundtrip_from_str() { + let stdlib_search_path = SearchPath::vendored_stdlib(); + + for module in KnownModule::iter() { + let module_name = module.name(); + + assert_eq!( + KnownModule::try_from_search_path_and_name(&stdlib_search_path, &module_name), + Some(module), + "The strum `EnumString` implementation appears to be incorrect for `{module_name}`" + ); + } + } } diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 65f484f48b..17002e939d 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -1,4 +1,5 @@ use std::hash::Hash; +use std::str::FromStr; use bitflags::bitflags; use call::{CallDunderError, CallError}; @@ -3234,9 +3235,16 @@ impl<'db> FunctionType<'db> { /// Non-exhaustive enumeration of known functions (e.g. `builtins.reveal_type`, ...) that might /// have special behavior. -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, strum_macros::EnumString)] +#[strum(serialize_all = "snake_case")] +#[cfg_attr(test, derive(strum_macros::EnumIter, strum_macros::IntoStaticStr))] pub enum KnownFunction { - ConstraintFunction(KnownConstraintFunction), + /// `builtins.isinstance` + #[strum(serialize = "isinstance")] + IsInstance, + /// `builtins.issubclass` + #[strum(serialize = "issubclass")] + IsSubclass, /// `builtins.reveal_type`, `typing.reveal_type` or `typing_extensions.reveal_type` RevealType, /// `builtins.len` @@ -3280,9 +3288,10 @@ pub enum KnownFunction { } impl KnownFunction { - pub fn constraint_function(self) -> Option { + pub fn into_constraint_function(self) -> Option { match self { - Self::ConstraintFunction(f) => Some(f), + Self::IsInstance => Some(KnownConstraintFunction::IsInstance), + Self::IsSubclass => Some(KnownConstraintFunction::IsSubclass), _ => None, } } @@ -3292,30 +3301,7 @@ impl KnownFunction { definition: Definition<'db>, name: &str, ) -> Option { - let candidate = match name { - "isinstance" => Self::ConstraintFunction(KnownConstraintFunction::IsInstance), - "issubclass" => Self::ConstraintFunction(KnownConstraintFunction::IsSubclass), - "reveal_type" => Self::RevealType, - "len" => Self::Len, - "repr" => Self::Repr, - "final" => Self::Final, - "no_type_check" => Self::NoTypeCheck, - "assert_type" => Self::AssertType, - "cast" => Self::Cast, - "overload" => Self::Overload, - "getattr_static" => Self::GetattrStatic, - "static_assert" => Self::StaticAssert, - "is_subtype_of" => Self::IsSubtypeOf, - "is_disjoint_from" => Self::IsDisjointFrom, - "is_equivalent_to" => Self::IsEquivalentTo, - "is_assignable_to" => Self::IsAssignableTo, - "is_gradual_equivalent_to" => Self::IsGradualEquivalentTo, - "is_fully_static" => Self::IsFullyStatic, - "is_singleton" => Self::IsSingleton, - "is_single_valued" => Self::IsSingleValued, - _ => return None, - }; - + let candidate = Self::from_str(name).ok()?; candidate .check_module(file_to_module(db, definition.file(db))?.known()?) .then_some(candidate) @@ -3324,12 +3310,7 @@ impl KnownFunction { /// Return `true` if `self` is defined in `module` at runtime. const fn check_module(self, module: KnownModule) -> bool { match self { - Self::ConstraintFunction(constraint_function) => match constraint_function { - KnownConstraintFunction::IsInstance | KnownConstraintFunction::IsSubclass => { - module.is_builtins() - } - }, - Self::Len | Self::Repr => module.is_builtins(), + Self::IsInstance | Self::IsSubclass | Self::Len | Self::Repr => module.is_builtins(), Self::AssertType | Self::Cast | Self::Overload @@ -3338,9 +3319,7 @@ impl KnownFunction { | Self::NoTypeCheck => { matches!(module, KnownModule::Typing | KnownModule::TypingExtensions) } - Self::GetattrStatic => { - matches!(module, KnownModule::Inspect) - } + Self::GetattrStatic => module.is_inspect(), Self::IsAssignableTo | Self::IsDisjointFrom | Self::IsEquivalentTo @@ -3369,7 +3348,8 @@ impl KnownFunction { Self::AssertType => ParameterExpectations::ValueExpressionAndTypeExpression, Self::Cast => ParameterExpectations::TypeExpressionAndValueExpression, - Self::ConstraintFunction(_) + Self::IsInstance + | Self::IsSubclass | Self::Len | Self::Repr | Self::Overload @@ -4026,12 +4006,15 @@ static_assertions::assert_eq_size!(Type, [u8; 16]); pub(crate) mod tests { use super::*; use crate::db::tests::{setup_db, TestDbBuilder}; - use crate::symbol::{global_symbol, typing_extensions_symbol, typing_symbol}; + use crate::symbol::{ + global_symbol, known_module_symbol, typing_extensions_symbol, typing_symbol, + }; use ruff_db::files::system_path_to_file; use ruff_db::parsed::parsed_module; use ruff_db::system::DbWithTestSystem; use ruff_db::testing::assert_function_query_was_not_run; use ruff_python_ast::PythonVersion; + use strum::IntoEnumIterator; use test_case::test_case; /// Explicitly test for Python version <3.13 and >=3.13, to ensure that @@ -4176,4 +4159,55 @@ pub(crate) mod tests { .build() .is_todo()); } + + #[test] + fn known_function_roundtrip_from_str() { + let db = setup_db(); + + for function in KnownFunction::iter() { + let function_name: &'static str = function.into(); + + let module = match function { + KnownFunction::Len + | KnownFunction::Repr + | KnownFunction::IsInstance + | KnownFunction::IsSubclass => KnownModule::Builtins, + + KnownFunction::GetattrStatic => KnownModule::Inspect, + + KnownFunction::Cast + | KnownFunction::Final + | KnownFunction::Overload + | KnownFunction::RevealType + | KnownFunction::AssertType + | KnownFunction::NoTypeCheck => KnownModule::TypingExtensions, + + KnownFunction::IsSingleton + | KnownFunction::IsSubtypeOf + | KnownFunction::StaticAssert + | KnownFunction::IsFullyStatic + | KnownFunction::IsDisjointFrom + | KnownFunction::IsSingleValued + | KnownFunction::IsAssignableTo + | KnownFunction::IsEquivalentTo + | KnownFunction::IsGradualEquivalentTo => KnownModule::KnotExtensions, + }; + + let function_body_scope = known_module_symbol(&db, module, function_name) + .expect_type() + .expect_function_literal() + .body_scope(&db); + + let function_node = function_body_scope.node(&db).expect_function(); + + let function_definition = + semantic_index(&db, function_body_scope.file(&db)).definition(function_node); + + assert_eq!( + KnownFunction::try_from_definition_and_name(&db, function_definition, function_name), + Some(function), + "The strum `EnumString` implementation appears to be incorrect for `{function_name}`" + ); + } + } } diff --git a/crates/red_knot_python_semantic/src/types/class.rs b/crates/red_knot_python_semantic/src/types/class.rs index 644cfc0759..cdd35a6966 100644 --- a/crates/red_knot_python_semantic/src/types/class.rs +++ b/crates/red_knot_python_semantic/src/types/class.rs @@ -670,6 +670,7 @@ impl<'db> From> for Type<'db> { /// places. /// Note: good candidates are any classes in `[crate::module_resolver::module::KnownModule]` #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[cfg_attr(test, derive(strum_macros::EnumIter))] pub enum KnownClass { // To figure out where an stdlib symbol is defined, you can go into `crates/red_knot_vendored` // and grep for the symbol name in any `.pyi` file. @@ -1026,10 +1027,13 @@ impl<'db> KnownClass { } } - pub fn try_from_file_and_name(db: &dyn Db, file: File, class_name: &str) -> Option { - // Note: if this becomes hard to maintain (as rust can't ensure at compile time that all - // variants of `Self` are covered), we might use a macro (in-house or dependency) - // See: https://stackoverflow.com/q/39070244 + pub(super) fn try_from_file_and_name( + db: &dyn Db, + file: File, + class_name: &str, + ) -> Option { + // We assert that this match is exhaustive over the right-hand side in the unit test + // `known_class_roundtrip_from_str()` let candidate = match class_name { "bool" => Self::Bool, "object" => Self::Object, @@ -1498,3 +1502,26 @@ pub(super) enum MetaclassErrorKind<'db> { /// The metaclass is of a union type whose some members are not callable PartlyNotCallable(Type<'db>), } + +#[cfg(test)] +mod tests { + use super::*; + use crate::db::tests::setup_db; + use crate::module_resolver::resolve_module; + use strum::IntoEnumIterator; + + #[test] + fn known_class_roundtrip_from_str() { + let db = setup_db(); + for class in KnownClass::iter() { + let class_name = class.as_str(&db); + let class_module = resolve_module(&db, &class.canonical_module(&db).name()).unwrap(); + + assert_eq!( + KnownClass::try_from_file_and_name(&db, class_module.file(), class_name), + Some(class), + "`KnownClass::candidate_from_str` appears to be missing a case for `{class_name}`" + ); + } + } +} diff --git a/crates/red_knot_python_semantic/src/types/narrow.rs b/crates/red_knot_python_semantic/src/types/narrow.rs index eb9bf231ad..bd1d975011 100644 --- a/crates/red_knot_python_semantic/src/types/narrow.rs +++ b/crates/red_knot_python_semantic/src/types/narrow.rs @@ -8,8 +8,8 @@ use crate::semantic_index::symbol::{ScopeId, ScopedSymbolId, SymbolTable}; use crate::semantic_index::symbol_table; use crate::types::infer::infer_same_file_expression_type; use crate::types::{ - infer_expression_types, IntersectionBuilder, KnownClass, KnownFunction, SubclassOfType, - Truthiness, Type, UnionBuilder, + infer_expression_types, IntersectionBuilder, KnownClass, SubclassOfType, Truthiness, Type, + UnionBuilder, }; use crate::Db; use itertools::Itertools; @@ -429,9 +429,7 @@ impl<'db> NarrowingConstraintsBuilder<'db> { // and `issubclass`, for example `isinstance(x, str | (int | float))`. match callable_ty { Type::FunctionLiteral(function_type) if expr_call.arguments.keywords.is_empty() => { - let function = function_type - .known(self.db) - .and_then(KnownFunction::constraint_function)?; + let function = function_type.known(self.db)?.into_constraint_function()?; let [ast::Expr::Name(ast::ExprName { id, .. }), class_info] = &*expr_call.arguments.args