mirror of
https://github.com/astral-sh/ruff.git
synced 2025-08-03 18:28:24 +00:00
[red-knot] Check for invalid overload usages (#17609)
Some checks are pending
CI / Determine changes (push) Waiting to run
CI / cargo fmt (push) Waiting to run
CI / cargo clippy (push) Blocked by required conditions
CI / cargo test (linux) (push) Blocked by required conditions
CI / cargo test (linux, release) (push) Blocked by required conditions
CI / cargo test (windows) (push) Blocked by required conditions
CI / cargo test (wasm) (push) Blocked by required conditions
CI / cargo build (release) (push) Waiting to run
CI / cargo build (msrv) (push) Blocked by required conditions
CI / cargo fuzz build (push) Blocked by required conditions
CI / fuzz parser (push) Blocked by required conditions
CI / test scripts (push) Blocked by required conditions
CI / ecosystem (push) Blocked by required conditions
CI / Fuzz for new red-knot panics (push) Blocked by required conditions
CI / cargo shear (push) Blocked by required conditions
CI / python package (push) Waiting to run
CI / pre-commit (push) Waiting to run
CI / mkdocs (push) Waiting to run
CI / formatter instabilities and black similarity (push) Blocked by required conditions
CI / test ruff-lsp (push) Blocked by required conditions
CI / check playground (push) Blocked by required conditions
CI / benchmarks (push) Blocked by required conditions
[Knot Playground] Release / publish (push) Waiting to run
Some checks are pending
CI / Determine changes (push) Waiting to run
CI / cargo fmt (push) Waiting to run
CI / cargo clippy (push) Blocked by required conditions
CI / cargo test (linux) (push) Blocked by required conditions
CI / cargo test (linux, release) (push) Blocked by required conditions
CI / cargo test (windows) (push) Blocked by required conditions
CI / cargo test (wasm) (push) Blocked by required conditions
CI / cargo build (release) (push) Waiting to run
CI / cargo build (msrv) (push) Blocked by required conditions
CI / cargo fuzz build (push) Blocked by required conditions
CI / fuzz parser (push) Blocked by required conditions
CI / test scripts (push) Blocked by required conditions
CI / ecosystem (push) Blocked by required conditions
CI / Fuzz for new red-knot panics (push) Blocked by required conditions
CI / cargo shear (push) Blocked by required conditions
CI / python package (push) Waiting to run
CI / pre-commit (push) Waiting to run
CI / mkdocs (push) Waiting to run
CI / formatter instabilities and black similarity (push) Blocked by required conditions
CI / test ruff-lsp (push) Blocked by required conditions
CI / check playground (push) Blocked by required conditions
CI / benchmarks (push) Blocked by required conditions
[Knot Playground] Release / publish (push) Waiting to run
## Summary Part of #15383, this PR adds the core infrastructure to check for invalid overloads and adds a diagnostic to raise if there are < 2 overloads for a given definition. ### Design notes The requirements to check the overloads are: * Requires `FunctionType` which has the `to_overloaded` method * The `FunctionType` **should** be for the function that is either the implementation or the last overload if the implementation doesn't exists * Avoid checking any `FunctionType` that are part of an overload chain * Consider visibility constraints This required a couple of iteration to make sure all of the above requirements are fulfilled. #### 1. Use a set to deduplicate The logic would first collect all the `FunctionType` that are part of the overload chain except for the implementation or the last overload if the implementation doesn't exists. Then, when iterating over all the function declarations within the scope, we'd avoid checking these functions. But, this approach would fail to consider visibility constraints as certain overloads _can_ be behind a version check. Those aren't part of the overload chain but those aren't a separate overload chain either. <details><summary>Implementation:</summary> <p> ```rs fn check_overloaded_functions(&mut self) { let function_definitions = || { self.types .declarations .iter() .filter_map(|(definition, ty)| { // Filter out function literals that result from anything other than a function // definition e.g., imports. if let DefinitionKind::Function(function) = definition.kind(self.db()) { ty.inner_type() .into_function_literal() .map(|ty| (ty, definition.symbol(self.db()), function.node())) } else { None } }) }; // A set of all the functions that are part of an overloaded function definition except for // the implementation function and the last overload in case the implementation doesn't // exists. This allows us to collect all the function definitions that needs to be skipped // when checking for invalid overload usages. let mut overloads: HashSet<FunctionType<'db>> = HashSet::default(); for (function, _) in function_definitions() { let Some(overloaded) = function.to_overloaded(self.db()) else { continue; }; if overloaded.implementation.is_some() { overloads.extend(overloaded.overloads.iter().copied()); } else if let Some((_, previous_overloads)) = overloaded.overloads.split_last() { overloads.extend(previous_overloads.iter().copied()); } } for (function, function_node) in function_definitions() { let Some(overloaded) = function.to_overloaded(self.db()) else { continue; }; if overloads.contains(&function) { continue; } // At this point, the `function` variable is either the implementation function or the // last overloaded function if the implementation doesn't exists. if overloaded.overloads.len() < 2 { if let Some(builder) = self .context .report_lint(&INVALID_OVERLOAD, &function_node.name) { let mut diagnostic = builder.into_diagnostic(format_args!( "Function `{}` requires at least two overloads", &function_node.name )); if let Some(first_overload) = overloaded.overloads.first() { diagnostic.annotate( self.context .secondary(first_overload.focus_range(self.db())) .message(format_args!("Only one overload defined here")), ); } } } } } ``` </p> </details> #### 2. Define a `predecessor` query The `predecessor` query would return the previous `FunctionType` for the given `FunctionType` i.e., the current logic would be extracted to be a query instead. This could then be used to make sure that we're checking the entire overload chain once. The way this would've been implemented is to have a `to_overloaded` implementation which would take the root of the overload chain instead of the leaf. But, this would require updates to the use-def map to somehow be able to return the _following_ functions for a given definition. #### 3. Create a successor link This is what Pyrefly uses, we'd create a forward link between two functions that are involved in an overload chain. This means that for a given function, we can get the successor function. This could be used to find the _leaf_ of the overload chain which can then be used with the `to_overloaded` method to get the entire overload chain. But, this would also require updating the use-def map to be able to "see" the _following_ function. ### Implementation This leads us to the final implementation that this PR implements which is to consider the overloaded functions using: * Collect all the **function symbols** that are defined **and** called within the same file. This could potentially be an overloaded function * Use the public bindings to get the leaf of the overload chain and use that to get the entire overload chain via `to_overloaded` and perform the check This has a limitation that in case a function redefines an overload, then that overload will not be checked. For example: ```py from typing import overload @overload def f() -> None: ... @overload def f(x: int) -> int: ... # The above overload will not be checked as the below function with the same name # shadows it def f(*args: int) -> int: ... ``` ## Test Plan Update existing mdtest and add snapshot diagnostics.
This commit is contained in:
parent
0861ecfa55
commit
ad1a8da4d1
6 changed files with 283 additions and 7 deletions
|
@ -309,18 +309,29 @@ reveal_type(func("")) # revealed: Literal[""]
|
|||
|
||||
### At least two overloads
|
||||
|
||||
<!-- snapshot-diagnostics -->
|
||||
|
||||
At least two `@overload`-decorated definitions must be present.
|
||||
|
||||
```py
|
||||
from typing import overload
|
||||
|
||||
# TODO: error
|
||||
@overload
|
||||
def func(x: int) -> int: ...
|
||||
|
||||
# error: [invalid-overload]
|
||||
def func(x: int | str) -> int | str:
|
||||
return x
|
||||
```
|
||||
|
||||
```pyi
|
||||
from typing import overload
|
||||
|
||||
@overload
|
||||
# error: [invalid-overload]
|
||||
def func(x: int) -> int: ...
|
||||
```
|
||||
|
||||
### Overload without an implementation
|
||||
|
||||
#### Regular modules
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
---
|
||||
source: crates/red_knot_test/src/lib.rs
|
||||
expression: snapshot
|
||||
---
|
||||
---
|
||||
mdtest name: overloads.md - Overloads - Invalid - At least two overloads
|
||||
mdtest path: crates/red_knot_python_semantic/resources/mdtest/overloads.md
|
||||
---
|
||||
|
||||
# Python source files
|
||||
|
||||
## mdtest_snippet.py
|
||||
|
||||
```
|
||||
1 | from typing import overload
|
||||
2 |
|
||||
3 | @overload
|
||||
4 | def func(x: int) -> int: ...
|
||||
5 |
|
||||
6 | # error: [invalid-overload]
|
||||
7 | def func(x: int | str) -> int | str:
|
||||
8 | return x
|
||||
```
|
||||
|
||||
## mdtest_snippet.pyi
|
||||
|
||||
```
|
||||
1 | from typing import overload
|
||||
2 |
|
||||
3 | @overload
|
||||
4 | # error: [invalid-overload]
|
||||
5 | def func(x: int) -> int: ...
|
||||
```
|
||||
|
||||
# Diagnostics
|
||||
|
||||
```
|
||||
error: lint:invalid-overload: Overloaded function `func` requires at least two overloads
|
||||
--> src/mdtest_snippet.py:4:5
|
||||
|
|
||||
3 | @overload
|
||||
4 | def func(x: int) -> int: ...
|
||||
| ---- Only one overload defined here
|
||||
5 |
|
||||
6 | # error: [invalid-overload]
|
||||
7 | def func(x: int | str) -> int | str:
|
||||
| ^^^^
|
||||
8 | return x
|
||||
|
|
||||
|
||||
```
|
||||
|
||||
```
|
||||
error: lint:invalid-overload: Overloaded function `func` requires at least two overloads
|
||||
--> src/mdtest_snippet.pyi:5:5
|
||||
|
|
||||
3 | @overload
|
||||
4 | # error: [invalid-overload]
|
||||
5 | def func(x: int) -> int: ...
|
||||
| ----
|
||||
| |
|
||||
| Only one overload defined here
|
||||
|
|
||||
|
||||
```
|
|
@ -6525,6 +6525,13 @@ pub struct FunctionType<'db> {
|
|||
|
||||
#[salsa::tracked]
|
||||
impl<'db> FunctionType<'db> {
|
||||
/// Returns the [`File`] in which this function is defined.
|
||||
pub(crate) fn file(self, db: &'db dyn Db) -> File {
|
||||
// NOTE: Do not use `self.definition(db).file(db)` here, as that could create a
|
||||
// cross-module dependency on the full AST.
|
||||
self.body_scope(db).file(db)
|
||||
}
|
||||
|
||||
pub(crate) fn has_known_decorator(self, db: &dyn Db, decorator: FunctionDecorators) -> bool {
|
||||
self.decorators(db).contains(decorator)
|
||||
}
|
||||
|
@ -6546,21 +6553,41 @@ impl<'db> FunctionType<'db> {
|
|||
Type::BoundMethod(BoundMethodType::new(db, self, self_instance))
|
||||
}
|
||||
|
||||
/// Returns the AST node for this function.
|
||||
pub(crate) fn node(self, db: &'db dyn Db, file: File) -> &'db ast::StmtFunctionDef {
|
||||
debug_assert_eq!(
|
||||
file,
|
||||
self.file(db),
|
||||
"FunctionType::node() must be called with the same file as the one where \
|
||||
the function is defined."
|
||||
);
|
||||
|
||||
self.body_scope(db).node(db).expect_function()
|
||||
}
|
||||
|
||||
/// Returns the [`FileRange`] of the function's name.
|
||||
pub fn focus_range(self, db: &dyn Db) -> FileRange {
|
||||
FileRange::new(
|
||||
self.body_scope(db).file(db),
|
||||
self.file(db),
|
||||
self.body_scope(db).node(db).expect_function().name.range,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn full_range(self, db: &dyn Db) -> FileRange {
|
||||
FileRange::new(
|
||||
self.body_scope(db).file(db),
|
||||
self.file(db),
|
||||
self.body_scope(db).node(db).expect_function().range,
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns the [`Definition`] of this function.
|
||||
///
|
||||
/// ## Warning
|
||||
///
|
||||
/// This uses the semantic index to find the definition of the function. This means that if the
|
||||
/// calling query is not in the same file as this function is defined in, then this will create
|
||||
/// a cross-module dependency directly on the full AST which will lead to cache
|
||||
/// over-invalidation.
|
||||
pub(crate) fn definition(self, db: &'db dyn Db) -> Definition<'db> {
|
||||
let body_scope = self.body_scope(db);
|
||||
let index = semantic_index(db, body_scope.file(db));
|
||||
|
|
|
@ -37,6 +37,7 @@ pub(crate) fn register_lints(registry: &mut LintRegistryBuilder) {
|
|||
registry.register_lint(&INVALID_EXCEPTION_CAUGHT);
|
||||
registry.register_lint(&INVALID_LEGACY_TYPE_VARIABLE);
|
||||
registry.register_lint(&INVALID_METACLASS);
|
||||
registry.register_lint(&INVALID_OVERLOAD);
|
||||
registry.register_lint(&INVALID_PARAMETER_DEFAULT);
|
||||
registry.register_lint(&INVALID_PROTOCOL);
|
||||
registry.register_lint(&INVALID_RAISE);
|
||||
|
@ -447,6 +448,49 @@ declare_lint! {
|
|||
}
|
||||
}
|
||||
|
||||
declare_lint! {
|
||||
/// ## What it does
|
||||
/// Checks for various invalid `@overload` usages.
|
||||
///
|
||||
/// ## Why is this bad?
|
||||
/// The `@overload` decorator is used to define functions and methods that accepts different
|
||||
/// combinations of arguments and return different types based on the arguments passed. This is
|
||||
/// mainly beneficial for type checkers. But, if the `@overload` usage is invalid, the type
|
||||
/// checker may not be able to provide correct type information.
|
||||
///
|
||||
/// ## Example
|
||||
///
|
||||
/// Defining only one overload:
|
||||
///
|
||||
/// ```py
|
||||
/// from typing import overload
|
||||
///
|
||||
/// @overload
|
||||
/// def foo(x: int) -> int: ...
|
||||
/// def foo(x: int | None) -> int | None:
|
||||
/// return x
|
||||
/// ```
|
||||
///
|
||||
/// Or, not providing an implementation for the overloaded definition:
|
||||
///
|
||||
/// ```py
|
||||
/// from typing import overload
|
||||
///
|
||||
/// @overload
|
||||
/// def foo() -> None: ...
|
||||
/// @overload
|
||||
/// def foo(x: int) -> int: ...
|
||||
/// ```
|
||||
///
|
||||
/// ## References
|
||||
/// - [Python documentation: `@overload`](https://docs.python.org/3/library/typing.html#typing.overload)
|
||||
pub(crate) static INVALID_OVERLOAD = {
|
||||
summary: "detects invalid `@overload` usages",
|
||||
status: LintStatus::preview("1.0.0"),
|
||||
default_level: Level::Error,
|
||||
}
|
||||
}
|
||||
|
||||
declare_lint! {
|
||||
/// ## What it does
|
||||
/// Checks for default values that can't be assigned to the parameter's annotated type.
|
||||
|
|
|
@ -101,8 +101,8 @@ use super::diagnostic::{
|
|||
report_invalid_exception_raised, report_invalid_type_checking_constant,
|
||||
report_non_subscriptable, report_possibly_unresolved_reference,
|
||||
report_runtime_check_against_non_runtime_checkable_protocol, report_slice_step_size_zero,
|
||||
report_unresolved_reference, INVALID_METACLASS, INVALID_PROTOCOL, REDUNDANT_CAST,
|
||||
STATIC_ASSERT_ERROR, SUBCLASS_OF_FINAL_CLASS, TYPE_ASSERTION_FAILURE,
|
||||
report_unresolved_reference, INVALID_METACLASS, INVALID_OVERLOAD, INVALID_PROTOCOL,
|
||||
REDUNDANT_CAST, STATIC_ASSERT_ERROR, SUBCLASS_OF_FINAL_CLASS, TYPE_ASSERTION_FAILURE,
|
||||
};
|
||||
use super::slots::check_class_slots;
|
||||
use super::string_annotation::{
|
||||
|
@ -418,7 +418,7 @@ impl<'db> TypeInference<'db> {
|
|||
.copied()
|
||||
.or(self.cycle_fallback_type)
|
||||
.expect(
|
||||
"definition should belong to this TypeInference region and
|
||||
"definition should belong to this TypeInference region and \
|
||||
TypeInferenceBuilder should have inferred a type for it",
|
||||
)
|
||||
}
|
||||
|
@ -430,7 +430,7 @@ impl<'db> TypeInference<'db> {
|
|||
.copied()
|
||||
.or(self.cycle_fallback_type.map(Into::into))
|
||||
.expect(
|
||||
"definition should belong to this TypeInference region and
|
||||
"definition should belong to this TypeInference region and \
|
||||
TypeInferenceBuilder should have inferred a type for it",
|
||||
)
|
||||
}
|
||||
|
@ -524,6 +524,31 @@ pub(super) struct TypeInferenceBuilder<'db> {
|
|||
/// The returned types and their corresponding ranges of the region, if it is a function body.
|
||||
return_types_and_ranges: Vec<TypeAndRange<'db>>,
|
||||
|
||||
/// A set of functions that have been defined **and** called in this region.
|
||||
///
|
||||
/// This is a set because the same function could be called multiple times in the same region.
|
||||
/// This is mainly used in [`check_overloaded_functions`] to check an overloaded function that
|
||||
/// is shadowed by a function with the same name in this scope but has been called before. For
|
||||
/// example:
|
||||
///
|
||||
/// ```py
|
||||
/// from typing import overload
|
||||
///
|
||||
/// @overload
|
||||
/// def foo() -> None: ...
|
||||
/// @overload
|
||||
/// def foo(x: int) -> int: ...
|
||||
/// def foo(x: int | None) -> int | None: return x
|
||||
///
|
||||
/// foo() # An overloaded function that was defined in this scope have been called
|
||||
///
|
||||
/// def foo(x: int) -> int:
|
||||
/// return x
|
||||
/// ```
|
||||
///
|
||||
/// [`check_overloaded_functions`]: TypeInferenceBuilder::check_overloaded_functions
|
||||
called_functions: FxHashSet<FunctionType<'db>>,
|
||||
|
||||
/// The deferred state of inferring types of certain expressions within the region.
|
||||
///
|
||||
/// This is different from [`InferenceRegion::Deferred`] which works on the entire definition
|
||||
|
@ -556,6 +581,7 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
index,
|
||||
region,
|
||||
return_types_and_ranges: vec![],
|
||||
called_functions: FxHashSet::default(),
|
||||
deferred_state: DeferredExpressionState::None,
|
||||
types: TypeInference::empty(scope),
|
||||
}
|
||||
|
@ -718,6 +744,7 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
|
||||
// TODO: Only call this function when diagnostics are enabled.
|
||||
self.check_class_definitions();
|
||||
self.check_overloaded_functions();
|
||||
}
|
||||
|
||||
/// Iterate over all class definitions to check that the definition will not cause an exception
|
||||
|
@ -952,6 +979,86 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Check the overloaded functions in this scope.
|
||||
///
|
||||
/// This only checks the overloaded functions that are:
|
||||
/// 1. Visible publicly at the end of this scope
|
||||
/// 2. Or, defined and called in this scope
|
||||
///
|
||||
/// For (1), this has the consequence of not checking an overloaded function that is being
|
||||
/// shadowed by another function with the same name in this scope.
|
||||
fn check_overloaded_functions(&mut self) {
|
||||
// Collect all the unique overloaded function symbols in this scope. This requires a set
|
||||
// because an overloaded function uses the same symbol for each of the overloads and the
|
||||
// implementation.
|
||||
let overloaded_function_symbols: FxHashSet<_> = self
|
||||
.types
|
||||
.declarations
|
||||
.iter()
|
||||
.filter_map(|(definition, ty)| {
|
||||
// Filter out function literals that result from anything other than a function
|
||||
// definition e.g., imports which would create a cross-module AST dependency.
|
||||
if !matches!(definition.kind(self.db()), DefinitionKind::Function(_)) {
|
||||
return None;
|
||||
}
|
||||
let function = ty.inner_type().into_function_literal()?;
|
||||
if function.has_known_decorator(self.db(), FunctionDecorators::OVERLOAD) {
|
||||
Some(definition.symbol(self.db()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let use_def = self
|
||||
.index
|
||||
.use_def_map(self.scope().file_scope_id(self.db()));
|
||||
|
||||
let mut public_functions = FxHashSet::default();
|
||||
|
||||
for symbol in overloaded_function_symbols {
|
||||
if let Symbol::Type(Type::FunctionLiteral(function), Boundness::Bound) =
|
||||
symbol_from_bindings(self.db(), use_def.public_bindings(symbol))
|
||||
{
|
||||
if function.file(self.db()) != self.file() {
|
||||
// If the function is not in this file, we don't need to check it.
|
||||
// https://github.com/astral-sh/ruff/pull/17609#issuecomment-2839445740
|
||||
continue;
|
||||
}
|
||||
|
||||
// Extend the functions that we need to check with the publicly visible overloaded
|
||||
// function. This is always going to be either the implementation or the last
|
||||
// overload if the implementation doesn't exists.
|
||||
public_functions.insert(function);
|
||||
}
|
||||
}
|
||||
|
||||
for function in self.called_functions.union(&public_functions) {
|
||||
let Some(overloaded) = function.to_overloaded(self.db()) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
// Check that the overloaded function has at least two overloads
|
||||
if let [single_overload] = overloaded.overloads.as_slice() {
|
||||
let function_node = function.node(self.db(), self.file());
|
||||
if let Some(builder) = self
|
||||
.context
|
||||
.report_lint(&INVALID_OVERLOAD, &function_node.name)
|
||||
{
|
||||
let mut diagnostic = builder.into_diagnostic(format_args!(
|
||||
"Overloaded function `{}` requires at least two overloads",
|
||||
&function_node.name
|
||||
));
|
||||
diagnostic.annotate(
|
||||
self.context
|
||||
.secondary(single_overload.focus_range(self.db()))
|
||||
.message(format_args!("Only one overload defined here")),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_region_definition(&mut self, definition: Definition<'db>) {
|
||||
match definition.kind(self.db()) {
|
||||
DefinitionKind::Function(function) => {
|
||||
|
@ -4299,6 +4406,18 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
let mut call_arguments = Self::parse_arguments(arguments);
|
||||
let callable_type = self.infer_expression(func);
|
||||
|
||||
if let Type::FunctionLiteral(function) = callable_type {
|
||||
// Make sure that the `function.definition` is only called when the function is defined
|
||||
// in the same file as the one we're currently inferring the types for. This is because
|
||||
// the `definition` method accesses the semantic index, which could create a
|
||||
// cross-module AST dependency.
|
||||
if function.file(self.db()) == self.file()
|
||||
&& function.definition(self.db()).scope(self.db()) == self.scope()
|
||||
{
|
||||
self.called_functions.insert(function);
|
||||
}
|
||||
}
|
||||
|
||||
// It might look odd here that we emit an error for class-literals but not `type[]` types.
|
||||
// But it's deliberate! The typing spec explicitly mandates that `type[]` types can be called
|
||||
// even though class-literals cannot. This is because even though a protocol class `SomeProtocol`
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue