diff --git a/resources/test/fixtures/flake8_annotations/allow_overload.py b/resources/test/fixtures/flake8_annotations/allow_overload.py new file mode 100644 index 0000000000..72d4ee3cff --- /dev/null +++ b/resources/test/fixtures/flake8_annotations/allow_overload.py @@ -0,0 +1,49 @@ +from typing import overload + + +@overload +def foo(i: int) -> "int": + ... + + +@overload +def foo(i: "str") -> "str": + ... + + +def foo(i): + return i + + +@overload +def bar(i: int) -> "int": + ... + + +@overload +def bar(i: "str") -> "str": + ... + + +class X: + def bar(i): + return i + + +# TODO(charlie): This third case should raise an error (as in Mypy), because we have a +# statement between the interfaces and implementation. +@overload +def baz(i: int) -> "int": + ... + + +@overload +def baz(i: "str") -> "str": + ... + + +x = 1 + + +def baz(i): + return i diff --git a/src/check_ast.rs b/src/check_ast.rs index fb66ec4f92..2a4eb8c05a 100644 --- a/src/check_ast.rs +++ b/src/check_ast.rs @@ -3132,6 +3132,8 @@ impl<'a> Checker<'a> { } fn check_definitions(&mut self) { + let mut overloaded_name: Option = None; + self.definitions.reverse(); while let Some((definition, visibility)) = self.definitions.pop() { // flake8-annotations if self.settings.enabled.contains(&CheckCode::ANN001) @@ -3146,7 +3148,21 @@ impl<'a> Checker<'a> { || self.settings.enabled.contains(&CheckCode::ANN206) || self.settings.enabled.contains(&CheckCode::ANN401) { - flake8_annotations::plugins::definition(self, &definition, &visibility); + // TODO(charlie): This should be even stricter, in that an overload + // implementation should come immediately after the overloaded + // interfaces, without any AST nodes in between. Right now, we + // only error when traversing definition boundaries (functions, + // classes, etc.). + if !overloaded_name.map_or(false, |overloaded_name| { + flake8_annotations::helpers::is_overload_impl( + self, + &definition, + &overloaded_name, + ) + }) { + flake8_annotations::plugins::definition(self, &definition, &visibility); + } + overloaded_name = flake8_annotations::helpers::overloaded_name(self, &definition); } // pydocstyle diff --git a/src/flake8_annotations/helpers.rs b/src/flake8_annotations/helpers.rs new file mode 100644 index 0000000000..1e2ae2f86e --- /dev/null +++ b/src/flake8_annotations/helpers.rs @@ -0,0 +1,62 @@ +use rustpython_ast::{Arguments, Expr, Stmt, StmtKind}; + +use crate::check_ast::Checker; +use crate::docstrings::definition::{Definition, DefinitionKind}; +use crate::visibility; + +pub(super) fn match_function_def( + stmt: &Stmt, +) -> (&str, &Arguments, &Option>, &Vec) { + match &stmt.node { + StmtKind::FunctionDef { + name, + args, + returns, + body, + .. + } + | StmtKind::AsyncFunctionDef { + name, + args, + returns, + body, + .. + } => (name, args, returns, body), + _ => panic!("Found non-FunctionDef in match_name"), + } +} + +/// Return the name of the function, if it's overloaded. +pub fn overloaded_name(checker: &Checker, definition: &Definition) -> Option { + if let DefinitionKind::Function(stmt) + | DefinitionKind::NestedFunction(stmt) + | DefinitionKind::Method(stmt) = definition.kind + { + if visibility::is_overload(checker, stmt) { + let (name, ..) = match_function_def(stmt); + Some(name.to_string()) + } else { + None + } + } else { + None + } +} + +/// Return `true` if the definition is the implementation for an overloaded +/// function. +pub fn is_overload_impl(checker: &Checker, definition: &Definition, overloaded_name: &str) -> bool { + if let DefinitionKind::Function(stmt) + | DefinitionKind::NestedFunction(stmt) + | DefinitionKind::Method(stmt) = definition.kind + { + if visibility::is_overload(checker, stmt) { + false + } else { + let (name, ..) = match_function_def(stmt); + name == overloaded_name + } + } else { + false + } +} diff --git a/src/flake8_annotations/mod.rs b/src/flake8_annotations/mod.rs index 56f92b217e..6110cd1526 100644 --- a/src/flake8_annotations/mod.rs +++ b/src/flake8_annotations/mod.rs @@ -1,3 +1,4 @@ +pub mod helpers; pub mod plugins; pub mod settings; @@ -134,4 +135,24 @@ mod tests { insta::assert_yaml_snapshot!(checks); Ok(()) } + + #[test] + fn allow_overload() -> Result<()> { + let mut checks = test_path( + Path::new("./resources/test/fixtures/flake8_annotations/allow_overload.py"), + &Settings { + ..Settings::for_rules(vec![ + CheckCode::ANN201, + CheckCode::ANN202, + CheckCode::ANN204, + CheckCode::ANN205, + CheckCode::ANN206, + ]) + }, + true, + )?; + checks.sort_by_key(|check| check.location); + insta::assert_yaml_snapshot!(checks); + Ok(()) + } } diff --git a/src/flake8_annotations/plugins.rs b/src/flake8_annotations/plugins.rs index afd05d4cbe..fb5018bd16 100644 --- a/src/flake8_annotations/plugins.rs +++ b/src/flake8_annotations/plugins.rs @@ -1,4 +1,4 @@ -use rustpython_ast::{Arguments, Constant, Expr, ExprKind, Stmt, StmtKind}; +use rustpython_ast::{Constant, Expr, ExprKind, Stmt, StmtKind}; use crate::ast::types::Range; use crate::ast::visitor; @@ -6,6 +6,7 @@ use crate::ast::visitor::Visitor; use crate::check_ast::Checker; use crate::checks::{CheckCode, CheckKind}; use crate::docstrings::definition::{Definition, DefinitionKind}; +use crate::flake8_annotations::helpers::match_function_def; use crate::visibility::Visibility; use crate::{visibility, Check}; @@ -61,26 +62,6 @@ where }; } -fn match_function_def(stmt: &Stmt) -> (&str, &Arguments, &Option>, &Vec) { - match &stmt.node { - StmtKind::FunctionDef { - name, - args, - returns, - body, - .. - } - | StmtKind::AsyncFunctionDef { - name, - args, - returns, - body, - .. - } => (name, args, returns, body), - _ => panic!("Found non-FunctionDef in match_name"), - } -} - /// Generate flake8-annotation checks for a given `Definition`. pub fn definition(checker: &mut Checker, definition: &Definition, visibility: &Visibility) { // TODO(charlie): Consider using the AST directly here rather than `Definition`. diff --git a/src/flake8_annotations/snapshots/ruff__flake8_annotations__tests__allow_overload.snap b/src/flake8_annotations/snapshots/ruff__flake8_annotations__tests__allow_overload.snap new file mode 100644 index 0000000000..de6ed83da2 --- /dev/null +++ b/src/flake8_annotations/snapshots/ruff__flake8_annotations__tests__allow_overload.snap @@ -0,0 +1,14 @@ +--- +source: src/flake8_annotations/mod.rs +expression: checks +--- +- kind: + MissingReturnTypePublicFunction: bar + location: + row: 29 + column: 4 + end_location: + row: 35 + column: 0 + fix: ~ +