Respect local subclasses in flake8-type-checking (#8768)

If you define a subclass of `pydantic.BaseModel`, and then a subclass of
_that_ class in the same file, we'll now correctly treat it as
runtime-evaluated.

Closes https://github.com/astral-sh/ruff/issues/7893.
This commit is contained in:
Charlie Marsh 2023-11-19 06:49:25 -08:00 committed by GitHub
parent 94178a0320
commit 00a015ca24
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 80 additions and 34 deletions

View file

@ -1,11 +1,12 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence # TCH003 from pandas import DataFrame
from pydantic import BaseModel
class MyBaseClass: class Parent(BaseModel):
pass ...
class Foo(MyBaseClass): class Child(Parent):
foo: Sequence baz: DataFrame

View file

@ -3,8 +3,9 @@ use rustc_hash::FxHashSet;
use ruff_diagnostics::{AlwaysFixableViolation, Diagnostic, Edit, Fix, Violation}; use ruff_diagnostics::{AlwaysFixableViolation, Diagnostic, Edit, Fix, Violation};
use ruff_macros::{derive_message_formats, violation}; use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::call_path::CallPath; use ruff_python_ast::call_path::CallPath;
use ruff_python_ast::helpers::map_subscript;
use ruff_python_ast::{ use ruff_python_ast::{
self as ast, Arguments, Expr, Operator, ParameterWithDefault, Parameters, Stmt, UnaryOp, self as ast, Expr, Operator, ParameterWithDefault, Parameters, Stmt, UnaryOp,
}; };
use ruff_python_semantic::{BindingId, ScopeKind, SemanticModel}; use ruff_python_semantic::{BindingId, ScopeKind, SemanticModel};
use ruff_source_file::Locator; use ruff_source_file::Locator;
@ -476,26 +477,25 @@ fn is_enum(class_def: &ast::StmtClassDef, semantic: &SemanticModel) -> bool {
semantic: &SemanticModel, semantic: &SemanticModel,
seen: &mut FxHashSet<BindingId>, seen: &mut FxHashSet<BindingId>,
) -> bool { ) -> bool {
let Some(Arguments { args: bases, .. }) = class_def.arguments.as_deref() else { class_def.bases().iter().any(|expr| {
return false;
};
bases.iter().any(|expr| {
// If the base class is `enum.Enum`, `enum.Flag`, etc., then this is an enum. // If the base class is `enum.Enum`, `enum.Flag`, etc., then this is an enum.
if semantic.resolve_call_path(expr).is_some_and(|call_path| { if semantic
matches!( .resolve_call_path(map_subscript(expr))
call_path.as_slice(), .is_some_and(|call_path| {
[ matches!(
"enum", call_path.as_slice(),
"Enum" | "Flag" | "IntEnum" | "IntFlag" | "StrEnum" | "ReprEnum" [
] "enum",
) "Enum" | "Flag" | "IntEnum" | "IntFlag" | "StrEnum" | "ReprEnum"
}) { ]
)
})
{
return true; return true;
} }
// If the base class extends `enum.Enum`, `enum.Flag`, etc., then this is an enum. // If the base class extends `enum.Enum`, `enum.Flag`, etc., then this is an enum.
if let Some(id) = semantic.lookup_attribute(expr) { if let Some(id) = semantic.lookup_attribute(map_subscript(expr)) {
if seen.insert(id) { if seen.insert(id) {
let binding = semantic.binding(id); let binding = semantic.binding(id);
if let Some(base_class) = binding if let Some(base_class) = binding

View file

@ -1,6 +1,8 @@
use ruff_python_ast::call_path::from_qualified_name; use ruff_python_ast::call_path::from_qualified_name;
use ruff_python_ast::helpers::{map_callable, map_subscript}; use ruff_python_ast::helpers::{map_callable, map_subscript};
use ruff_python_semantic::{Binding, BindingKind, ScopeKind, SemanticModel}; use ruff_python_ast::{self as ast};
use ruff_python_semantic::{Binding, BindingId, BindingKind, ScopeKind, SemanticModel};
use rustc_hash::FxHashSet;
pub(crate) fn is_valid_runtime_import(binding: &Binding, semantic: &SemanticModel) -> bool { pub(crate) fn is_valid_runtime_import(binding: &Binding, semantic: &SemanticModel) -> bool {
if matches!( if matches!(
@ -35,19 +37,54 @@ pub(crate) fn runtime_evaluated(
} }
fn runtime_evaluated_base_class(base_classes: &[String], semantic: &SemanticModel) -> bool { fn runtime_evaluated_base_class(base_classes: &[String], semantic: &SemanticModel) -> bool {
let ScopeKind::Class(class_def) = &semantic.current_scope().kind else { fn inner(
return false; class_def: &ast::StmtClassDef,
}; base_classes: &[String],
semantic: &SemanticModel,
seen: &mut FxHashSet<BindingId>,
) -> bool {
class_def.bases().iter().any(|expr| {
// If the base class is itself runtime-evaluated, then this is too.
// Ex) `class Foo(BaseModel): ...`
if semantic
.resolve_call_path(map_subscript(expr))
.is_some_and(|call_path| {
base_classes
.iter()
.any(|base_class| from_qualified_name(base_class) == call_path)
})
{
return true;
}
class_def.bases().iter().any(|base| { // If the base class extends a runtime-evaluated class, then this does too.
semantic // Ex) `class Bar(BaseModel): ...; class Foo(Bar): ...`
.resolve_call_path(map_subscript(base)) if let Some(id) = semantic.lookup_attribute(map_subscript(expr)) {
.is_some_and(|call_path| { if seen.insert(id) {
base_classes let binding = semantic.binding(id);
.iter() if let Some(base_class) = binding
.any(|base_class| from_qualified_name(base_class) == call_path) .kind
}) .as_class_definition()
}) .map(|id| &semantic.scopes[*id])
.and_then(|scope| scope.kind.as_class())
{
if inner(base_class, base_classes, semantic, seen) {
return true;
}
}
}
}
false
})
}
semantic
.current_scope()
.kind
.as_class()
.is_some_and(|class_def| {
inner(class_def, base_classes, semantic, &mut FxHashSet::default())
})
} }
fn runtime_evaluated_decorators(decorators: &[String], semantic: &SemanticModel) -> bool { fn runtime_evaluated_decorators(decorators: &[String], semantic: &SemanticModel) -> bool {

View file

@ -98,6 +98,10 @@ mod tests {
Rule::TypingOnlyStandardLibraryImport, Rule::TypingOnlyStandardLibraryImport,
Path::new("runtime_evaluated_base_classes_4.py") Path::new("runtime_evaluated_base_classes_4.py")
)] )]
#[test_case(
Rule::TypingOnlyThirdPartyImport,
Path::new("runtime_evaluated_base_classes_5.py")
)]
fn runtime_evaluated_base_classes(rule_code: Rule, path: &Path) -> Result<()> { fn runtime_evaluated_base_classes(rule_code: Rule, path: &Path) -> Result<()> {
let snapshot = format!("{}_{}", rule_code.as_ref(), path.to_string_lossy()); let snapshot = format!("{}_{}", rule_code.as_ref(), path.to_string_lossy());
let diagnostics = test_path( let diagnostics = test_path(