[ty] Strict validation of protocol members

This commit is contained in:
Alex Waygood 2025-04-30 22:56:46 +01:00
parent e658778ced
commit fdee512083
5 changed files with 225 additions and 10 deletions

View file

@ -419,6 +419,8 @@ reveal_type(get_protocol_members(Baz2))
## Protocol members in statically known branches ## Protocol members in statically known branches
<!-- snapshot-diagnostics -->
The list of protocol members does not include any members declared in branches that are statically The list of protocol members does not include any members declared in branches that are statically
known to be unreachable: known to be unreachable:
@ -429,7 +431,7 @@ python-version = "3.9"
```py ```py
import sys import sys
from typing_extensions import Protocol, get_protocol_members from typing_extensions import Protocol, get_protocol_members, reveal_type
class Foo(Protocol): class Foo(Protocol):
if sys.version_info >= (3, 10): if sys.version_info >= (3, 10):
@ -438,7 +440,7 @@ class Foo(Protocol):
def c(self) -> None: ... def c(self) -> None: ...
else: else:
d: int d: int
e = 56 e = 56 # error: [invalid-protocol] "not declared as a protocol member"
def f(self) -> None: ... def f(self) -> None: ...
reveal_type(get_protocol_members(Foo)) # revealed: frozenset[Literal["d", "e", "f"]] reveal_type(get_protocol_members(Foo)) # revealed: frozenset[Literal["d", "e", "f"]]
@ -656,26 +658,70 @@ class LotsOfBindings(Protocol):
class Nested: ... # also weird, but we should also probably allow it class Nested: ... # also weird, but we should also probably allow it
class NestedProtocol(Protocol): ... # same here... class NestedProtocol(Protocol): ... # same here...
e = 72 # TODO: this should error with `[invalid-protocol]` (`e` is not declared) e = 72 # error: [invalid-protocol] "not declared as a protocol member"
f, g = (1, 2) # TODO: this should error with `[invalid-protocol]` (`f` and `g` are not declared) # error: [invalid-protocol] "Cannot assign to variable `f` in body of protocol class `LotsOfBindings`"
# error: [invalid-protocol] "Cannot assign to variable `g` in body of protocol class `LotsOfBindings`"
f, g = (1, 2)
h: int = (i := 3) # TODO: this should error with `[invalid-protocol]` (`i` is not declared) h: int = (i := 3) # error: [invalid-protocol] "not declared as a protocol member"
for j in range(42): # TODO: this should error with `[invalid-protocol]` (`j` is not declared) for j in range(42): # error: [invalid-protocol] "not declared as a protocol member"
pass pass
with MyContext() as k: # TODO: this should error with `[invalid-protocol]` (`k` is not declared) with MyContext() as k: # error: [invalid-protocol] "not declared as a protocol member"
pass pass
match object(): match object():
case l: # TODO: this should error with `[invalid-protocol]` (`l` is not declared) case l: # error: [invalid-protocol] "not declared as a protocol member"
... ...
# revealed: frozenset[Literal["Nested", "NestedProtocol", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"]] # revealed: frozenset[Literal["Nested", "NestedProtocol", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"]]
reveal_type(get_protocol_members(LotsOfBindings)) reveal_type(get_protocol_members(LotsOfBindings))
``` ```
A binding-without-declaration will not be reported if it occurs in a branch that we can statically
determine to be unreachable. The reason is that we don't consider it to be a protocol member at all
if all definitions for the variable are in unreachable blocks:
```py
import sys
class Protocol694(Protocol):
if sys.version_info > (3, 694):
x = 42 # no error!
```
If there are multiple bindings of the variable in the class body, however, and at least one of the
bindings occurs in a block of code that is understood to be (possibly) reachable, a diagnostic will
be reported. The diagnostic will be attached to the first binding that occurs in the class body,
even if that first definition occurs in an unreachable block:
```py
class Protocol695(Protocol):
if sys.version_info > (3, 695):
x = 42
else:
x = 42
# error: [invalid-protocol] "not declared as a protocol member"
x = 56
```
In order for the variable to be considered declared, the declaration of the variable must also take
place in a block of code that is understood to be (possibly) reachable:
```py
class Protocol696(Protocol):
if sys.version_info > (3, 696):
x: int
else:
x = 42 # error: [invalid-protocol] "not declared as a protocol member"
y: int
y = 56 # no error
```
Attribute members are allowed to have assignments in methods on the protocol class, just like Attribute members are allowed to have assignments in methods on the protocol class, just like
non-protocol classes. Unlike other classes, however, instance attributes that are not declared in non-protocol classes. Unlike other classes, however, instance attributes that are not declared in
the class body are disallowed. This is mandated by [the spec][spec_protocol_members]: the class body are disallowed. This is mandated by [the spec][spec_protocol_members]:

