mirror of
https://github.com/astral-sh/ruff.git
synced 2025-09-29 13:24:57 +00:00
[ty] Fix the inferred interface of specialized generic protocols (#19866)
This commit is contained in:
parent
7d0c8e045c
commit
ce1dc21e7e
9 changed files with 200 additions and 57 deletions
|
@ -748,6 +748,10 @@ impl<'db> Bindings<'db> {
|
|||
|
||||
Some(KnownFunction::IsProtocol) => {
|
||||
if let [Some(ty)] = overload.parameter_types() {
|
||||
// We evaluate this to `Literal[True]` only if the runtime function `typing.is_protocol`
|
||||
// would return `True` for the given type. Internally we consider `SupportsAbs[int]` to
|
||||
// be a "(specialised) protocol class", but `typing.is_protocol(SupportsAbs[int])` returns
|
||||
// `False` at runtime, so we do not set the return type to `Literal[True]` in this case.
|
||||
overload.set_return_type(Type::BooleanLiteral(
|
||||
ty.into_class_literal()
|
||||
.is_some_and(|class| class.is_protocol(db)),
|
||||
|
@ -756,6 +760,9 @@ impl<'db> Bindings<'db> {
|
|||
}
|
||||
|
||||
Some(KnownFunction::GetProtocolMembers) => {
|
||||
// Similarly to `is_protocol`, we only evaluate to this a frozenset of literal strings if a
|
||||
// class-literal is passed in, not if a generic alias is passed in, to emulate the behaviour
|
||||
// of `typing.get_protocol_members` at runtime.
|
||||
if let [Some(Type::ClassLiteral(class))] = overload.parameter_types() {
|
||||
if let Some(protocol_class) = class.into_protocol_class(db) {
|
||||
let member_names = protocol_class
|
||||
|
|
|
@ -1198,6 +1198,14 @@ impl<'db> ClassType<'db> {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn is_protocol(self, db: &'db dyn Db) -> bool {
|
||||
self.class_literal(db).0.is_protocol(db)
|
||||
}
|
||||
|
||||
pub(super) fn header_span(self, db: &'db dyn Db) -> Span {
|
||||
self.class_literal(db).0.header_span(db)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'db> From<GenericAlias<'db>> for ClassType<'db> {
|
||||
|
|
|
@ -20,7 +20,7 @@ use crate::types::string_annotation::{
|
|||
use crate::types::{
|
||||
DynamicType, LintDiagnosticGuard, Protocol, ProtocolInstanceType, SubclassOfInner, binding_type,
|
||||
};
|
||||
use crate::types::{SpecialFormType, Type, protocol_class::ProtocolClassLiteral};
|
||||
use crate::types::{SpecialFormType, Type, protocol_class::ProtocolClass};
|
||||
use crate::util::diagnostics::format_enumeration;
|
||||
use crate::{Db, FxIndexMap, FxOrderMap, Module, ModuleName, Program, declare_lint};
|
||||
use itertools::Itertools;
|
||||
|
@ -2467,7 +2467,7 @@ pub(crate) fn add_type_expression_reference_link<'db, 'ctx>(
|
|||
pub(crate) fn report_runtime_check_against_non_runtime_checkable_protocol(
|
||||
context: &InferContext,
|
||||
call: &ast::ExprCall,
|
||||
protocol: ProtocolClassLiteral,
|
||||
protocol: ProtocolClass,
|
||||
function: KnownFunction,
|
||||
) {
|
||||
let Some(builder) = context.report_lint(&INVALID_ARGUMENT_TYPE, call) else {
|
||||
|
@ -2504,7 +2504,7 @@ pub(crate) fn report_runtime_check_against_non_runtime_checkable_protocol(
|
|||
pub(crate) fn report_attempted_protocol_instantiation(
|
||||
context: &InferContext,
|
||||
call: &ast::ExprCall,
|
||||
protocol: ProtocolClassLiteral,
|
||||
protocol: ProtocolClass,
|
||||
) {
|
||||
let Some(builder) = context.report_lint(&CALL_NON_CALLABLE, call) else {
|
||||
return;
|
||||
|
@ -2529,7 +2529,7 @@ pub(crate) fn report_attempted_protocol_instantiation(
|
|||
pub(crate) fn report_undeclared_protocol_member(
|
||||
context: &InferContext,
|
||||
definition: Definition,
|
||||
protocol_class: ProtocolClassLiteral,
|
||||
protocol_class: ProtocolClass,
|
||||
class_symbol_table: &PlaceTable,
|
||||
) {
|
||||
/// We want to avoid suggesting an annotation for e.g. `x = None`,
|
||||
|
|
|
@ -1433,7 +1433,7 @@ impl KnownFunction {
|
|||
return;
|
||||
};
|
||||
let Some(protocol_class) = param_type
|
||||
.into_class_literal()
|
||||
.to_class_type(db)
|
||||
.and_then(|class| class.into_protocol_class(db))
|
||||
else {
|
||||
report_bad_argument_to_protocol_interface(
|
||||
|
|
|
@ -1216,8 +1216,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
}
|
||||
|
||||
if is_protocol
|
||||
&& !(base_class.class_literal(self.db()).0.is_protocol(self.db())
|
||||
|| base_class.is_known(self.db(), KnownClass::Object))
|
||||
&& !(base_class.is_protocol(self.db()) || base_class.is_object(self.db()))
|
||||
{
|
||||
if let Some(builder) = self
|
||||
.context
|
||||
|
@ -6249,11 +6248,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
// subclasses of the protocol to be passed to parameters that accept `type[SomeProtocol]`.
|
||||
// <https://typing.python.org/en/latest/spec/protocol.html#type-and-class-objects-vs-protocols>.
|
||||
if !callable_type.is_subclass_of() {
|
||||
if let Some(protocol) = class
|
||||
.class_literal(self.db())
|
||||
.0
|
||||
.into_protocol_class(self.db())
|
||||
{
|
||||
if let Some(protocol) = class.into_protocol_class(self.db()) {
|
||||
report_attempted_protocol_instantiation(
|
||||
&self.context,
|
||||
call_expression,
|
||||
|
|
|
@ -645,10 +645,8 @@ impl<'db> Protocol<'db> {
|
|||
fn interface(self, db: &'db dyn Db) -> ProtocolInterface<'db> {
|
||||
match self {
|
||||
Self::FromClass(class) => class
|
||||
.class_literal(db)
|
||||
.0
|
||||
.into_protocol_class(db)
|
||||
.expect("Protocol class literal should be a protocol class")
|
||||
.expect("Class wrapped by `Protocol` should be a protocol class")
|
||||
.interface(db),
|
||||
Self::Synthesized(synthesized) => synthesized.interface(),
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ use rustc_hash::FxHashMap;
|
|||
use super::TypeVarVariance;
|
||||
use crate::semantic_index::place::ScopedPlaceId;
|
||||
use crate::semantic_index::{SemanticIndex, place_table};
|
||||
use crate::types::ClassType;
|
||||
use crate::types::context::InferContext;
|
||||
use crate::types::diagnostic::report_undeclared_protocol_member;
|
||||
use crate::{
|
||||
|
@ -26,16 +27,24 @@ use crate::{
|
|||
|
||||
impl<'db> ClassLiteral<'db> {
|
||||
/// Returns `Some` if this is a protocol class, `None` otherwise.
|
||||
pub(super) fn into_protocol_class(self, db: &'db dyn Db) -> Option<ProtocolClassLiteral<'db>> {
|
||||
self.is_protocol(db).then_some(ProtocolClassLiteral(self))
|
||||
pub(super) fn into_protocol_class(self, db: &'db dyn Db) -> Option<ProtocolClass<'db>> {
|
||||
self.is_protocol(db)
|
||||
.then_some(ProtocolClass(ClassType::NonGeneric(self)))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'db> ClassType<'db> {
|
||||
/// Returns `Some` if this is a protocol class, `None` otherwise.
|
||||
pub(super) fn into_protocol_class(self, db: &'db dyn Db) -> Option<ProtocolClass<'db>> {
|
||||
self.is_protocol(db).then_some(ProtocolClass(self))
|
||||
}
|
||||
}
|
||||
|
||||
/// Representation of a single `Protocol` class definition.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
pub(super) struct ProtocolClassLiteral<'db>(ClassLiteral<'db>);
|
||||
pub(super) struct ProtocolClass<'db>(ClassType<'db>);
|
||||
|
||||
impl<'db> ProtocolClassLiteral<'db> {
|
||||
impl<'db> ProtocolClass<'db> {
|
||||
/// Returns the protocol members of this class.
|
||||
///
|
||||
/// A protocol's members define the interface declared by the protocol.
|
||||
|
@ -56,7 +65,9 @@ impl<'db> ProtocolClassLiteral<'db> {
|
|||
}
|
||||
|
||||
pub(super) fn is_runtime_checkable(self, db: &'db dyn Db) -> bool {
|
||||
self.known_function_decorators(db)
|
||||
self.class_literal(db)
|
||||
.0
|
||||
.known_function_decorators(db)
|
||||
.contains(&KnownFunction::RuntimeCheckable)
|
||||
}
|
||||
|
||||
|
@ -66,10 +77,11 @@ impl<'db> ProtocolClassLiteral<'db> {
|
|||
pub(super) fn validate_members(self, context: &InferContext, index: &SemanticIndex<'db>) {
|
||||
let db = context.db();
|
||||
let interface = self.interface(db);
|
||||
let class_place_table = index.place_table(self.body_scope(db).file_scope_id(db));
|
||||
let body_scope = self.class_literal(db).0.body_scope(db);
|
||||
let class_place_table = index.place_table(body_scope.file_scope_id(db));
|
||||
|
||||
for (symbol_id, mut bindings_iterator) in
|
||||
use_def_map(db, self.body_scope(db)).all_end_of_scope_symbol_bindings()
|
||||
use_def_map(db, body_scope).all_end_of_scope_symbol_bindings()
|
||||
{
|
||||
let symbol_name = class_place_table.symbol(symbol_id).name();
|
||||
|
||||
|
@ -77,27 +89,27 @@ impl<'db> ProtocolClassLiteral<'db> {
|
|||
continue;
|
||||
}
|
||||
|
||||
let has_declaration = self
|
||||
.iter_mro(db, None)
|
||||
.filter_map(ClassBase::into_class)
|
||||
.any(|superclass| {
|
||||
let superclass_scope = superclass.class_literal(db).0.body_scope(db);
|
||||
let Some(scoped_symbol_id) =
|
||||
place_table(db, superclass_scope).symbol_id(symbol_name)
|
||||
else {
|
||||
return false;
|
||||
};
|
||||
!place_from_declarations(
|
||||
db,
|
||||
index
|
||||
.use_def_map(superclass_scope.file_scope_id(db))
|
||||
.end_of_scope_declarations(ScopedPlaceId::Symbol(scoped_symbol_id)),
|
||||
)
|
||||
.into_place_and_conflicting_declarations()
|
||||
.0
|
||||
.place
|
||||
.is_unbound()
|
||||
});
|
||||
let has_declaration =
|
||||
self.iter_mro(db)
|
||||
.filter_map(ClassBase::into_class)
|
||||
.any(|superclass| {
|
||||
let superclass_scope = superclass.class_literal(db).0.body_scope(db);
|
||||
let Some(scoped_symbol_id) =
|
||||
place_table(db, superclass_scope).symbol_id(symbol_name)
|
||||
else {
|
||||
return false;
|
||||
};
|
||||
!place_from_declarations(
|
||||
db,
|
||||
index
|
||||
.use_def_map(superclass_scope.file_scope_id(db))
|
||||
.end_of_scope_declarations(ScopedPlaceId::Symbol(scoped_symbol_id)),
|
||||
)
|
||||
.into_place_and_conflicting_declarations()
|
||||
.0
|
||||
.place
|
||||
.is_unbound()
|
||||
});
|
||||
|
||||
if has_declaration {
|
||||
continue;
|
||||
|
@ -114,8 +126,8 @@ impl<'db> ProtocolClassLiteral<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'db> Deref for ProtocolClassLiteral<'db> {
|
||||
type Target = ClassLiteral<'db>;
|
||||
impl<'db> Deref for ProtocolClass<'db> {
|
||||
type Target = ClassType<'db>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
|
@ -622,16 +634,19 @@ impl BoundOnClass {
|
|||
#[salsa::tracked(cycle_fn=proto_interface_cycle_recover, cycle_initial=proto_interface_cycle_initial, heap_size=ruff_memory_usage::heap_size)]
|
||||
fn cached_protocol_interface<'db>(
|
||||
db: &'db dyn Db,
|
||||
class: ClassLiteral<'db>,
|
||||
class: ClassType<'db>,
|
||||
) -> ProtocolInterface<'db> {
|
||||
let mut members = BTreeMap::default();
|
||||
|
||||
for parent_protocol in class
|
||||
.iter_mro(db, None)
|
||||
for (parent_protocol, specialization) in class
|
||||
.iter_mro(db)
|
||||
.filter_map(ClassBase::into_class)
|
||||
.filter_map(|class| class.class_literal(db).0.into_protocol_class(db))
|
||||
.filter_map(|class| {
|
||||
let (class, specialization) = class.class_literal(db);
|
||||
Some((class.into_protocol_class(db)?, specialization))
|
||||
})
|
||||
{
|
||||
let parent_scope = parent_protocol.body_scope(db);
|
||||
let parent_scope = parent_protocol.class_literal(db).0.body_scope(db);
|
||||
let use_def_map = use_def_map(db, parent_scope);
|
||||
let place_table = place_table(db, parent_scope);
|
||||
let mut direct_members = FxHashMap::default();
|
||||
|
@ -676,6 +691,8 @@ fn cached_protocol_interface<'db>(
|
|||
continue;
|
||||
}
|
||||
|
||||
let ty = ty.apply_optional_specialization(db, specialization);
|
||||
|
||||
let member = match ty {
|
||||
Type::PropertyInstance(property) => ProtocolMemberKind::Property(property),
|
||||
Type::Callable(callable)
|
||||
|
@ -702,19 +719,21 @@ fn cached_protocol_interface<'db>(
|
|||
ProtocolInterface::new(db, members)
|
||||
}
|
||||
|
||||
// If we use `expect(clippy::trivially_copy_pass_by_ref)` here,
|
||||
// the lint expectation is unfulfilled on WASM
|
||||
#[allow(clippy::trivially_copy_pass_by_ref)]
|
||||
fn proto_interface_cycle_recover<'db>(
|
||||
_db: &dyn Db,
|
||||
_value: &ProtocolInterface<'db>,
|
||||
_count: u32,
|
||||
_class: ClassLiteral<'db>,
|
||||
_class: ClassType<'db>,
|
||||
) -> salsa::CycleRecoveryAction<ProtocolInterface<'db>> {
|
||||
salsa::CycleRecoveryAction::Iterate
|
||||
}
|
||||
|
||||
fn proto_interface_cycle_initial<'db>(
|
||||
db: &'db dyn Db,
|
||||
_class: ClassLiteral<'db>,
|
||||
_class: ClassType<'db>,
|
||||
) -> ProtocolInterface<'db> {
|
||||
ProtocolInterface::empty(db)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue