[red-knot] Add a test to ensure that KnownClass::try_from_file_and_name() is kept up to date (#16326)

This commit is contained in:
Alex Waygood 2025-02-24 12:14:20 +00:00 committed by GitHub
parent 320a3c68ae
commit 5bac4f6bd4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 145 additions and 64 deletions

2
Cargo.lock generated
View file

@ -2499,6 +2499,8 @@ dependencies = [
"serde", "serde",
"smallvec", "smallvec",
"static_assertions", "static_assertions",
"strum",
"strum_macros",
"tempfile", "tempfile",
"test-case", "test-case",
"thiserror 2.0.11", "thiserror 2.0.11",

View file

@ -42,6 +42,8 @@ smallvec = { workspace = true }
static_assertions = { workspace = true } static_assertions = { workspace = true }
test-case = { workspace = true } test-case = { workspace = true }
memchr = { workspace = true } memchr = { workspace = true }
strum = { workspace = true}
strum_macros = { workspace = true}
[dev-dependencies] [dev-dependencies]
ruff_db = { workspace = true, features = ["testing", "os"] } ruff_db = { workspace = true, features = ["testing", "os"] }

View file

@ -1,4 +1,5 @@
use std::fmt::Formatter; use std::fmt::Formatter;
use std::str::FromStr;
use std::sync::Arc; use std::sync::Arc;
use ruff_db::files::File; use ruff_db::files::File;
@ -98,10 +99,13 @@ impl ModuleKind {
} }
/// Enumeration of various core stdlib modules in which important types are located /// Enumeration of various core stdlib modules in which important types are located
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, strum_macros::EnumString)]
#[cfg_attr(test, derive(strum_macros::EnumIter))]
#[strum(serialize_all = "snake_case")]
pub enum KnownModule { pub enum KnownModule {
Builtins, Builtins,
Types, Types,
#[strum(serialize = "_typeshed")]
Typeshed, Typeshed,
TypingExtensions, TypingExtensions,
Typing, Typing,
@ -139,21 +143,10 @@ impl KnownModule {
search_path: &SearchPath, search_path: &SearchPath,
name: &ModuleName, name: &ModuleName,
) -> Option<Self> { ) -> Option<Self> {
if !search_path.is_standard_library() { if search_path.is_standard_library() {
return None; Self::from_str(name.as_str()).ok()
} } else {
match name.as_str() { None
"builtins" => Some(Self::Builtins),
"types" => Some(Self::Types),
"typing" => Some(Self::Typing),
"_typeshed" => Some(Self::Typeshed),
"typing_extensions" => Some(Self::TypingExtensions),
"sys" => Some(Self::Sys),
"abc" => Some(Self::Abc),
"collections" => Some(Self::Collections),
"inspect" => Some(Self::Inspect),
"knot_extensions" => Some(Self::KnotExtensions),
_ => None,
} }
} }
@ -168,4 +161,29 @@ impl KnownModule {
pub const fn is_knot_extensions(self) -> bool { pub const fn is_knot_extensions(self) -> bool {
matches!(self, Self::KnotExtensions) matches!(self, Self::KnotExtensions)
} }
pub const fn is_inspect(self) -> bool {
matches!(self, Self::Inspect)
}
}
#[cfg(test)]
mod tests {
use super::*;
use strum::IntoEnumIterator;
#[test]
fn known_module_roundtrip_from_str() {
let stdlib_search_path = SearchPath::vendored_stdlib();
for module in KnownModule::iter() {
let module_name = module.name();
assert_eq!(
KnownModule::try_from_search_path_and_name(&stdlib_search_path, &module_name),
Some(module),
"The strum `EnumString` implementation appears to be incorrect for `{module_name}`"
);
}
}
} }

View file