View file

@ -0,0 +1,68 @@
---
source: crates/ty_test/src/lib.rs
expression: snapshot
---
---
mdtest name: protocols.md - Protocols - Protocol members in statically known branches
mdtest path: crates/ty_python_semantic/resources/mdtest/protocols.md
---
# Python source files
## mdtest_snippet.py
```
1 | import sys
2 | from typing_extensions import Protocol, get_protocol_members, reveal_type
3 |
4 | class Foo(Protocol):
5 | if sys.version_info >= (3, 10):
6 | a: int
7 | b = 42
8 | def c(self) -> None: ...
9 | else:
10 | d: int
11 | e = 56 # error: [invalid-protocol] "not declared as a protocol member"
12 | def f(self) -> None: ...
13 |
14 | reveal_type(get_protocol_members(Foo)) # revealed: frozenset[Literal["d", "e", "f"]]
```
# Diagnostics
```
error[invalid-protocol]: Cannot assign to variable `e` in body of protocol class `Foo`
--> src/mdtest_snippet.py:11:9
|
9 | else:
10 | d: int
11 | e = 56 # error: [invalid-protocol] "not declared as a protocol member"
| ^ `e` is not declared as a protocol member
12 | def f(self) -> None: ...
|
info: Assigning to an undeclared variable in a protocol class leads to an ambiguous interface
--> src/mdtest_snippet.py:4:7
|
2 | from typing_extensions import Protocol, get_protocol_members, reveal_type
3 |
4 | class Foo(Protocol):
| ^^^^^^^^^^^^^ `Foo` declared as a protocol here
5 | if sys.version_info >= (3, 10):
6 | a: int
|
info: No declarations found for `e` in the body of `Foo` or any of its superclasses
info: rule `invalid-protocol` is enabled by default
```
```
info[revealed-type]: Revealed type
--> src/mdtest_snippet.py:14:13
|
12 | def f(self) -> None: ...
13 |
14 | reveal_type(get_protocol_members(Foo)) # revealed: frozenset[Literal["d", "e", "f"]]
| ^^^^^^^^^^^^^^^^^^^^^^^^^ `frozenset[Literal["d", "e", "f"]]`
|
```

View file

@ -6,6 +6,8 @@ use super::{
add_inferred_python_version_hint_to_diagnostic, add_inferred_python_version_hint_to_diagnostic,
}; };
use crate::lint::{Level, LintRegistryBuilder, LintStatus}; use crate::lint::{Level, LintRegistryBuilder, LintStatus};
use crate::semantic_index::definition::Definition;
use crate::semantic_index::symbol::SymbolTable;
use crate::suppression::FileSuppressionId; use crate::suppression::FileSuppressionId;
use crate::types::LintDiagnosticGuard; use crate::types::LintDiagnosticGuard;
use crate::types::function::KnownFunction; use crate::types::function::KnownFunction;
@ -2198,3 +2200,43 @@ pub(super) fn hint_if_stdlib_submodule_exists_on_other_versions(
add_inferred_python_version_hint_to_diagnostic(db, &mut diagnostic, "resolving modules"); add_inferred_python_version_hint_to_diagnostic(db, &mut diagnostic, "resolving modules");
} }
pub(crate) fn report_undeclared_protocol_member(
context: &InferContext,
definition: Definition,
protocol_class: ProtocolClassLiteral,
class_symbol_table: &SymbolTable,
) {
let db = context.db();
let Some(builder) = context.report_lint(&INVALID_PROTOCOL, definition.full_range(db)) else {
return;
};
let symbol_name = class_symbol_table.symbol(definition.symbol(db)).name();
let class_name = protocol_class.name(db);
let mut diagnostic = builder.into_diagnostic(format_args!(
"Cannot assign to variable `{symbol_name}` \
in body of protocol class `{class_name}`",
));
diagnostic.set_primary_message(format_args!(
"`{symbol_name}` is not declared as a protocol member"
));
let mut class_def_diagnostic = SubDiagnostic::new(
Severity::Info,
"Assigning to an undeclared variable in a protocol class \
leads to an ambiguous interface",
);
class_def_diagnostic.annotate(
Annotation::primary(protocol_class.header_span(db))
.message(format_args!("`{class_name}` declared as a protocol here",)),
);
diagnostic.sub(class_def_diagnostic);
diagnostic.info(format_args!(
"No declarations found for `{symbol_name}` \
in the body of `{class_name}` or any of its superclasses"
));
}

