This commit is contained in:
Alex Waygood 2025-05-05 18:55:48 +01:00
parent fdee512083
commit ea06aefea3
5 changed files with 109 additions and 111 deletions

View file

@ -338,7 +338,7 @@ declare_lint! {
declare_lint! {
/// ## What it does
/// Checks for invalidly defined protocol classes.
/// Checks for protocol classes that will raise `TypeError` at runtime.
///
/// ## Why is this bad?
/// An invalidly defined protocol class may lead to the type checker inferring

View file

@ -58,9 +58,7 @@ use crate::semantic_index::narrowing_constraints::ConstraintKey;
use crate::semantic_index::symbol::{
FileScopeId, NodeWithScopeKind, NodeWithScopeRef, ScopeId, ScopeKind, ScopedSymbolId,
};
use crate::semantic_index::{
EagerSnapshotResult, SemanticIndex, semantic_index, symbol_table, use_def_map,
};
use crate::semantic_index::{EagerSnapshotResult, SemanticIndex, semantic_index};
use crate::symbol::{
Boundness, LookupError, builtins_module_scope, builtins_symbol, explicit_global_symbol,
global_symbol, module_type_implicit_global_declaration, module_type_implicit_global_symbol,
@ -114,7 +112,7 @@ use super::diagnostic::{
report_invalid_type_checking_constant, report_non_subscriptable,
report_possibly_unresolved_reference,
report_runtime_check_against_non_runtime_checkable_protocol, report_slice_step_size_zero,
report_undeclared_protocol_member, report_unresolved_reference,
report_unresolved_reference,
};
use super::generics::LegacyGenericBase;
use super::slots::check_class_slots;
@ -1079,53 +1077,7 @@ impl<'db> TypeInferenceBuilder<'db> {
}
if let Some(protocol) = class.into_protocol_class(self.db()) {
let interface = protocol.interface(self.db());
let class_symbol_table = symbol_table(self.db(), class.body_scope(self.db()));
for (symbol_id, mut bindings_iterator) in
use_def_map(self.db(), class.body_scope(self.db())).all_public_bindings()
{
let symbol_name = class_symbol_table.symbol(symbol_id).name();
if !interface.includes_member(self.db(), symbol_name) {
continue;
}
let has_declaration = class
.iter_mro(self.db(), None)
.filter_map(ClassBase::into_class)
.any(|superclass| {
let superclass_scope =
superclass.class_literal(self.db()).0.body_scope(self.db());
let Some(scoped_symbol_id) = symbol_table(self.db(), superclass_scope)
.symbol_id_by_name(symbol_name)
else {
return false;
};
symbol_from_declarations(
self.db(),
use_def_map(self.db(), superclass_scope)
.public_declarations(scoped_symbol_id),
)
.is_ok_and(|symbol| !symbol.symbol.is_unbound())
});
if has_declaration {
continue;
}
let Some(first_binding) = bindings_iterator.find_map(|binding| binding.binding)
else {
continue;
};
report_undeclared_protocol_member(
&self.context,
first_binding,
protocol,
class_symbol_table,
);
}
protocol.validate_members(&self.context);
}
}
}

View file

@ -5,11 +5,14 @@ use itertools::{Either, Itertools};
use ruff_python_ast::name::Name;
use crate::{
Db, FxOrderSet,
semantic_index::{symbol_table, use_def_map},
symbol::{symbol_from_bindings, symbol_from_declarations},
types::function::KnownFunction,
types::{ClassBase, ClassLiteral, Type, TypeMapping, TypeQualifiers, TypeVarInstance},
{Db, FxOrderSet},
types::{
ClassBase, ClassLiteral, Type, TypeMapping, TypeQualifiers, TypeVarInstance,
context::InferContext, diagnostic::report_undeclared_protocol_member,
function::KnownFunction,
},
};
impl<'db> ClassLiteral<'db> {
@ -47,6 +50,49 @@ impl<'db> ProtocolClassLiteral<'db> {
self.known_function_decorators(db)
.contains(&KnownFunction::RuntimeCheckable)
}
pub(super) fn validate_members(self, context: &InferContext) {
let db = context.db();
let interface = self.interface(db);
let class_symbol_table = symbol_table(db, self.body_scope(db));
for (symbol_id, mut bindings_iterator) in
use_def_map(db, self.body_scope(db)).all_public_bindings()
{
let symbol_name = class_symbol_table.symbol(symbol_id).name();
if !interface.includes_member(db, symbol_name) {
continue;
}
let has_declaration = self
.iter_mro(db, None)
.filter_map(ClassBase::into_class)
.any(|superclass| {
let superclass_scope = superclass.class_literal(db).0.body_scope(db);
let Some(scoped_symbol_id) =
symbol_table(db, superclass_scope).symbol_id_by_name(symbol_name)
else {
return false;
};
symbol_from_declarations(
db,
use_def_map(db, superclass_scope).public_declarations(scoped_symbol_id),
)
.is_ok_and(|symbol| !symbol.symbol.is_unbound())
});
if has_declaration {
continue;
}
let Some(first_binding) = bindings_iterator.find_map(|binding| binding.binding) else {
continue;
};
report_undeclared_protocol_member(context, first_binding, self, class_symbol_table);
}
}
}
impl<'db> Deref for ProtocolClassLiteral<'db> {