Enable auto-return-type involving Optional and Union annotations (#8885)

## Summary

Previously, this was only supported for Python 3.10 and later, since we
always use the PEP 604-style unions.
This commit is contained in:
Charlie Marsh 2023-11-28 18:35:55 -08:00 committed by GitHub
parent ec7456bac0
commit 6435e4e4aa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 526 additions and 69 deletions

View file

@ -42,8 +42,24 @@ def func(x: int):
return {"foo": 1}
def func():
def func(x: int):
if not x:
return 1
else:
return True
def func(x: int):
if not x:
return 1
else:
return None
def func(x: int):
if not x:
return 1
elif x > 5:
return "str"
else:
return None

View file

@ -1,12 +1,17 @@
use itertools::Itertools;
use ruff_diagnostics::Edit;
use rustc_hash::FxHashSet;
use ruff_python_ast::helpers::{pep_604_union, ReturnStatementVisitor};
use crate::importer::{ImportRequest, Importer};
use ruff_python_ast::helpers::{
pep_604_union, typing_optional, typing_union, ReturnStatementVisitor,
};
use ruff_python_ast::visitor::Visitor;
use ruff_python_ast::{self as ast, Expr, ExprContext};
use ruff_python_semantic::analyze::type_inference::{NumberLike, PythonType, ResolvedPythonType};
use ruff_python_semantic::analyze::visibility;
use ruff_python_semantic::{Definition, SemanticModel};
use ruff_text_size::TextRange;
use ruff_text_size::{TextRange, TextSize};
use crate::settings::types::PythonVersion;
@ -38,10 +43,7 @@ pub(crate) fn is_overload_impl(
}
/// Given a function, guess its return type.
pub(crate) fn auto_return_type(
function: &ast::StmtFunctionDef,
target_version: PythonVersion,
) -> Option<Expr> {
pub(crate) fn auto_return_type(function: &ast::StmtFunctionDef) -> Option<AutoPythonType> {
// Collect all the `return` statements.
let returns = {
let mut visitor = ReturnStatementVisitor::default();
@ -68,24 +70,94 @@ pub(crate) fn auto_return_type(
}
match return_type {
ResolvedPythonType::Atom(python_type) => type_expr(python_type),
ResolvedPythonType::Union(python_types) if target_version >= PythonVersion::Py310 => {
// Aggregate all the individual types (e.g., `int`, `float`).
let names = python_types
.iter()
.sorted_unstable()
.map(|python_type| type_expr(*python_type))
.collect::<Option<Vec<_>>>()?;
// Wrap in a bitwise union (e.g., `int | float`).
Some(pep_604_union(&names))
}
ResolvedPythonType::Union(_) => None,
ResolvedPythonType::Atom(python_type) => Some(AutoPythonType::Atom(python_type)),
ResolvedPythonType::Union(python_types) => Some(AutoPythonType::Union(python_types)),
ResolvedPythonType::Unknown => None,
ResolvedPythonType::TypeError => None,
}
}
#[derive(Debug)]
pub(crate) enum AutoPythonType {
Atom(PythonType),
Union(FxHashSet<PythonType>),
}
impl AutoPythonType {
/// Convert an [`AutoPythonType`] into an [`Expr`].
///
/// If the [`Expr`] relies on importing any external symbols, those imports will be returned as
/// additional edits.
pub(crate) fn into_expression(
self,
importer: &Importer,
at: TextSize,
semantic: &SemanticModel,
target_version: PythonVersion,
) -> Option<(Expr, Vec<Edit>)> {
match self {
AutoPythonType::Atom(python_type) => {
let expr = type_expr(python_type)?;
Some((expr, vec![]))
}
AutoPythonType::Union(python_types) => {
if target_version >= PythonVersion::Py310 {
// Aggregate all the individual types (e.g., `int`, `float`).
let names = python_types
.iter()
.sorted_unstable()
.map(|python_type| type_expr(*python_type))
.collect::<Option<Vec<_>>>()?;
// Wrap in a bitwise union (e.g., `int | float`).
let expr = pep_604_union(&names);
Some((expr, vec![]))
} else {
let python_types = python_types
.into_iter()
.sorted_unstable()
.collect::<Vec<_>>();
match python_types.as_slice() {
[python_type, PythonType::None] | [PythonType::None, python_type] => {
let element = type_expr(*python_type)?;
// Ex) `Optional[int]`
let (optional_edit, binding) = importer
.get_or_import_symbol(
&ImportRequest::import_from("typing", "Optional"),
at,
semantic,
)
.ok()?;
let expr = typing_optional(element, binding);
Some((expr, vec![optional_edit]))
}
_ => {
let elements = python_types
.into_iter()
.map(type_expr)
.collect::<Option<Vec<_>>>()?;
// Ex) `Union[int, str]`
let (union_edit, binding) = importer
.get_or_import_symbol(
&ImportRequest::import_from("typing", "Union"),
at,
semantic,
)
.ok()?;
let expr = typing_union(&elements, binding);
Some((expr, vec![union_edit]))
}
}
}
}
}
}
}
/// Given a [`PythonType`], return an [`Expr`] that resolves to that type.
fn type_expr(python_type: PythonType) -> Option<Expr> {
fn name(name: &str) -> Expr {

View file

@ -11,6 +11,7 @@ mod tests {
use crate::assert_messages;
use crate::registry::Rule;
use crate::settings::types::PythonVersion;
use crate::settings::LinterSettings;
use crate::test::test_path;
@ -128,6 +129,25 @@ mod tests {
Ok(())
}
#[test]
fn auto_return_type_py38() -> Result<()> {
let diagnostics = test_path(
Path::new("flake8_annotations/auto_return_type.py"),
&LinterSettings {
target_version: PythonVersion::Py38,
..LinterSettings::for_rules(vec![
Rule::MissingReturnTypeUndocumentedPublicFunction,
Rule::MissingReturnTypePrivateFunction,
Rule::MissingReturnTypeSpecialMethod,
Rule::MissingReturnTypeStaticMethod,
Rule::MissingReturnTypeClassMethod,
])
},
)?;
assert_messages!(diagnostics);
Ok(())
}
#[test]
fn suppress_none_returning() -> Result<()> {
let diagnostics = test_path(

View file

@ -725,39 +725,55 @@ pub(crate) fn definition(
) {
if is_method && visibility::is_classmethod(decorator_list, checker.semantic()) {
if checker.enabled(Rule::MissingReturnTypeClassMethod) {
let return_type = auto_return_type(function, checker.settings.target_version)
.map(|return_type| checker.generator().expr(&return_type));
let return_type = auto_return_type(function)
.and_then(|return_type| {
return_type.into_expression(
checker.importer(),
function.parameters.start(),
checker.semantic(),
checker.settings.target_version,
)
})
.map(|(return_type, edits)| (checker.generator().expr(&return_type), edits));
let mut diagnostic = Diagnostic::new(
MissingReturnTypeClassMethod {
name: name.to_string(),
annotation: return_type.clone(),
annotation: return_type.clone().map(|(return_type, ..)| return_type),
},
function.identifier(),
);
if let Some(return_type) = return_type {
diagnostic.set_fix(Fix::unsafe_edit(Edit::insertion(
format!(" -> {return_type}"),
function.parameters.range().end(),
)));
if let Some((return_type, edits)) = return_type {
diagnostic.set_fix(Fix::unsafe_edits(
Edit::insertion(format!(" -> {return_type}"), function.parameters.end()),
edits,
));
}
diagnostics.push(diagnostic);
}
} else if is_method && visibility::is_staticmethod(decorator_list, checker.semantic()) {
if checker.enabled(Rule::MissingReturnTypeStaticMethod) {
let return_type = auto_return_type(function, checker.settings.target_version)
.map(|return_type| checker.generator().expr(&return_type));
let return_type = auto_return_type(function)
.and_then(|return_type| {
return_type.into_expression(
checker.importer(),
function.parameters.start(),
checker.semantic(),
checker.settings.target_version,
)
})
.map(|(return_type, edits)| (checker.generator().expr(&return_type), edits));
let mut diagnostic = Diagnostic::new(
MissingReturnTypeStaticMethod {
name: name.to_string(),
annotation: return_type.clone(),
annotation: return_type.clone().map(|(return_type, ..)| return_type),
},
function.identifier(),
);
if let Some(return_type) = return_type {
diagnostic.set_fix(Fix::unsafe_edit(Edit::insertion(
format!(" -> {return_type}"),
function.parameters.range().end(),
)));
if let Some((return_type, edits)) = return_type {
diagnostic.set_fix(Fix::unsafe_edits(
Edit::insertion(format!(" -> {return_type}"), function.parameters.end()),
edits,
));
}
diagnostics.push(diagnostic);
}
@ -775,7 +791,7 @@ pub(crate) fn definition(
);
diagnostic.set_fix(Fix::unsafe_edit(Edit::insertion(
" -> None".to_string(),
function.parameters.range().end(),
function.parameters.end(),
)));
diagnostics.push(diagnostic);
}
@ -793,7 +809,7 @@ pub(crate) fn definition(
if let Some(return_type) = return_type {
diagnostic.set_fix(Fix::unsafe_edit(Edit::insertion(
format!(" -> {return_type}"),
function.parameters.range().end(),
function.parameters.end(),
)));
}
diagnostics.push(diagnostic);
@ -802,42 +818,70 @@ pub(crate) fn definition(
match visibility {
visibility::Visibility::Public => {
if checker.enabled(Rule::MissingReturnTypeUndocumentedPublicFunction) {
let return_type =
auto_return_type(function, checker.settings.target_version)
.map(|return_type| checker.generator().expr(&return_type));
let return_type = auto_return_type(function)
.and_then(|return_type| {
return_type.into_expression(
checker.importer(),
function.parameters.start(),
checker.semantic(),
checker.settings.target_version,
)
})
.map(|(return_type, edits)| {
(checker.generator().expr(&return_type), edits)
});
let mut diagnostic = Diagnostic::new(
MissingReturnTypeUndocumentedPublicFunction {
name: name.to_string(),
annotation: return_type.clone(),
annotation: return_type
.clone()
.map(|(return_type, ..)| return_type),
},
function.identifier(),
);
if let Some(return_type) = return_type {
diagnostic.set_fix(Fix::unsafe_edit(Edit::insertion(
format!(" -> {return_type}"),
function.parameters.range().end(),
)));
if let Some((return_type, edits)) = return_type {
diagnostic.set_fix(Fix::unsafe_edits(
Edit::insertion(
format!(" -> {return_type}"),
function.parameters.end(),
),
edits,
));
}
diagnostics.push(diagnostic);
}
}
visibility::Visibility::Private => {
if checker.enabled(Rule::MissingReturnTypePrivateFunction) {
let return_type =
auto_return_type(function, checker.settings.target_version)
.map(|return_type| checker.generator().expr(&return_type));
let return_type = auto_return_type(function)
.and_then(|return_type| {
return_type.into_expression(
checker.importer(),
function.parameters.start(),
checker.semantic(),
checker.settings.target_version,
)
})
.map(|(return_type, edits)| {
(checker.generator().expr(&return_type), edits)
});
let mut diagnostic = Diagnostic::new(
MissingReturnTypePrivateFunction {
name: name.to_string(),
annotation: return_type.clone(),
annotation: return_type
.clone()
.map(|(return_type, ..)| return_type),
},
function.identifier(),
);
if let Some(return_type) = return_type {
diagnostic.set_fix(Fix::unsafe_edit(Edit::insertion(
format!(" -> {return_type}"),
function.parameters.range().end(),
)));
if let Some((return_type, edits)) = return_type {
diagnostic.set_fix(Fix::unsafe_edits(
Edit::insertion(
format!(" -> {return_type}"),
function.parameters.end(),
),
edits,
));
}
diagnostics.push(diagnostic);
}

View file

@ -145,7 +145,7 @@ auto_return_type.py:41:5: ANN201 Missing return type annotation for public funct
auto_return_type.py:45:5: ANN201 [*] Missing return type annotation for public function `func`
|
45 | def func():
45 | def func(x: int):
| ^^^^ ANN201
46 | if not x:
47 | return 1
@ -156,10 +156,48 @@ auto_return_type.py:45:5: ANN201 [*] Missing return type annotation for public f
42 42 | return {"foo": 1}
43 43 |
44 44 |
45 |-def func():
45 |+def func() -> int:
45 |-def func(x: int):
45 |+def func(x: int) -> int:
46 46 | if not x:
47 47 | return 1
48 48 | else:
auto_return_type.py:52:5: ANN201 [*] Missing return type annotation for public function `func`
|
52 | def func(x: int):
| ^^^^ ANN201
53 | if not x:
54 | return 1
|
= help: Add return type annotation: `int | None`
Unsafe fix
49 49 | return True
50 50 |
51 51 |
52 |-def func(x: int):
52 |+def func(x: int) -> int | None:
53 53 | if not x:
54 54 | return 1
55 55 | else:
auto_return_type.py:59:5: ANN201 [*] Missing return type annotation for public function `func`
|
59 | def func(x: int):
| ^^^^ ANN201
60 | if not x:
61 | return 1
|
= help: Add return type annotation: `str | int | None`
Unsafe fix
56 56 | return None
57 57 |
58 58 |
59 |-def func(x: int):
59 |+def func(x: int) -> str | int | None:
60 60 | if not x:
61 61 | return 1
62 62 | elif x > 5:

View file

@ -0,0 +1,223 @@
---
source: crates/ruff_linter/src/rules/flake8_annotations/mod.rs
---
auto_return_type.py:1:5: ANN201 [*] Missing return type annotation for public function `func`
|
1 | def func():
| ^^^^ ANN201
2 | return 1
|
= help: Add return type annotation: `int`
Unsafe fix
1 |-def func():
1 |+def func() -> int:
2 2 | return 1
3 3 |
4 4 |
auto_return_type.py:5:5: ANN201 [*] Missing return type annotation for public function `func`
|
5 | def func():
| ^^^^ ANN201
6 | return 1.5
|
= help: Add return type annotation: `float`
Unsafe fix
2 2 | return 1
3 3 |
4 4 |
5 |-def func():
5 |+def func() -> float:
6 6 | return 1.5
7 7 |
8 8 |
auto_return_type.py:9:5: ANN201 [*] Missing return type annotation for public function `func`
|
9 | def func(x: int):
| ^^^^ ANN201
10 | if x > 0:
11 | return 1
|
= help: Add return type annotation: `float`
Unsafe fix
6 6 | return 1.5
7 7 |
8 8 |
9 |-def func(x: int):
9 |+def func(x: int) -> float:
10 10 | if x > 0:
11 11 | return 1
12 12 | else:
auto_return_type.py:16:5: ANN201 [*] Missing return type annotation for public function `func`
|
16 | def func():
| ^^^^ ANN201
17 | return True
|
= help: Add return type annotation: `bool`
Unsafe fix
13 13 | return 1.5
14 14 |
15 15 |
16 |-def func():
16 |+def func() -> bool:
17 17 | return True
18 18 |
19 19 |
auto_return_type.py:20:5: ANN201 [*] Missing return type annotation for public function `func`
|
20 | def func(x: int):
| ^^^^ ANN201
21 | if x > 0:
22 | return None
|
= help: Add return type annotation: `None`
Unsafe fix
17 17 | return True
18 18 |
19 19 |
20 |-def func(x: int):
20 |+def func(x: int) -> None:
21 21 | if x > 0:
22 22 | return None
23 23 | else:
auto_return_type.py:27:5: ANN201 [*] Missing return type annotation for public function `func`
|
27 | def func(x: int):
| ^^^^ ANN201
28 | return 1 or 2.5 if x > 0 else 1.5 or "str"
|
= help: Add return type annotation: `Union[str | float]`
Unsafe fix
1 |+from typing import Union
1 2 | def func():
2 3 | return 1
3 4 |
--------------------------------------------------------------------------------
24 25 | return
25 26 |
26 27 |
27 |-def func(x: int):
28 |+def func(x: int) -> Union[str | float]:
28 29 | return 1 or 2.5 if x > 0 else 1.5 or "str"
29 30 |
30 31 |
auto_return_type.py:31:5: ANN201 [*] Missing return type annotation for public function `func`
|
31 | def func(x: int):
| ^^^^ ANN201
32 | return 1 + 2.5 if x > 0 else 1.5 or "str"
|
= help: Add return type annotation: `Union[str | float]`
Unsafe fix
1 |+from typing import Union
1 2 | def func():
2 3 | return 1
3 4 |
--------------------------------------------------------------------------------
28 29 | return 1 or 2.5 if x > 0 else 1.5 or "str"
29 30 |
30 31 |
31 |-def func(x: int):
32 |+def func(x: int) -> Union[str | float]:
32 33 | return 1 + 2.5 if x > 0 else 1.5 or "str"
33 34 |
34 35 |
auto_return_type.py:35:5: ANN201 Missing return type annotation for public function `func`
|
35 | def func(x: int):
| ^^^^ ANN201
36 | if not x:
37 | return None
|
= help: Add return type annotation
auto_return_type.py:41:5: ANN201 Missing return type annotation for public function `func`
|
41 | def func(x: int):
| ^^^^ ANN201
42 | return {"foo": 1}
|
= help: Add return type annotation
auto_return_type.py:45:5: ANN201 [*] Missing return type annotation for public function `func`
|
45 | def func(x: int):
| ^^^^ ANN201
46 | if not x:
47 | return 1
|
= help: Add return type annotation: `int`
Unsafe fix
42 42 | return {"foo": 1}
43 43 |
44 44 |
45 |-def func(x: int):
45 |+def func(x: int) -> int:
46 46 | if not x:
47 47 | return 1
48 48 | else:
auto_return_type.py:52:5: ANN201 [*] Missing return type annotation for public function `func`
|
52 | def func(x: int):
| ^^^^ ANN201
53 | if not x:
54 | return 1
|
= help: Add return type annotation: `Optional[int]`
Unsafe fix
1 |+from typing import Optional
1 2 | def func():
2 3 | return 1
3 4 |
--------------------------------------------------------------------------------
49 50 | return True
50 51 |
51 52 |
52 |-def func(x: int):
53 |+def func(x: int) -> Optional[int]:
53 54 | if not x:
54 55 | return 1
55 56 | else:
auto_return_type.py:59:5: ANN201 [*] Missing return type annotation for public function `func`
|
59 | def func(x: int):
| ^^^^ ANN201
60 | if not x:
61 | return 1
|
= help: Add return type annotation: `Union[str | int | None]`
Unsafe fix
1 |+from typing import Union
1 2 | def func():
2 3 | return 1
3 4 |
--------------------------------------------------------------------------------
56 57 | return None
57 58 |
58 59 |
59 |-def func(x: int):
60 |+def func(x: int) -> Union[str | int | None]:
60 61 | if not x:
61 62 | return 1
62 63 | elif x > 5:

View file

@ -550,7 +550,7 @@ fn check_duplicates(checker: &mut Checker, values: &Expr) {
element.range(),
);
if let Some(prev) = prev {
let values_end = values.range().end() - TextSize::new(1);
let values_end = values.end() - TextSize::new(1);
let previous_end =
trailing_comma(prev, checker.locator().contents()).unwrap_or(values_end);
let element_end =

View file

@ -1293,6 +1293,50 @@ pub fn pep_604_union(elts: &[Expr]) -> Expr {
}
}
pub fn typing_optional(elt: Expr, binding: String) -> Expr {
Expr::Subscript(ast::ExprSubscript {
value: Box::new(Expr::Name(ast::ExprName {
id: binding,
range: TextRange::default(),
ctx: ExprContext::Load,
})),
slice: Box::new(elt),
ctx: ExprContext::Load,
range: TextRange::default(),
})
}
pub fn typing_union(elts: &[Expr], binding: String) -> Expr {
fn tuple(elts: &[Expr]) -> Expr {
match elts {
[] => Expr::Tuple(ast::ExprTuple {
elts: vec![],
ctx: ExprContext::Load,
range: TextRange::default(),
}),
[Expr::Tuple(ast::ExprTuple { elts, .. })] => pep_604_union(elts),
[elt] => elt.clone(),
[rest @ .., elt] => Expr::BinOp(ast::ExprBinOp {
left: Box::new(tuple(rest)),
op: Operator::BitOr,
right: Box::new(elt.clone()),
range: TextRange::default(),
}),
}
}
Expr::Subscript(ast::ExprSubscript {
value: Box::new(Expr::Name(ast::ExprName {
id: binding,
range: TextRange::default(),
ctx: ExprContext::Load,
})),
slice: Box::new(tuple(elts)),
ctx: ExprContext::Load,
range: TextRange::default(),
})
}
#[cfg(test)]
mod tests {
use std::borrow::Cow;

View file

@ -399,7 +399,7 @@ IpyHelpEndEscapeCommandStatement: ast::Stmt = {
_ => {
return Err(LexicalError {
error: LexicalErrorType::OtherError("only Name, Subscript and Attribute expressions are allowed in help end escape command".to_string()),
location: expr.range().start(),
location: expr.start(),
});
}
}
@ -1642,12 +1642,12 @@ FStringReplacementField: ast::Expr = {
} else {
format_spec.as_ref().map_or_else(
|| end_location - "}".text_len(),
|spec| spec.range().start() - ":".text_len(),
|spec| spec.start() - ":".text_len(),
)
};
ast::DebugText {
leading: source_code[TextRange::new(start_offset, value.range().start())].to_string(),
trailing: source_code[TextRange::new(value.range().end(), end_offset)].to_string(),
leading: source_code[TextRange::new(start_offset, value.start())].to_string(),
trailing: source_code[TextRange::new(value.end(), end_offset)].to_string(),
}
});
Ok(

View file

@ -1,5 +1,5 @@
// auto-generated: "lalrpop 0.20.0"
// sha3: e78a653b50980f07fcb78bfa43c9f023870e26514f59acdec0bec5bf84c2a133
// sha3: e999c9c9ca8fe5a29655244aa995b8cf4e639f0bda95099d8f2a395bc06b6408
use ruff_text_size::{Ranged, TextLen, TextRange, TextSize};
use ruff_python_ast::{self as ast, Int, IpyEscapeKind};
use crate::{
@ -33700,7 +33700,7 @@ fn __action76<
_ => {
return Err(LexicalError {
error: LexicalErrorType::OtherError("only Name, Subscript and Attribute expressions are allowed in help end escape command".to_string()),
location: expr.range().start(),
location: expr.start(),
});
}
}
@ -36420,12 +36420,12 @@ fn __action221<
} else {
format_spec.as_ref().map_or_else(
|| end_location - "}".text_len(),
|spec| spec.range().start() - ":".text_len(),
|spec| spec.start() - ":".text_len(),
)
};
ast::DebugText {
leading: source_code[TextRange::new(start_offset, value.range().start())].to_string(),
trailing: source_code[TextRange::new(value.range().end(), end_offset)].to_string(),
leading: source_code[TextRange::new(start_offset, value.start())].to_string(),
trailing: source_code[TextRange::new(value.end(), end_offset)].to_string(),
}
});
Ok(