View file

@ -58,7 +58,9 @@ use crate::semantic_index::narrowing_constraints::ConstraintKey;
use crate::semantic_index::symbol::{ use crate::semantic_index::symbol::{
FileScopeId, NodeWithScopeKind, NodeWithScopeRef, ScopeId, ScopeKind, ScopedSymbolId, FileScopeId, NodeWithScopeKind, NodeWithScopeRef, ScopeId, ScopeKind, ScopedSymbolId,
}; };
use crate::semantic_index::{EagerSnapshotResult, SemanticIndex, semantic_index}; use crate::semantic_index::{
EagerSnapshotResult, SemanticIndex, semantic_index, symbol_table, use_def_map,
};
use crate::symbol::{ use crate::symbol::{
Boundness, LookupError, builtins_module_scope, builtins_symbol, explicit_global_symbol, Boundness, LookupError, builtins_module_scope, builtins_symbol, explicit_global_symbol,
global_symbol, module_type_implicit_global_declaration, module_type_implicit_global_symbol, global_symbol, module_type_implicit_global_declaration, module_type_implicit_global_symbol,
@ -112,7 +114,7 @@ use super::diagnostic::{
report_invalid_type_checking_constant, report_non_subscriptable, report_invalid_type_checking_constant, report_non_subscriptable,
report_possibly_unresolved_reference, report_possibly_unresolved_reference,
report_runtime_check_against_non_runtime_checkable_protocol, report_slice_step_size_zero, report_runtime_check_against_non_runtime_checkable_protocol, report_slice_step_size_zero,
report_unresolved_reference, report_undeclared_protocol_member, report_unresolved_reference,
}; };
use super::generics::LegacyGenericBase; use super::generics::LegacyGenericBase;
use super::slots::check_class_slots; use super::slots::check_class_slots;
@ -1075,6 +1077,56 @@ impl<'db> TypeInferenceBuilder<'db> {
} }
} }
} }
if let Some(protocol) = class.into_protocol_class(self.db()) {
let interface = protocol.interface(self.db());
let class_symbol_table = symbol_table(self.db(), class.body_scope(self.db()));
for (symbol_id, mut bindings_iterator) in
use_def_map(self.db(), class.body_scope(self.db())).all_public_bindings()
{
let symbol_name = class_symbol_table.symbol(symbol_id).name();
if !interface.includes_member(self.db(), symbol_name) {
continue;
}
let has_declaration = class
.iter_mro(self.db(), None)
.filter_map(ClassBase::into_class)
.any(|superclass| {
let superclass_scope =
superclass.class_literal(self.db()).0.body_scope(self.db());
let Some(scoped_symbol_id) = symbol_table(self.db(), superclass_scope)
.symbol_id_by_name(symbol_name)
else {
return false;
};
symbol_from_declarations(
self.db(),
use_def_map(self.db(), superclass_scope)
.public_declarations(scoped_symbol_id),
)
.is_ok_and(|symbol| !symbol.symbol.is_unbound())
});
if has_declaration {
continue;
}
let Some(first_binding) = bindings_iterator.find_map(|binding| binding.binding)
else {
continue;
};
report_undeclared_protocol_member(
&self.context,
first_binding,
protocol,
class_symbol_table,
);
}
}
} }
} }

View file

@ -132,6 +132,13 @@ impl<'db> ProtocolInterface<'db> {
} }
} }
pub(super) fn includes_member(self, db: &'db dyn Db, name: &str) -> bool {
match self {
Self::Members(members) => members.inner(db).contains_key(name),
Self::SelfReference => false,
}
}
/// Return `true` if all members of this protocol are fully static. /// Return `true` if all members of this protocol are fully static.
pub(super) fn is_fully_static(self, db: &'db dyn Db) -> bool { pub(super) fn is_fully_static(self, db: &'db dyn Db) -> bool {
self.members(db).all(|member| member.ty.is_fully_static(db)) self.members(db).all(|member| member.ty.is_fully_static(db))