From 5d52902e186cea6fa73dac1ceebf53208f6e08ba Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Fri, 5 Sep 2025 17:56:06 +0100 Subject: [PATCH] [ty] Implement the legacy PEP-484 convention for indicating positional-only parameters (#20248) Co-authored-by: Carl Meyer --- .../rules/pre_pep570_positional_argument.rs | 15 +-- crates/ruff_python_ast/src/nodes.rs | 9 +- .../resources/mdtest/call/function.md | 74 ++++++++++- .../resources/mdtest/protocols.md | 43 +++++-- crates/ty_python_semantic/src/ast_node_ref.rs | 6 + crates/ty_python_semantic/src/node_key.rs | 6 + .../ty_python_semantic/src/semantic_index.rs | 36 +++++- .../src/semantic_index/builder.rs | 2 +- .../src/semantic_index/definition.rs | 9 ++ .../src/semantic_index/scope.rs | 56 +++----- crates/ty_python_semantic/src/types.rs | 10 +- crates/ty_python_semantic/src/types/class.rs | 46 +++---- .../ty_python_semantic/src/types/context.rs | 2 +- .../ty_python_semantic/src/types/display.rs | 8 +- .../ty_python_semantic/src/types/function.rs | 120 +++++++++++++++--- crates/ty_python_semantic/src/types/infer.rs | 32 +---- .../src/types/signatures.rs | 52 ++++++-- 17 files changed, 376 insertions(+), 150 deletions(-) diff --git a/crates/ruff_linter/src/rules/flake8_pyi/rules/pre_pep570_positional_argument.rs b/crates/ruff_linter/src/rules/flake8_pyi/rules/pre_pep570_positional_argument.rs index ec5fd871bb..e481a82dc2 100644 --- a/crates/ruff_linter/src/rules/flake8_pyi/rules/pre_pep570_positional_argument.rs +++ b/crates/ruff_linter/src/rules/flake8_pyi/rules/pre_pep570_positional_argument.rs @@ -1,6 +1,6 @@ use ruff_macros::{ViolationMetadata, derive_message_formats}; +use ruff_python_ast as ast; use ruff_python_ast::identifier::Identifier; -use ruff_python_ast::{self as ast, ParameterWithDefault}; use ruff_python_semantic::analyze::function_type; use crate::Violation; @@ -85,16 +85,9 @@ pub(crate) fn pep_484_positional_parameter(checker: &Checker, function_def: &ast function_type::FunctionType::Method | function_type::FunctionType::ClassMethod )); - if let Some(arg) = function_def.parameters.args.get(skip) { - if is_old_style_positional_only(arg) { - checker.report_diagnostic(Pep484StylePositionalOnlyParameter, arg.identifier()); + if let Some(param) = function_def.parameters.args.get(skip) { + if param.uses_pep_484_positional_only_convention() { + checker.report_diagnostic(Pep484StylePositionalOnlyParameter, param.identifier()); } } } - -/// Returns `true` if the [`ParameterWithDefault`] is an old-style positional-only parameter (i.e., -/// its name starts with `__` and does not end with `__`). -fn is_old_style_positional_only(param: &ParameterWithDefault) -> bool { - let arg_name = param.name(); - arg_name.starts_with("__") && !arg_name.ends_with("__") -} diff --git a/crates/ruff_python_ast/src/nodes.rs b/crates/ruff_python_ast/src/nodes.rs index 6a824603bb..27cfc06da7 100644 --- a/crates/ruff_python_ast/src/nodes.rs +++ b/crates/ruff_python_ast/src/nodes.rs @@ -3219,7 +3219,6 @@ impl<'a> IntoIterator for &'a Box { /// Used by `Arguments` original type. /// /// NOTE: This type is different from original Python AST. - #[derive(Clone, Debug, PartialEq)] #[cfg_attr(feature = "get-size", derive(get_size2::GetSize))] pub struct ParameterWithDefault { @@ -3241,6 +3240,14 @@ impl ParameterWithDefault { pub fn annotation(&self) -> Option<&Expr> { self.parameter.annotation() } + + /// Return `true` if the parameter name uses the pre-PEP-570 convention + /// (specified in PEP 484) to indicate to a type checker that it should be treated + /// as positional-only. + pub fn uses_pep_484_positional_only_convention(&self) -> bool { + let name = self.name(); + name.starts_with("__") && !name.ends_with("__") + } } /// An AST node used to represent the arguments passed to a function call or class definition. diff --git a/crates/ty_python_semantic/resources/mdtest/call/function.md b/crates/ty_python_semantic/resources/mdtest/call/function.md index 6be6a8d9ec..852623a4f4 100644 --- a/crates/ty_python_semantic/resources/mdtest/call/function.md +++ b/crates/ty_python_semantic/resources/mdtest/call/function.md @@ -68,6 +68,78 @@ def _(flag: bool): reveal_type(foo()) # revealed: int ``` +## PEP-484 convention for positional-only parameters + +PEP 570, introduced in Python 3.8, added dedicated Python syntax for denoting positional-only +parameters (the `/` in a function signature). However, functions implemented in C were able to have +positional-only parameters prior to Python 3.8 (there was just no syntax for expressing this at the +Python level). + +Stub files describing functions implemented in C nonetheless needed a way of expressing that certain +parameters were positional-only. In the absence of dedicated Python syntax, PEP 484 described a +convention that type checkers were expected to understand: + +> Some functions are designed to take their arguments only positionally, and expect their callers +> never to use the argument’s name to provide that argument by keyword. All arguments with names +> beginning with `__` are assumed to be positional-only, except if their names also end with `__`. + +While this convention is now redundant (following the implementation of PEP 570), many projects +still continue to use the old convention, so it is supported by ty as well. + +```py +def f(__x: int): ... + +f(1) +# error: [missing-argument] +# error: [unknown-argument] +f(__x=1) +``` + +But not if they follow a non-positional-only parameter: + +```py +def g(x: int, __y: str): ... + +g(x=1, __y="foo") +``` + +And also not if they both start and end with `__`: + +```py +def h(__x__: str): ... + +h(__x__="foo") +``` + +And if *any* parameters use the new PEP-570 convention, the old convention does not apply: + +```py +def i(x: str, /, __y: int): ... + +i("foo", __y=42) # fine +``` + +And `self`/`cls` are implicitly positional-only: + +```py +class C: + def method(self, __x: int): ... + @classmethod + def class_method(cls, __x: str): ... + # (the name of the first parameter is irrelevant; + # a staticmethod works the same as a free function in the global scope) + @staticmethod + def static_method(self, __x: int): ... + +# error: [missing-argument] +# error: [unknown-argument] +C().method(__x=1) +# error: [missing-argument] +# error: [unknown-argument] +C.class_method(__x="1") +C.static_method("x", __x=42) # fine +``` + ## Splatted arguments ### Unknown argument length @@ -545,7 +617,7 @@ def _(args: str) -> None: This is a regression that was highlighted by the ecosystem check, which shows that we might need to rethink how we perform argument expansion during overload resolution. In particular, we might need -to retry both `match_parameters` _and_ `check_types` for each expansion. Currently we only retry +to retry both `match_parameters` *and* `check_types` for each expansion. Currently we only retry `check_types`. The issue is that argument expansion might produce a splatted value with a different arity than what diff --git a/crates/ty_python_semantic/resources/mdtest/protocols.md b/crates/ty_python_semantic/resources/mdtest/protocols.md index 7e4e710f09..1267bb27fd 100644 --- a/crates/ty_python_semantic/resources/mdtest/protocols.md +++ b/crates/ty_python_semantic/resources/mdtest/protocols.md @@ -413,13 +413,13 @@ To see the kinds and types of the protocol members, you can use the debugging ai from ty_extensions import reveal_protocol_interface from typing import SupportsIndex, SupportsAbs, ClassVar, Iterator -# revealed: {"method_member": MethodMember(`(self) -> bytes`), "x": AttributeMember(`int`), "y": PropertyMember { getter: `def y(self) -> str` }, "z": PropertyMember { getter: `def z(self) -> int`, setter: `def z(self, z: int) -> None` }} +# revealed: {"method_member": MethodMember(`(self, /) -> bytes`), "x": AttributeMember(`int`), "y": PropertyMember { getter: `def y(self, /) -> str` }, "z": PropertyMember { getter: `def z(self, /) -> int`, setter: `def z(self, /, z: int) -> None` }} reveal_protocol_interface(Foo) -# revealed: {"__index__": MethodMember(`(self) -> int`)} +# revealed: {"__index__": MethodMember(`(self, /) -> int`)} reveal_protocol_interface(SupportsIndex) -# revealed: {"__abs__": MethodMember(`(self) -> Unknown`)} +# revealed: {"__abs__": MethodMember(`(self, /) -> Unknown`)} reveal_protocol_interface(SupportsAbs) -# revealed: {"__iter__": MethodMember(`(self) -> Iterator[Unknown]`), "__next__": MethodMember(`(self) -> Unknown`)} +# revealed: {"__iter__": MethodMember(`(self, /) -> Iterator[Unknown]`), "__next__": MethodMember(`(self, /) -> Unknown`)} reveal_protocol_interface(Iterator) # error: [invalid-argument-type] "Invalid argument to `reveal_protocol_interface`: Only protocol classes can be passed to `reveal_protocol_interface`" @@ -439,9 +439,9 @@ do not implement any special handling for generic aliases passed to the function reveal_type(get_protocol_members(SupportsAbs[int])) # revealed: frozenset[str] reveal_type(get_protocol_members(Iterator[int])) # revealed: frozenset[str] -# revealed: {"__abs__": MethodMember(`(self) -> int`)} +# revealed: {"__abs__": MethodMember(`(self, /) -> int`)} reveal_protocol_interface(SupportsAbs[int]) -# revealed: {"__iter__": MethodMember(`(self) -> Iterator[int]`), "__next__": MethodMember(`(self) -> int`)} +# revealed: {"__iter__": MethodMember(`(self, /) -> Iterator[int]`), "__next__": MethodMember(`(self, /) -> int`)} reveal_protocol_interface(Iterator[int]) class BaseProto(Protocol): @@ -450,10 +450,10 @@ class BaseProto(Protocol): class SubProto(BaseProto, Protocol): def member(self) -> bool: ... -# revealed: {"member": MethodMember(`(self) -> int`)} +# revealed: {"member": MethodMember(`(self, /) -> int`)} reveal_protocol_interface(BaseProto) -# revealed: {"member": MethodMember(`(self) -> bool`)} +# revealed: {"member": MethodMember(`(self, /) -> bool`)} reveal_protocol_interface(SubProto) class ProtoWithClassVar(Protocol): @@ -1767,7 +1767,7 @@ class Foo(Protocol): def method(self) -> str: ... def f(x: Foo): - reveal_type(type(x).method) # revealed: def method(self) -> str + reveal_type(type(x).method) # revealed: def method(self, /) -> str class Bar: def __init__(self): @@ -1776,6 +1776,31 @@ class Bar: f(Bar()) # error: [invalid-argument-type] ``` +Some protocols use the old convention (specified in PEP-484) for denoting positional-only +parameters. This is supported by ty: + +```py +class HasPosOnlyDunders: + def __invert__(self, /) -> "HasPosOnlyDunders": + return self + + def __lt__(self, other, /) -> bool: + return True + +class SupportsLessThan(Protocol): + def __lt__(self, __other) -> bool: ... + +class Invertable(Protocol): + # `self` and `cls` are always implicitly positional-only for methods defined in `Protocol` + # classes, even if no parameters in the method use the PEP-484 convention. + def __invert__(self) -> object: ... + +static_assert(is_assignable_to(HasPosOnlyDunders, SupportsLessThan)) +static_assert(is_assignable_to(HasPosOnlyDunders, Invertable)) +static_assert(is_assignable_to(str, SupportsLessThan)) +static_assert(is_assignable_to(int, Invertable)) +``` + ## Equivalence of protocols with method or property members Two protocols `P1` and `P2`, both with a method member `x`, are considered equivalent if the diff --git a/crates/ty_python_semantic/src/ast_node_ref.rs b/crates/ty_python_semantic/src/ast_node_ref.rs index 9d1dd60433..ed28bc396b 100644 --- a/crates/ty_python_semantic/src/ast_node_ref.rs +++ b/crates/ty_python_semantic/src/ast_node_ref.rs @@ -49,6 +49,12 @@ pub struct AstNodeRef { _node: PhantomData, } +impl AstNodeRef { + pub(crate) fn index(&self) -> NodeIndex { + self.index + } +} + impl AstNodeRef where T: HasNodeIndex + Ranged + PartialEq + Debug, diff --git a/crates/ty_python_semantic/src/node_key.rs b/crates/ty_python_semantic/src/node_key.rs index 18edfe1a04..a93931294b 100644 --- a/crates/ty_python_semantic/src/node_key.rs +++ b/crates/ty_python_semantic/src/node_key.rs @@ -1,5 +1,7 @@ use ruff_python_ast::{HasNodeIndex, NodeIndex}; +use crate::ast_node_ref::AstNodeRef; + /// Compact key for a node for use in a hash map. #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, get_size2::GetSize)] pub(super) struct NodeKey(NodeIndex); @@ -11,4 +13,8 @@ impl NodeKey { { NodeKey(node.node_index().load()) } + + pub(super) fn from_node_ref(node_ref: &AstNodeRef) -> Self { + NodeKey(node_ref.index()) + } } diff --git a/crates/ty_python_semantic/src/semantic_index.rs b/crates/ty_python_semantic/src/semantic_index.rs index 2c770ee0a6..fc57d0b2f6 100644 --- a/crates/ty_python_semantic/src/semantic_index.rs +++ b/crates/ty_python_semantic/src/semantic_index.rs @@ -159,7 +159,6 @@ pub(crate) fn attribute_scopes<'db, 's>( class_body_scope: ScopeId<'db>, ) -> impl Iterator + use<'s, 'db> { let file = class_body_scope.file(db); - let module = parsed_module(db, file).load(db); let index = semantic_index(db, file); let class_scope_id = class_body_scope.file_scope_id(db); @@ -175,7 +174,7 @@ pub(crate) fn attribute_scopes<'db, 's>( (child_scope_id, scope) }; - function_scope.node().as_function(&module)?; + function_scope.node().as_function()?; Some(function_scope_id) }) } @@ -332,6 +331,39 @@ impl<'db> SemanticIndex<'db> { Some(&self.scopes[self.parent_scope_id(scope_id)?]) } + /// Return the [`Definition`] of the class enclosing this method, given the + /// method's body scope, or `None` if it is not a method. + pub(crate) fn class_definition_of_method( + &self, + function_body_scope: FileScopeId, + ) -> Option> { + let current_scope = self.scope(function_body_scope); + if current_scope.kind() != ScopeKind::Function { + return None; + } + let parent_scope_id = current_scope.parent()?; + let parent_scope = self.scope(parent_scope_id); + + let class_scope = match parent_scope.kind() { + ScopeKind::Class => parent_scope, + ScopeKind::TypeParams => { + let class_scope_id = parent_scope.parent()?; + let potentially_class_scope = self.scope(class_scope_id); + + match potentially_class_scope.kind() { + ScopeKind::Class => potentially_class_scope, + _ => return None, + } + } + _ => return None, + }; + + class_scope + .node() + .as_class() + .map(|node_ref| self.expect_single_definition(node_ref)) + } + fn is_scope_reachable(&self, db: &'db dyn Db, scope_id: FileScopeId) -> bool { self.parent_scope_id(scope_id) .is_none_or(|parent_scope_id| { diff --git a/crates/ty_python_semantic/src/semantic_index/builder.rs b/crates/ty_python_semantic/src/semantic_index/builder.rs index 943715fd11..95bce62545 100644 --- a/crates/ty_python_semantic/src/semantic_index/builder.rs +++ b/crates/ty_python_semantic/src/semantic_index/builder.rs @@ -2644,7 +2644,7 @@ impl SemanticSyntaxContext for SemanticIndexBuilder<'_, '_> { match scope.kind() { ScopeKind::Class | ScopeKind::Lambda => return false, ScopeKind::Function => { - return scope.node().expect_function(self.module).is_async; + return scope.node().expect_function().node(self.module).is_async; } ScopeKind::Comprehension | ScopeKind::Module diff --git a/crates/ty_python_semantic/src/semantic_index/definition.rs b/crates/ty_python_semantic/src/semantic_index/definition.rs index 0075c0ff41..b06390fa84 100644 --- a/crates/ty_python_semantic/src/semantic_index/definition.rs +++ b/crates/ty_python_semantic/src/semantic_index/definition.rs @@ -1227,3 +1227,12 @@ impl From<&ast::TypeParamTypeVarTuple> for DefinitionNodeKey { Self(NodeKey::from_node(value)) } } + +impl From<&AstNodeRef> for DefinitionNodeKey +where + for<'a> &'a T: Into, +{ + fn from(value: &AstNodeRef) -> Self { + Self(NodeKey::from_node_ref(value)) + } +} diff --git a/crates/ty_python_semantic/src/semantic_index/scope.rs b/crates/ty_python_semantic/src/semantic_index/scope.rs index b29e732305..b807232c78 100644 --- a/crates/ty_python_semantic/src/semantic_index/scope.rs +++ b/crates/ty_python_semantic/src/semantic_index/scope.rs @@ -397,52 +397,38 @@ impl NodeWithScopeKind { } } - pub(crate) fn expect_class<'ast>( - &self, - module: &'ast ParsedModuleRef, - ) -> &'ast ast::StmtClassDef { + pub(crate) fn as_class(&self) -> Option<&AstNodeRef> { match self { - Self::Class(class) => class.node(module), - _ => panic!("expected class"), - } - } - - pub(crate) fn as_class<'ast>( - &self, - module: &'ast ParsedModuleRef, - ) -> Option<&'ast ast::StmtClassDef> { - match self { - Self::Class(class) => Some(class.node(module)), + Self::Class(class) => Some(class), _ => None, } } - pub(crate) fn expect_function<'ast>( - &self, - module: &'ast ParsedModuleRef, - ) -> &'ast ast::StmtFunctionDef { - self.as_function(module).expect("expected function") + pub(crate) fn expect_class(&self) -> &AstNodeRef { + self.as_class().expect("expected class") } - pub(crate) fn expect_type_alias<'ast>( - &self, - module: &'ast ParsedModuleRef, - ) -> &'ast ast::StmtTypeAlias { + pub(crate) fn as_function(&self) -> Option<&AstNodeRef> { match self { - Self::TypeAlias(type_alias) => type_alias.node(module), - _ => panic!("expected type alias"), - } - } - - pub(crate) fn as_function<'ast>( - &self, - module: &'ast ParsedModuleRef, - ) -> Option<&'ast ast::StmtFunctionDef> { - match self { - Self::Function(function) => Some(function.node(module)), + Self::Function(function) => Some(function), _ => None, } } + + pub(crate) fn expect_function(&self) -> &AstNodeRef { + self.as_function().expect("expected function") + } + + pub(crate) fn as_type_alias(&self) -> Option<&AstNodeRef> { + match self { + Self::TypeAlias(type_alias) => Some(type_alias), + _ => None, + } + } + + pub(crate) fn expect_type_alias(&self) -> &AstNodeRef { + self.as_type_alias().expect("expected type alias") + } } #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, get_size2::GetSize)] diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 82f81659c7..80a9206c8c 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -5649,7 +5649,7 @@ impl<'db> Type<'db> { SpecialFormType::TypingSelf => { let module = parsed_module(db, scope_id.file(db)).load(db); let index = semantic_index(db, scope_id.file(db)); - let Some(class) = nearest_enclosing_class(db, index, scope_id, &module) else { + let Some(class) = nearest_enclosing_class(db, index, scope_id) else { return Err(InvalidTypeExpressionError { fallback_type: Type::unknown(), invalid_expressions: smallvec::smallvec_inline![ @@ -9364,9 +9364,7 @@ fn walk_pep_695_type_alias<'db, V: visitor::TypeVisitor<'db> + ?Sized>( impl<'db> PEP695TypeAliasType<'db> { pub(crate) fn definition(self, db: &'db dyn Db) -> Definition<'db> { let scope = self.rhs_scope(db); - let module = parsed_module(db, scope.file(db)).load(db); - let type_alias_stmt_node = scope.node(db).expect_type_alias(&module); - + let type_alias_stmt_node = scope.node(db).expect_type_alias(); semantic_index(db, scope.file(db)).expect_single_definition(type_alias_stmt_node) } @@ -9374,9 +9372,9 @@ impl<'db> PEP695TypeAliasType<'db> { pub(crate) fn value_type(self, db: &'db dyn Db) -> Type<'db> { let scope = self.rhs_scope(db); let module = parsed_module(db, scope.file(db)).load(db); - let type_alias_stmt_node = scope.node(db).expect_type_alias(&module); + let type_alias_stmt_node = scope.node(db).expect_type_alias(); let definition = self.definition(db); - definition_expression_type(db, definition, &type_alias_stmt_node.value) + definition_expression_type(db, definition, &type_alias_stmt_node.node(&module).value) } fn normalized_impl(self, _db: &'db dyn Db, _visitor: &NormalizedVisitor<'db>) -> Self { diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index 501bea8844..ac8b245a63 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -1403,7 +1403,7 @@ impl<'db> ClassLiteral<'db> { let scope = self.body_scope(db); let file = scope.file(db); let parsed = parsed_module(db, file).load(db); - let class_def_node = scope.node(db).expect_class(&parsed); + let class_def_node = scope.node(db).expect_class().node(&parsed); class_def_node.type_params.as_ref().map(|type_params| { let index = semantic_index(db, scope.file(db)); let definition = index.expect_single_definition(class_def_node); @@ -1445,14 +1445,13 @@ impl<'db> ClassLiteral<'db> { /// query depends on the AST of another file (bad!). fn node<'ast>(self, db: &'db dyn Db, module: &'ast ParsedModuleRef) -> &'ast ast::StmtClassDef { let scope = self.body_scope(db); - scope.node(db).expect_class(module) + scope.node(db).expect_class().node(module) } pub(crate) fn definition(self, db: &'db dyn Db) -> Definition<'db> { let body_scope = self.body_scope(db); - let module = parsed_module(db, body_scope.file(db)).load(db); let index = semantic_index(db, body_scope.file(db)); - index.expect_single_definition(body_scope.node(db).expect_class(&module)) + index.expect_single_definition(body_scope.node(db).expect_class()) } pub(crate) fn apply_specialization( @@ -2870,8 +2869,8 @@ impl<'db> ClassLiteral<'db> { let class_table = place_table(db, class_body_scope); let is_valid_scope = |method_scope: ScopeId<'db>| { - if let Some(method_def) = method_scope.node(db).as_function(&module) { - let method_name = method_def.name.as_str(); + if let Some(method_def) = method_scope.node(db).as_function() { + let method_name = method_def.node(&module).name.as_str(); if let Place::Type(Type::FunctionLiteral(method_type), _) = class_symbol(db, class_body_scope, method_name).place { @@ -2946,20 +2945,22 @@ impl<'db> ClassLiteral<'db> { } // The attribute assignment inherits the reachability of the method which contains it - let is_method_reachable = - if let Some(method_def) = method_scope.node(db).as_function(&module) { - let method = index.expect_single_definition(method_def); - let method_place = class_table.symbol_id(&method_def.name).unwrap(); - class_map - .all_reachable_symbol_bindings(method_place) - .find_map(|bind| { - (bind.binding.is_defined_and(|def| def == method)) - .then(|| class_map.binding_reachability(db, &bind)) - }) - .unwrap_or(Truthiness::AlwaysFalse) - } else { - Truthiness::AlwaysFalse - }; + let is_method_reachable = if let Some(method_def) = method_scope.node(db).as_function() + { + let method = index.expect_single_definition(method_def); + let method_place = class_table + .symbol_id(&method_def.node(&module).name) + .unwrap(); + class_map + .all_reachable_symbol_bindings(method_place) + .find_map(|bind| { + (bind.binding.is_defined_and(|def| def == method)) + .then(|| class_map.binding_reachability(db, &bind)) + }) + .unwrap_or(Truthiness::AlwaysFalse) + } else { + Truthiness::AlwaysFalse + }; if is_method_reachable.is_always_false() { continue; } @@ -3323,7 +3324,7 @@ impl<'db> ClassLiteral<'db> { pub(super) fn header_range(self, db: &'db dyn Db) -> TextRange { let class_scope = self.body_scope(db); let module = parsed_module(db, class_scope.file(db)).load(db); - let class_node = class_scope.node(db).expect_class(&module); + let class_node = class_scope.node(db).expect_class().node(&module); let class_name = &class_node.name; TextRange::new( class_name.start(), @@ -4784,8 +4785,7 @@ impl KnownClass { // 2. The first parameter of the current function (typically `self` or `cls`) match overload.parameter_types() { [] => { - let Some(enclosing_class) = - nearest_enclosing_class(db, index, scope, module) + let Some(enclosing_class) = nearest_enclosing_class(db, index, scope) else { BoundSuperError::UnavailableImplicitArguments .report_diagnostic(context, call_expression.into()); diff --git a/crates/ty_python_semantic/src/types/context.rs b/crates/ty_python_semantic/src/types/context.rs index 2230603d82..d04bbf28e9 100644 --- a/crates/ty_python_semantic/src/types/context.rs +++ b/crates/ty_python_semantic/src/types/context.rs @@ -172,7 +172,7 @@ impl<'db, 'ast> InferContext<'db, 'ast> { // Inspect all ancestor function scopes by walking bottom up and infer the function's type. let mut function_scope_tys = index .ancestor_scopes(scope_id) - .filter_map(|(_, scope)| scope.node().as_function(self.module())) + .filter_map(|(_, scope)| scope.node().as_function()) .map(|node| binding_type(self.db, index.expect_single_definition(node))) .filter_map(Type::into_function_literal); diff --git a/crates/ty_python_semantic/src/types/display.rs b/crates/ty_python_semantic/src/types/display.rs index a54d0e92f7..65a350b7ca 100644 --- a/crates/ty_python_semantic/src/types/display.rs +++ b/crates/ty_python_semantic/src/types/display.rs @@ -200,15 +200,15 @@ impl ClassDisplay<'_> { match ancestor_scope.kind() { ScopeKind::Class => { - if let Some(class_def) = node.as_class(&module_ast) { - name_parts.push(class_def.name.as_str().to_string()); + if let Some(class_def) = node.as_class() { + name_parts.push(class_def.node(&module_ast).name.as_str().to_string()); } } ScopeKind::Function => { - if let Some(function_def) = node.as_function(&module_ast) { + if let Some(function_def) = node.as_function() { name_parts.push(format!( "", - function_def.name.as_str() + function_def.node(&module_ast).name.as_str() )); } } diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index c3cb87f134..0a544a1e5f 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -55,7 +55,7 @@ use bitflags::bitflags; use ruff_db::diagnostic::{Annotation, DiagnosticId, Severity, Span}; use ruff_db::files::{File, FileRange}; use ruff_db::parsed::{ParsedModuleRef, parsed_module}; -use ruff_python_ast as ast; +use ruff_python_ast::{self as ast, ParameterWithDefault}; use ruff_text_size::Ranged; use crate::module_resolver::{KnownModule, file_to_module}; @@ -63,7 +63,7 @@ use crate::place::{Boundness, Place, place_from_bindings}; use crate::semantic_index::ast_ids::HasScopedUseId; use crate::semantic_index::definition::Definition; use crate::semantic_index::scope::ScopeId; -use crate::semantic_index::semantic_index; +use crate::semantic_index::{FileScopeId, SemanticIndex, semantic_index}; use crate::types::call::{Binding, CallArguments}; use crate::types::constraints::{ConstraintSet, Constraints}; use crate::types::context::InferContext; @@ -80,7 +80,7 @@ use crate::types::{ BoundMethodType, BoundTypeVarInstance, CallableType, ClassBase, ClassLiteral, ClassType, DeprecatedInstance, DynamicType, FindLegacyTypeVarsVisitor, HasRelationToVisitor, IsEquivalentVisitor, KnownClass, NormalizedVisitor, SpecialFormType, Truthiness, Type, - TypeMapping, TypeRelation, UnionBuilder, all_members, walk_type_mapping, + TypeMapping, TypeRelation, UnionBuilder, all_members, binding_type, walk_type_mapping, }; use crate::{Db, FxOrderSet, ModuleName, resolve_module}; @@ -236,6 +236,22 @@ impl<'db> OverloadLiteral<'db> { self.has_known_decorator(db, FunctionDecorators::OVERLOAD) } + /// Returns true if this overload is decorated with `@staticmethod`, or if it is implicitly a + /// staticmethod. + pub(crate) fn is_staticmethod(self, db: &dyn Db) -> bool { + self.has_known_decorator(db, FunctionDecorators::STATICMETHOD) || self.name(db) == "__new__" + } + + /// Returns true if this overload is decorated with `@classmethod`, or if it is implicitly a + /// classmethod. + pub(crate) fn is_classmethod(self, db: &dyn Db) -> bool { + self.has_known_decorator(db, FunctionDecorators::CLASSMETHOD) + || matches!( + self.name(db).as_str(), + "__init_subclass__" | "__class_getitem__" + ) + } + fn node<'ast>( self, db: &dyn Db, @@ -249,7 +265,7 @@ impl<'db> OverloadLiteral<'db> { the function is defined." ); - self.body_scope(db).node(db).expect_function(module) + self.body_scope(db).node(db).expect_function().node(module) } /// Returns the [`FileRange`] of the function's name. @@ -258,7 +274,8 @@ impl<'db> OverloadLiteral<'db> { self.file(db), self.body_scope(db) .node(db) - .expect_function(module) + .expect_function() + .node(module) .name .range, ) @@ -274,9 +291,8 @@ impl<'db> OverloadLiteral<'db> { /// over-invalidation. fn definition(self, db: &'db dyn Db) -> Definition<'db> { let body_scope = self.body_scope(db); - let module = parsed_module(db, self.file(db)).load(db); let index = semantic_index(db, body_scope.file(db)); - index.expect_single_definition(body_scope.node(db).expect_function(&module)) + index.expect_single_definition(body_scope.node(db).expect_function()) } /// Returns the overload immediately before this one in the AST. Returns `None` if there is no @@ -290,7 +306,8 @@ impl<'db> OverloadLiteral<'db> { let use_id = self .body_scope(db) .node(db) - .expect_function(&module) + .expect_function() + .node(&module) .name .scoped_use_id(db, scope); @@ -325,17 +342,79 @@ impl<'db> OverloadLiteral<'db> { db: &'db dyn Db, inherited_generic_context: Option>, ) -> Signature<'db> { + /// `self` or `cls` can be implicitly positional-only if: + /// - It is a method AND + /// - No parameters in the method use PEP-570 syntax AND + /// - It is not a `@staticmethod` AND + /// - `self`/`cls` is not explicitly positional-only using the PEP-484 convention AND + /// - Either the next parameter after `self`/`cls` uses the PEP-484 convention, + /// or the enclosing class is a `Protocol` class + fn has_implicitly_positional_only_first_param<'db>( + db: &'db dyn Db, + literal: OverloadLiteral<'db>, + node: &ast::StmtFunctionDef, + scope: FileScopeId, + index: &SemanticIndex, + ) -> bool { + let parameters = &node.parameters; + + if !parameters.posonlyargs.is_empty() { + return false; + } + + let Some(first_param) = parameters.args.first() else { + return false; + }; + + if first_param.uses_pep_484_positional_only_convention() { + return false; + } + + if literal.is_staticmethod(db) { + return false; + } + + let Some(class_definition) = index.class_definition_of_method(scope) else { + return false; + }; + + // `self` and `cls` are always positional-only if the next parameter uses the + // PEP-484 convention. + if parameters + .args + .get(1) + .is_some_and(ParameterWithDefault::uses_pep_484_positional_only_convention) + { + return true; + } + + // If there isn't any parameter other than `self`/`cls`, + // or there is but it isn't using the PEP-484 convention, + // then `self`/`cls` are only implicitly positional-only if + // it is a protocol class. + let class_type = binding_type(db, class_definition); + class_type + .to_class_type(db) + .is_some_and(|class| class.is_protocol(db)) + } + let scope = self.body_scope(db); let module = parsed_module(db, self.file(db)).load(db); - let function_stmt_node = scope.node(db).expect_function(&module); + let function_stmt_node = scope.node(db).expect_function().node(&module); let definition = self.definition(db); + let index = semantic_index(db, scope.file(db)); let generic_context = function_stmt_node.type_params.as_ref().map(|type_params| { - let index = semantic_index(db, scope.file(db)); GenericContext::from_type_params(db, index, definition, type_params) }); - - let index = semantic_index(db, scope.file(db)); - let is_generator = scope.file_scope_id(db).is_generator_function(index); + let file_scope_id = scope.file_scope_id(db); + let is_generator = file_scope_id.is_generator_function(index); + let has_implicitly_positional_first_parameter = has_implicitly_positional_only_first_param( + db, + self, + function_stmt_node, + file_scope_id, + index, + ); Signature::from_function( db, @@ -344,6 +423,7 @@ impl<'db> OverloadLiteral<'db> { definition, function_stmt_node, is_generator, + has_implicitly_positional_first_parameter, ) } @@ -356,7 +436,7 @@ impl<'db> OverloadLiteral<'db> { let span = Span::from(function_scope.file(db)); let node = function_scope.node(db); let module = parsed_module(db, self.file(db)).load(db); - let func_def = node.as_function(&module)?; + let func_def = node.as_function()?.node(&module); let range = parameter_index .and_then(|parameter_index| { func_def @@ -376,7 +456,7 @@ impl<'db> OverloadLiteral<'db> { let span = Span::from(function_scope.file(db)); let node = function_scope.node(db); let module = parsed_module(db, self.file(db)).load(db); - let func_def = node.as_function(&module)?; + let func_def = node.as_function()?.node(&module); let return_type_range = func_def.returns.as_ref().map(|returns| returns.range()); let mut signature = func_def.name.range.cover(func_def.parameters.range); if let Some(return_type_range) = return_type_range { @@ -713,17 +793,15 @@ impl<'db> FunctionType<'db> { /// Returns true if this method is decorated with `@classmethod`, or if it is implicitly a /// classmethod. pub(crate) fn is_classmethod(self, db: &'db dyn Db) -> bool { - self.has_known_decorator(db, FunctionDecorators::CLASSMETHOD) - || matches!( - self.name(db).as_str(), - "__init_subclass__" | "__class_getitem__" - ) + self.iter_overloads_and_implementation(db) + .any(|overload| overload.is_classmethod(db)) } /// Returns true if this method is decorated with `@staticmethod`, or if it is implicitly a /// static method. pub(crate) fn is_staticmethod(self, db: &'db dyn Db) -> bool { - self.has_known_decorator(db, FunctionDecorators::STATICMETHOD) || self.name(db) == "__new__" + self.iter_overloads_and_implementation(db) + .any(|overload| overload.is_staticmethod(db)) } /// If the implementation of this function is deprecated, returns the `@warnings.deprecated`. diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 8dc5379e0e..ae1fd7f409 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -422,12 +422,11 @@ pub(crate) fn nearest_enclosing_class<'db>( db: &'db dyn Db, semantic: &SemanticIndex<'db>, scope: ScopeId, - parsed: &ParsedModuleRef, ) -> Option> { semantic .ancestor_scopes(scope.file_scope_id(db)) .find_map(|(_, ancestor_scope)| { - let class = ancestor_scope.node().as_class(parsed)?; + let class = ancestor_scope.node().as_class()?; let definition = semantic.expect_single_definition(class); infer_definition_types(db, definition) .declaration_type(definition) @@ -2418,29 +2417,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { /// behaviour to the [`nearest_enclosing_class`] function. fn class_context_of_current_method(&self) -> Option> { let current_scope_id = self.scope().file_scope_id(self.db()); - let current_scope = self.index.scope(current_scope_id); - if current_scope.kind() != ScopeKind::Function { - return None; - } - let parent_scope_id = current_scope.parent()?; - let parent_scope = self.index.scope(parent_scope_id); - - let class_scope = match parent_scope.kind() { - ScopeKind::Class => parent_scope, - ScopeKind::TypeParams => { - let class_scope_id = parent_scope.parent()?; - let potentially_class_scope = self.index.scope(class_scope_id); - - match potentially_class_scope.kind() { - ScopeKind::Class => potentially_class_scope, - _ => return None, - } - } - _ => return None, - }; - - let class_stmt = class_scope.node().as_class(self.module())?; - let class_definition = self.index.expect_single_definition(class_stmt); + let class_definition = self.index.class_definition_of_method(current_scope_id)?; binding_type(self.db(), class_definition).to_class_type(self.db()) } @@ -2453,7 +2430,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { if !current_scope.kind().is_non_lambda_function() { return None; } - current_scope.node().as_function(self.module()) + current_scope + .node() + .as_function() + .map(|node_ref| node_ref.node(self.module())) } fn function_decorator_types<'a>( diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index 029869a0b2..ffd11bf77f 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -12,7 +12,7 @@ use std::{collections::HashMap, slice::Iter}; -use itertools::EitherOrBoth; +use itertools::{EitherOrBoth, Itertools}; use smallvec::{SmallVec, smallvec_inline}; use super::{DynamicType, Type, TypeVarVariance, definition_expression_type}; @@ -352,9 +352,14 @@ impl<'db> Signature<'db> { definition: Definition<'db>, function_node: &ast::StmtFunctionDef, is_generator: bool, + has_implicitly_positional_first_parameter: bool, ) -> Self { - let parameters = - Parameters::from_parameters(db, definition, function_node.parameters.as_ref()); + let parameters = Parameters::from_parameters( + db, + definition, + function_node.parameters.as_ref(), + has_implicitly_positional_first_parameter, + ); let return_ty = function_node.returns.as_ref().map(|returns| { let plain_return_ty = definition_expression_type(db, definition, returns.as_ref()) .apply_type_mapping( @@ -1139,6 +1144,7 @@ impl<'db> Parameters<'db> { db: &'db dyn Db, definition: Definition<'db>, parameters: &ast::Parameters, + has_implicitly_positional_first_parameter: bool, ) -> Self { let ast::Parameters { posonlyargs, @@ -1149,23 +1155,46 @@ impl<'db> Parameters<'db> { range: _, node_index: _, } = parameters; + let default_type = |param: &ast::ParameterWithDefault| { param .default() .map(|default| definition_expression_type(db, definition, default)) }; - let positional_only = posonlyargs.iter().map(|arg| { + + let pos_only_param = |param: &ast::ParameterWithDefault| { Parameter::from_node_and_kind( db, definition, - &arg.parameter, + ¶m.parameter, ParameterKind::PositionalOnly { - name: Some(arg.parameter.name.id.clone()), - default_type: default_type(arg), + name: Some(param.parameter.name.id.clone()), + default_type: default_type(param), }, ) - }); - let positional_or_keyword = args.iter().map(|arg| { + }; + + let mut positional_only: Vec = posonlyargs.iter().map(pos_only_param).collect(); + + let mut pos_or_keyword_iter = args.iter(); + + // If there are no PEP-570 positional-only parameters, check for the legacy PEP-484 convention + // for denoting positional-only parameters (parameters that start with `__` and do not end with `__`) + if positional_only.is_empty() { + let pos_or_keyword_iter = pos_or_keyword_iter.by_ref(); + + if has_implicitly_positional_first_parameter { + positional_only.extend(pos_or_keyword_iter.next().map(pos_only_param)); + } + + positional_only.extend( + pos_or_keyword_iter + .peeking_take_while(|param| param.uses_pep_484_positional_only_convention()) + .map(pos_only_param), + ); + } + + let positional_or_keyword = pos_or_keyword_iter.map(|arg| { Parameter::from_node_and_kind( db, definition, @@ -1176,6 +1205,7 @@ impl<'db> Parameters<'db> { }, ) }); + let variadic = vararg.as_ref().map(|arg| { Parameter::from_node_and_kind( db, @@ -1186,6 +1216,7 @@ impl<'db> Parameters<'db> { }, ) }); + let keyword_only = kwonlyargs.iter().map(|arg| { Parameter::from_node_and_kind( db, @@ -1197,6 +1228,7 @@ impl<'db> Parameters<'db> { }, ) }); + let keywords = kwarg.as_ref().map(|arg| { Parameter::from_node_and_kind( db, @@ -1207,8 +1239,10 @@ impl<'db> Parameters<'db> { }, ) }); + Self::new( positional_only + .into_iter() .chain(positional_or_keyword) .chain(variadic) .chain(keyword_only)