[ty] Fix the inferred interface of specialized generic protocols (#19866)

This commit is contained in:
Alex Waygood 2025-08-27 18:16:15 +01:00 committed by GitHub
parent 7d0c8e045c
commit ce1dc21e7e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 200 additions and 57 deletions

View file

@ -748,6 +748,10 @@ impl<'db> Bindings<'db> {
Some(KnownFunction::IsProtocol) => {
if let [Some(ty)] = overload.parameter_types() {
// We evaluate this to `Literal[True]` only if the runtime function `typing.is_protocol`
// would return `True` for the given type. Internally we consider `SupportsAbs[int]` to
// be a "(specialised) protocol class", but `typing.is_protocol(SupportsAbs[int])` returns
// `False` at runtime, so we do not set the return type to `Literal[True]` in this case.
overload.set_return_type(Type::BooleanLiteral(
ty.into_class_literal()
.is_some_and(|class| class.is_protocol(db)),
@ -756,6 +760,9 @@ impl<'db> Bindings<'db> {
}
Some(KnownFunction::GetProtocolMembers) => {
// Similarly to `is_protocol`, we only evaluate to this a frozenset of literal strings if a
// class-literal is passed in, not if a generic alias is passed in, to emulate the behaviour
// of `typing.get_protocol_members` at runtime.
if let [Some(Type::ClassLiteral(class))] = overload.parameter_types() {
if let Some(protocol_class) = class.into_protocol_class(db) {
let member_names = protocol_class

View file

@ -1198,6 +1198,14 @@ impl<'db> ClassType<'db> {
}
}
}
pub(super) fn is_protocol(self, db: &'db dyn Db) -> bool {
self.class_literal(db).0.is_protocol(db)
}
pub(super) fn header_span(self, db: &'db dyn Db) -> Span {
self.class_literal(db).0.header_span(db)
}
}
impl<'db> From<GenericAlias<'db>> for ClassType<'db> {

View file

@ -20,7 +20,7 @@ use crate::types::string_annotation::{
use crate::types::{
DynamicType, LintDiagnosticGuard, Protocol, ProtocolInstanceType, SubclassOfInner, binding_type,
};
use crate::types::{SpecialFormType, Type, protocol_class::ProtocolClassLiteral};
use crate::types::{SpecialFormType, Type, protocol_class::ProtocolClass};
use crate::util::diagnostics::format_enumeration;
use crate::{Db, FxIndexMap, FxOrderMap, Module, ModuleName, Program, declare_lint};
use itertools::Itertools;
@ -2467,7 +2467,7 @@ pub(crate) fn add_type_expression_reference_link<'db, 'ctx>(
pub(crate) fn report_runtime_check_against_non_runtime_checkable_protocol(
context: &InferContext,
call: &ast::ExprCall,
protocol: ProtocolClassLiteral,
protocol: ProtocolClass,
function: KnownFunction,
) {
let Some(builder) = context.report_lint(&INVALID_ARGUMENT_TYPE, call) else {
@ -2504,7 +2504,7 @@ pub(crate) fn report_runtime_check_against_non_runtime_checkable_protocol(
pub(crate) fn report_attempted_protocol_instantiation(
context: &InferContext,
call: &ast::ExprCall,
protocol: ProtocolClassLiteral,
protocol: ProtocolClass,
) {
let Some(builder) = context.report_lint(&CALL_NON_CALLABLE, call) else {
return;
@ -2529,7 +2529,7 @@ pub(crate) fn report_attempted_protocol_instantiation(
pub(crate) fn report_undeclared_protocol_member(
context: &InferContext,
definition: Definition,
protocol_class: ProtocolClassLiteral,
protocol_class: ProtocolClass,
class_symbol_table: &PlaceTable,
) {
/// We want to avoid suggesting an annotation for e.g. `x = None`,

View file

@ -1433,7 +1433,7 @@ impl KnownFunction {
return;
};
let Some(protocol_class) = param_type
.into_class_literal()
.to_class_type(db)
.and_then(|class| class.into_protocol_class(db))
else {
report_bad_argument_to_protocol_interface(

View file

@ -1216,8 +1216,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
if is_protocol
&& !(base_class.class_literal(self.db()).0.is_protocol(self.db())
|| base_class.is_known(self.db(), KnownClass::Object))
&& !(base_class.is_protocol(self.db()) || base_class.is_object(self.db()))
{
if let Some(builder) = self
.context
@ -6249,11 +6248,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// subclasses of the protocol to be passed to parameters that accept `type[SomeProtocol]`.
// <https://typing.python.org/en/latest/spec/protocol.html#type-and-class-objects-vs-protocols>.
if !callable_type.is_subclass_of() {
if let Some(protocol) = class
.class_literal(self.db())
.0
.into_protocol_class(self.db())
{
if let Some(protocol) = class.into_protocol_class(self.db()) {
report_attempted_protocol_instantiation(
&self.context,
call_expression,

View file

@ -645,10 +645,8 @@ impl<'db> Protocol<'db> {
fn interface(self, db: &'db dyn Db) -> ProtocolInterface<'db> {
match self {
Self::FromClass(class) => class
.class_literal(db)
.0
.into_protocol_class(db)
.expect("Protocol class literal should be a protocol class")
.expect("Class wrapped by `Protocol` should be a protocol class")
.interface(db),
Self::Synthesized(synthesized) => synthesized.interface(),
}

View file

@ -9,6 +9,7 @@ use rustc_hash::FxHashMap;
use super::TypeVarVariance;
use crate::semantic_index::place::ScopedPlaceId;
use crate::semantic_index::{SemanticIndex, place_table};
use crate::types::ClassType;
use crate::types::context::InferContext;
use crate::types::diagnostic::report_undeclared_protocol_member;
use crate::{
@ -26,16 +27,24 @@ use crate::{
impl<'db> ClassLiteral<'db> {
/// Returns `Some` if this is a protocol class, `None` otherwise.
pub(super) fn into_protocol_class(self, db: &'db dyn Db) -> Option<ProtocolClassLiteral<'db>> {
self.is_protocol(db).then_some(ProtocolClassLiteral(self))
pub(super) fn into_protocol_class(self, db: &'db dyn Db) -> Option<ProtocolClass<'db>> {
self.is_protocol(db)
.then_some(ProtocolClass(ClassType::NonGeneric(self)))
}
}
impl<'db> ClassType<'db> {
/// Returns `Some` if this is a protocol class, `None` otherwise.
pub(super) fn into_protocol_class(self, db: &'db dyn Db) -> Option<ProtocolClass<'db>> {
self.is_protocol(db).then_some(ProtocolClass(self))
}
}
/// Representation of a single `Protocol` class definition.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub(super) struct ProtocolClassLiteral<'db>(ClassLiteral<'db>);
pub(super) struct ProtocolClass<'db>(ClassType<'db>);
impl<'db> ProtocolClassLiteral<'db> {
impl<'db> ProtocolClass<'db> {
/// Returns the protocol members of this class.
///
/// A protocol's members define the interface declared by the protocol.
@ -56,7 +65,9 @@ impl<'db> ProtocolClassLiteral<'db> {
}
pub(super) fn is_runtime_checkable(self, db: &'db dyn Db) -> bool {
self.known_function_decorators(db)
self.class_literal(db)
.0
.known_function_decorators(db)
.contains(&KnownFunction::RuntimeCheckable)
}
@ -66,10 +77,11 @@ impl<'db> ProtocolClassLiteral<'db> {
pub(super) fn validate_members(self, context: &InferContext, index: &SemanticIndex<'db>) {
let db = context.db();
let interface = self.interface(db);
let class_place_table = index.place_table(self.body_scope(db).file_scope_id(db));
let body_scope = self.class_literal(db).0.body_scope(db);
let class_place_table = index.place_table(body_scope.file_scope_id(db));
for (symbol_id, mut bindings_iterator) in
use_def_map(db, self.body_scope(db)).all_end_of_scope_symbol_bindings()
use_def_map(db, body_scope).all_end_of_scope_symbol_bindings()
{
let symbol_name = class_place_table.symbol(symbol_id).name();
@ -77,27 +89,27 @@ impl<'db> ProtocolClassLiteral<'db> {
continue;
}
let has_declaration = self
.iter_mro(db, None)
.filter_map(ClassBase::into_class)
.any(|superclass| {
let superclass_scope = superclass.class_literal(db).0.body_scope(db);
let Some(scoped_symbol_id) =
place_table(db, superclass_scope).symbol_id(symbol_name)
else {
return false;
};
!place_from_declarations(
db,
index
.use_def_map(superclass_scope.file_scope_id(db))
.end_of_scope_declarations(ScopedPlaceId::Symbol(scoped_symbol_id)),
)
.into_place_and_conflicting_declarations()
.0
.place
.is_unbound()
});
let has_declaration =
self.iter_mro(db)
.filter_map(ClassBase::into_class)
.any(|superclass| {
let superclass_scope = superclass.class_literal(db).0.body_scope(db);
let Some(scoped_symbol_id) =
place_table(db, superclass_scope).symbol_id(symbol_name)
else {
return false;
};
!place_from_declarations(
db,
index
.use_def_map(superclass_scope.file_scope_id(db))
.end_of_scope_declarations(ScopedPlaceId::Symbol(scoped_symbol_id)),
)
.into_place_and_conflicting_declarations()
.0
.place
.is_unbound()
});
if has_declaration {
continue;
@ -114,8 +126,8 @@ impl<'db> ProtocolClassLiteral<'db> {
}
}
impl<'db> Deref for ProtocolClassLiteral<'db> {
type Target = ClassLiteral<'db>;
impl<'db> Deref for ProtocolClass<'db> {
type Target = ClassType<'db>;
fn deref(&self) -> &Self::Target {
&self.0
@ -622,16 +634,19 @@ impl BoundOnClass {
#[salsa::tracked(cycle_fn=proto_interface_cycle_recover, cycle_initial=proto_interface_cycle_initial, heap_size=ruff_memory_usage::heap_size)]
fn cached_protocol_interface<'db>(
db: &'db dyn Db,
class: ClassLiteral<'db>,
class: ClassType<'db>,
) -> ProtocolInterface<'db> {
let mut members = BTreeMap::default();
for parent_protocol in class
.iter_mro(db, None)
for (parent_protocol, specialization) in class
.iter_mro(db)
.filter_map(ClassBase::into_class)
.filter_map(|class| class.class_literal(db).0.into_protocol_class(db))
.filter_map(|class| {
let (class, specialization) = class.class_literal(db);
Some((class.into_protocol_class(db)?, specialization))
})
{
let parent_scope = parent_protocol.body_scope(db);
let parent_scope = parent_protocol.class_literal(db).0.body_scope(db);
let use_def_map = use_def_map(db, parent_scope);
let place_table = place_table(db, parent_scope);
let mut direct_members = FxHashMap::default();
@ -676,6 +691,8 @@ fn cached_protocol_interface<'db>(
continue;
}
let ty = ty.apply_optional_specialization(db, specialization);
let member = match ty {
Type::PropertyInstance(property) => ProtocolMemberKind::Property(property),
Type::Callable(callable)
@ -702,19 +719,21 @@ fn cached_protocol_interface<'db>(
ProtocolInterface::new(db, members)
}
// If we use `expect(clippy::trivially_copy_pass_by_ref)` here,
// the lint expectation is unfulfilled on WASM
#[allow(clippy::trivially_copy_pass_by_ref)]
fn proto_interface_cycle_recover<'db>(
_db: &dyn Db,
_value: &ProtocolInterface<'db>,
_count: u32,
_class: ClassLiteral<'db>,
_class: ClassType<'db>,
) -> salsa::CycleRecoveryAction<ProtocolInterface<'db>> {
salsa::CycleRecoveryAction::Iterate
}
fn proto_interface_cycle_initial<'db>(
db: &'db dyn Db,
_class: ClassLiteral<'db>,
_class: ClassType<'db>,
) -> ProtocolInterface<'db> {
ProtocolInterface::empty(db)
}