@ -1,4 +1,5 @@
use std::hash::Hash; use std::hash::Hash;
use std::str::FromStr;
use bitflags::bitflags; use bitflags::bitflags;
use call::{CallDunderError, CallError}; use call::{CallDunderError, CallError};
@ -3234,9 +3235,16 @@ impl<'db> FunctionType<'db> {
/// Non-exhaustive enumeration of known functions (e.g. `builtins.reveal_type`, ...) that might /// Non-exhaustive enumeration of known functions (e.g. `builtins.reveal_type`, ...) that might
/// have special behavior. /// have special behavior.
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, strum_macros::EnumString)]
#[strum(serialize_all = "snake_case")]
#[cfg_attr(test, derive(strum_macros::EnumIter, strum_macros::IntoStaticStr))]
pub enum KnownFunction { pub enum KnownFunction {
ConstraintFunction(KnownConstraintFunction), /// `builtins.isinstance`
#[strum(serialize = "isinstance")]
IsInstance,
/// `builtins.issubclass`
#[strum(serialize = "issubclass")]
IsSubclass,
/// `builtins.reveal_type`, `typing.reveal_type` or `typing_extensions.reveal_type` /// `builtins.reveal_type`, `typing.reveal_type` or `typing_extensions.reveal_type`
RevealType, RevealType,
/// `builtins.len` /// `builtins.len`
@ -3280,9 +3288,10 @@ pub enum KnownFunction {
} }
impl KnownFunction { impl KnownFunction {
pub fn constraint_function(self) -> Option<KnownConstraintFunction> { pub fn into_constraint_function(self) -> Option<KnownConstraintFunction> {
match self { match self {
Self::ConstraintFunction(f) => Some(f), Self::IsInstance => Some(KnownConstraintFunction::IsInstance),
Self::IsSubclass => Some(KnownConstraintFunction::IsSubclass),
_ => None, _ => None,
} }
} }
@ -3292,30 +3301,7 @@ impl KnownFunction {
definition: Definition<'db>, definition: Definition<'db>,
name: &str, name: &str,
) -> Option<Self> { ) -> Option<Self> {
let candidate = match name { let candidate = Self::from_str(name).ok()?;
"isinstance" => Self::ConstraintFunction(KnownConstraintFunction::IsInstance),
"issubclass" => Self::ConstraintFunction(KnownConstraintFunction::IsSubclass),
"reveal_type" => Self::RevealType,
"len" => Self::Len,
"repr" => Self::Repr,
"final" => Self::Final,
"no_type_check" => Self::NoTypeCheck,
"assert_type" => Self::AssertType,
"cast" => Self::Cast,
"overload" => Self::Overload,
"getattr_static" => Self::GetattrStatic,
"static_assert" => Self::StaticAssert,
"is_subtype_of" => Self::IsSubtypeOf,
"is_disjoint_from" => Self::IsDisjointFrom,
"is_equivalent_to" => Self::IsEquivalentTo,
"is_assignable_to" => Self::IsAssignableTo,
"is_gradual_equivalent_to" => Self::IsGradualEquivalentTo,
"is_fully_static" => Self::IsFullyStatic,
"is_singleton" => Self::IsSingleton,
"is_single_valued" => Self::IsSingleValued,
_ => return None,
};
candidate candidate
.check_module(file_to_module(db, definition.file(db))?.known()?) .check_module(file_to_module(db, definition.file(db))?.known()?)
.then_some(candidate) .then_some(candidate)
@ -3324,12 +3310,7 @@ impl KnownFunction {
/// Return `true` if `self` is defined in `module` at runtime. /// Return `true` if `self` is defined in `module` at runtime.
const fn check_module(self, module: KnownModule) -> bool { const fn check_module(self, module: KnownModule) -> bool {
match self { match self {
Self::ConstraintFunction(constraint_function) => match constraint_function { Self::IsInstance | Self::IsSubclass | Self::Len | Self::Repr => module.is_builtins(),
KnownConstraintFunction::IsInstance | KnownConstraintFunction::IsSubclass => {
module.is_builtins()
}
},
Self::Len | Self::Repr => module.is_builtins(),
Self::AssertType Self::AssertType
| Self::Cast | Self::Cast
| Self::Overload | Self::Overload
@ -3338,9 +3319,7 @@ impl KnownFunction {
| Self::NoTypeCheck => { | Self::NoTypeCheck => {
matches!(module, KnownModule::Typing | KnownModule::TypingExtensions) matches!(module, KnownModule::Typing | KnownModule::TypingExtensions)
} }
Self::GetattrStatic => { Self::GetattrStatic => module.is_inspect(),
matches!(module, KnownModule::Inspect)
}
Self::IsAssignableTo Self::IsAssignableTo
| Self::IsDisjointFrom | Self::IsDisjointFrom
| Self::IsEquivalentTo | Self::IsEquivalentTo
@ -3369,7 +3348,8 @@ impl KnownFunction {
Self::AssertType => ParameterExpectations::ValueExpressionAndTypeExpression, Self::AssertType => ParameterExpectations::ValueExpressionAndTypeExpression,
Self::Cast => ParameterExpectations::TypeExpressionAndValueExpression, Self::Cast => ParameterExpectations::TypeExpressionAndValueExpression,
Self::ConstraintFunction(_) Self::IsInstance
| Self::IsSubclass
| Self::Len | Self::Len
| Self::Repr | Self::Repr
| Self::Overload | Self::Overload
@ -4026,12 +4006,15 @@ static_assertions::assert_eq_size!(Type, [u8; 16]);
pub(crate) mod tests { pub(crate) mod tests {
use super::*; use super::*;
use crate::db::tests::{setup_db, TestDbBuilder}; use crate::db::tests::{setup_db, TestDbBuilder};
use crate::symbol::{global_symbol, typing_extensions_symbol, typing_symbol}; use crate::symbol::{
global_symbol, known_module_symbol, typing_extensions_symbol, typing_symbol,
};
use ruff_db::files::system_path_to_file; use ruff_db::files::system_path_to_file;
use ruff_db::parsed::parsed_module; use ruff_db::parsed::parsed_module;
use ruff_db::system::DbWithTestSystem; use ruff_db::system::DbWithTestSystem;
use ruff_db::testing::assert_function_query_was_not_run; use ruff_db::testing::assert_function_query_was_not_run;
use ruff_python_ast::PythonVersion; use ruff_python_ast::PythonVersion;
use strum::IntoEnumIterator;
use test_case::test_case; use test_case::test_case;
/// Explicitly test for Python version <3.13 and >=3.13, to ensure that /// Explicitly test for Python version <3.13 and >=3.13, to ensure that
@ -4176,4 +4159,55 @@ pub(crate) mod tests {
.build() .build()
.is_todo()); .is_todo());
} }
#[test]
fn known_function_roundtrip_from_str() {
let db = setup_db();
for function in KnownFunction::iter() {
let function_name: &'static str = function.into();
let module = match function {
KnownFunction::Len
| KnownFunction::Repr
| KnownFunction::IsInstance
| KnownFunction::IsSubclass => KnownModule::Builtins,
KnownFunction::GetattrStatic => KnownModule::Inspect,
KnownFunction::Cast
| KnownFunction::Final
| KnownFunction::Overload
| KnownFunction::RevealType
| KnownFunction::AssertType
| KnownFunction::NoTypeCheck => KnownModule::TypingExtensions,
KnownFunction::IsSingleton
| KnownFunction::IsSubtypeOf
| KnownFunction::StaticAssert
| KnownFunction::IsFullyStatic
| KnownFunction::IsDisjointFrom
| KnownFunction::IsSingleValued
| KnownFunction::IsAssignableTo
| KnownFunction::IsEquivalentTo
| KnownFunction::IsGradualEquivalentTo => KnownModule::KnotExtensions,
};
let function_body_scope = known_module_symbol(&db, module, function_name)
.expect_type()
.expect_function_literal()
.body_scope(&db);
let function_node = function_body_scope.node(&db).expect_function();
let function_definition =
semantic_index(&db, function_body_scope.file(&db)).definition(function_node);
assert_eq!(
KnownFunction::try_from_definition_and_name(&db, function_definition, function_name),
Some(function),
"The strum `EnumString` implementation appears to be incorrect for `{function_name}`"
);
}
}
} }

View file

@ -670,6 +670,7 @@ impl<'db> From<InstanceType<'db>> for Type<'db> {
/// places. /// places.
/// Note: good candidates are any classes in `[crate::module_resolver::module::KnownModule]` /// Note: good candidates are any classes in `[crate::module_resolver::module::KnownModule]`
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(test, derive(strum_macros::EnumIter))]
pub enum KnownClass { pub enum KnownClass {
// To figure out where an stdlib symbol is defined, you can go into `crates/red_knot_vendored` // To figure out where an stdlib symbol is defined, you can go into `crates/red_knot_vendored`
// and grep for the symbol name in any `.pyi` file. // and grep for the symbol name in any `.pyi` file.
@ -1026,10 +1027,13 @@ impl<'db> KnownClass {
} }
} }
pub fn try_from_file_and_name(db: &dyn Db, file: File, class_name: &str) -> Option<Self> { pub(super) fn try_from_file_and_name(
// Note: if this becomes hard to maintain (as rust can't ensure at compile time that all db: &dyn Db,
// variants of `Self` are covered), we might use a macro (in-house or dependency) file: File,
// See: https://stackoverflow.com/q/39070244 class_name: &str,
) -> Option<Self> {
// We assert that this match is exhaustive over the right-hand side in the unit test
// `known_class_roundtrip_from_str()`
let candidate = match class_name { let candidate = match class_name {
"bool" => Self::Bool, "bool" => Self::Bool,
"object" => Self::Object, "object" => Self::Object,
@ -1498,3 +1502,26 @@ pub(super) enum MetaclassErrorKind<'db> {
/// The metaclass is of a union type whose some members are not callable /// The metaclass is of a union type whose some members are not callable
PartlyNotCallable(Type<'db>), PartlyNotCallable(Type<'db>),
} }
#[cfg(test)]
mod tests {
use super::*;
use crate::db::tests::setup_db;
use crate::module_resolver::resolve_module;
use strum::IntoEnumIterator;
#[test]
fn known_class_roundtrip_from_str() {
let db = setup_db();
for class in KnownClass::iter() {
let class_name = class.as_str(&db);
let class_module = resolve_module(&db, &class.canonical_module(&db).name()).unwrap();
assert_eq!(
KnownClass::try_from_file_and_name(&db, class_module.file(), class_name),
Some(class),
"`KnownClass::candidate_from_str` appears to be missing a case for `{class_name}`"
);
}
}
}

View file

@ -8,8 +8,8 @@ use crate::semantic_index::symbol::{ScopeId, ScopedSymbolId, SymbolTable};
use crate::semantic_index::symbol_table; use crate::semantic_index::symbol_table;
use crate::types::infer::infer_same_file_expression_type; use crate::types::infer::infer_same_file_expression_type;
use crate::types::{ use crate::types::{
infer_expression_types, IntersectionBuilder, KnownClass, KnownFunction, SubclassOfType, infer_expression_types, IntersectionBuilder, KnownClass, SubclassOfType, Truthiness, Type,
Truthiness, Type, UnionBuilder, UnionBuilder,
}; };
use crate::Db; use crate::Db;
use itertools::Itertools; use itertools::Itertools;
@ -429,9 +429,7 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
// and `issubclass`, for example `isinstance(x, str | (int | float))`. // and `issubclass`, for example `isinstance(x, str | (int | float))`.
match callable_ty { match callable_ty {
Type::FunctionLiteral(function_type) if expr_call.arguments.keywords.is_empty() => { Type::FunctionLiteral(function_type) if expr_call.arguments.keywords.is_empty() => {
let function = function_type let function = function_type.known(self.db)?.into_constraint_function()?;
.known(self.db)
.and_then(KnownFunction::constraint_function)?;
let [ast::Expr::Name(ast::ExprName { id, .. }), class_info] = let [ast::Expr::Name(ast::ExprName { id, .. }), class_info] =
&*expr_call.arguments.args &*expr_call.arguments.args