mirror of
https://github.com/astral-sh/ruff.git
synced 2025-08-04 02:38:25 +00:00
[ty] Fix false-positive [invalid-return-type]
diagnostics on generator functions (#17871)
This commit is contained in:
parent
47e3aa40b3
commit
bb6c7cad07
8 changed files with 253 additions and 22 deletions
|
@ -340,3 +340,57 @@ def f() -> int:
|
|||
def f(cond: bool) -> str:
|
||||
return "hello" if cond else NotImplemented
|
||||
```
|
||||
|
||||
## Generator functions
|
||||
|
||||
<!-- snapshot-diagnostics -->
|
||||
|
||||
A function with a `yield` statement anywhere in its body is a
|
||||
[generator function](https://docs.python.org/3/glossary.html#term-generator). A generator function
|
||||
implicitly returns an instance of `types.GeneratorType` even if it does not contain any `return`
|
||||
statements.
|
||||
|
||||
```py
|
||||
import types
|
||||
import typing
|
||||
|
||||
def f() -> types.GeneratorType:
|
||||
yield 42
|
||||
|
||||
def g() -> typing.Generator:
|
||||
yield 42
|
||||
|
||||
def h() -> typing.Iterator:
|
||||
yield 42
|
||||
|
||||
def i() -> typing.Iterable:
|
||||
yield 42
|
||||
|
||||
def j() -> str: # error: [invalid-return-type]
|
||||
yield 42
|
||||
```
|
||||
|
||||
If it is an `async` function with a `yield` statement in its body, it is an
|
||||
[asynchronous generator function](https://docs.python.org/3/glossary.html#term-asynchronous-generator).
|
||||
An asynchronous generator function implicitly returns an instance of `types.AsyncGeneratorType` even
|
||||
if it does not contain any `return` statements.
|
||||
|
||||
```py
|
||||
import types
|
||||
import typing
|
||||
|
||||
async def f() -> types.AsyncGeneratorType:
|
||||
yield 42
|
||||
|
||||
async def g() -> typing.AsyncGenerator:
|
||||
yield 42
|
||||
|
||||
async def h() -> typing.AsyncIterator:
|
||||
yield 42
|
||||
|
||||
async def i() -> typing.AsyncIterable:
|
||||
yield 42
|
||||
|
||||
async def j() -> str: # error: [invalid-return-type]
|
||||
yield 42
|
||||
```
|
||||
|
|
|
@ -0,0 +1,82 @@
|
|||
---
|
||||
source: crates/ty_test/src/lib.rs
|
||||
expression: snapshot
|
||||
---
|
||||
---
|
||||
mdtest name: return_type.md - Function return type - Generator functions
|
||||
mdtest path: crates/ty_python_semantic/resources/mdtest/function/return_type.md
|
||||
---
|
||||
|
||||
# Python source files
|
||||
|
||||
## mdtest_snippet.py
|
||||
|
||||
```
|
||||
1 | import types
|
||||
2 | import typing
|
||||
3 |
|
||||
4 | def f() -> types.GeneratorType:
|
||||
5 | yield 42
|
||||
6 |
|
||||
7 | def g() -> typing.Generator:
|
||||
8 | yield 42
|
||||
9 |
|
||||
10 | def h() -> typing.Iterator:
|
||||
11 | yield 42
|
||||
12 |
|
||||
13 | def i() -> typing.Iterable:
|
||||
14 | yield 42
|
||||
15 |
|
||||
16 | def j() -> str: # error: [invalid-return-type]
|
||||
17 | yield 42
|
||||
18 | import types
|
||||
19 | import typing
|
||||
20 |
|
||||
21 | async def f() -> types.AsyncGeneratorType:
|
||||
22 | yield 42
|
||||
23 |
|
||||
24 | async def g() -> typing.AsyncGenerator:
|
||||
25 | yield 42
|
||||
26 |
|
||||
27 | async def h() -> typing.AsyncIterator:
|
||||
28 | yield 42
|
||||
29 |
|
||||
30 | async def i() -> typing.AsyncIterable:
|
||||
31 | yield 42
|
||||
32 |
|
||||
33 | async def j() -> str: # error: [invalid-return-type]
|
||||
34 | yield 42
|
||||
```
|
||||
|
||||
# Diagnostics
|
||||
|
||||
```
|
||||
error: lint:invalid-return-type: Return type does not match returned value
|
||||
--> src/mdtest_snippet.py:16:12
|
||||
|
|
||||
14 | yield 42
|
||||
15 |
|
||||
16 | def j() -> str: # error: [invalid-return-type]
|
||||
| ^^^ Expected `str`, found `types.GeneratorType`
|
||||
17 | yield 42
|
||||
18 | import types
|
||||
|
|
||||
info: Function is inferred as returning `types.GeneratorType` because it is a generator function
|
||||
info: See https://docs.python.org/3/glossary.html#term-generator for more details
|
||||
|
||||
```
|
||||
|
||||
```
|
||||
error: lint:invalid-return-type: Return type does not match returned value
|
||||
--> src/mdtest_snippet.py:33:18
|
||||
|
|
||||
31 | yield 42
|
||||
32 |
|
||||
33 | async def j() -> str: # error: [invalid-return-type]
|
||||
| ^^^ Expected `str`, found `types.AsyncGeneratorType`
|
||||
34 | yield 42
|
||||
|
|
||||
info: Function is inferred as returning `types.AsyncGeneratorType` because it is an async generator function
|
||||
info: See https://docs.python.org/3/glossary.html#term-asynchronous-generator for more details
|
||||
|
||||
```
|
|
@ -194,6 +194,9 @@ pub(crate) struct SemanticIndex<'db> {
|
|||
|
||||
/// List of all semantic syntax errors in this file.
|
||||
semantic_syntax_errors: Vec<SemanticSyntaxError>,
|
||||
|
||||
/// Set of all generator functions in this file.
|
||||
generator_functions: FxHashSet<FileScopeId>,
|
||||
}
|
||||
|
||||
impl<'db> SemanticIndex<'db> {
|
||||
|
|
|
@ -109,6 +109,10 @@ pub(super) struct SemanticIndexBuilder<'db> {
|
|||
definitions_by_node: FxHashMap<DefinitionNodeKey, Definitions<'db>>,
|
||||
expressions_by_node: FxHashMap<ExpressionNodeKey, Expression<'db>>,
|
||||
imported_modules: FxHashSet<ModuleName>,
|
||||
/// Hashset of all [`FileScopeId`]s that correspond to [generator functions].
|
||||
///
|
||||
/// [generator functions]: https://docs.python.org/3/glossary.html#term-generator
|
||||
generator_functions: FxHashSet<FileScopeId>,
|
||||
eager_bindings: FxHashMap<EagerBindingsKey, ScopedEagerBindingsId>,
|
||||
/// Errors collected by the `semantic_checker`.
|
||||
semantic_syntax_errors: RefCell<Vec<SemanticSyntaxError>>,
|
||||
|
@ -142,6 +146,7 @@ impl<'db> SemanticIndexBuilder<'db> {
|
|||
expressions_by_node: FxHashMap::default(),
|
||||
|
||||
imported_modules: FxHashSet::default(),
|
||||
generator_functions: FxHashSet::default(),
|
||||
|
||||
eager_bindings: FxHashMap::default(),
|
||||
|
||||
|
@ -1081,6 +1086,7 @@ impl<'db> SemanticIndexBuilder<'db> {
|
|||
self.scope_ids_by_scope.shrink_to_fit();
|
||||
self.scopes_by_node.shrink_to_fit();
|
||||
self.eager_bindings.shrink_to_fit();
|
||||
self.generator_functions.shrink_to_fit();
|
||||
|
||||
SemanticIndex {
|
||||
symbol_tables,
|
||||
|
@ -1097,6 +1103,7 @@ impl<'db> SemanticIndexBuilder<'db> {
|
|||
has_future_annotations: self.has_future_annotations,
|
||||
eager_bindings: self.eager_bindings,
|
||||
semantic_syntax_errors: self.semantic_syntax_errors.into_inner(),
|
||||
generator_functions: self.generator_functions,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2305,6 +2312,13 @@ where
|
|||
|
||||
walk_expr(self, expr);
|
||||
}
|
||||
ast::Expr::Yield(_) => {
|
||||
let scope = self.current_scope();
|
||||
if self.scopes[scope].kind() == ScopeKind::Function {
|
||||
self.generator_functions.insert(scope);
|
||||
}
|
||||
walk_expr(self, expr);
|
||||
}
|
||||
_ => {
|
||||
walk_expr(self, expr);
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ use rustc_hash::FxHasher;
|
|||
use crate::ast_node_ref::AstNodeRef;
|
||||
use crate::node_key::NodeKey;
|
||||
use crate::semantic_index::visibility_constraints::ScopedVisibilityConstraintId;
|
||||
use crate::semantic_index::{semantic_index, SymbolMap};
|
||||
use crate::semantic_index::{semantic_index, SemanticIndex, SymbolMap};
|
||||
use crate::Db;
|
||||
|
||||
#[derive(Eq, PartialEq, Debug)]
|
||||
|
@ -170,6 +170,10 @@ impl FileScopeId {
|
|||
let index = semantic_index(db, file);
|
||||
index.scope_ids_by_scope[self]
|
||||
}
|
||||
|
||||
pub(crate) fn is_generator_function(self, index: &SemanticIndex) -> bool {
|
||||
index.generator_functions.contains(&self)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, salsa::Update)]
|
||||
|
|
|
@ -1865,6 +1865,8 @@ pub enum KnownClass {
|
|||
MethodWrapperType,
|
||||
WrapperDescriptorType,
|
||||
UnionType,
|
||||
GeneratorType,
|
||||
AsyncGeneratorType,
|
||||
// Typeshed
|
||||
NoneType, // Part of `types` for Python >= 3.10
|
||||
// Typing
|
||||
|
@ -1929,6 +1931,8 @@ impl<'db> KnownClass {
|
|||
| Self::Super
|
||||
| Self::WrapperDescriptorType
|
||||
| Self::UnionType
|
||||
| Self::GeneratorType
|
||||
| Self::AsyncGeneratorType
|
||||
| Self::MethodWrapperType => Truthiness::AlwaysTrue,
|
||||
|
||||
Self::NoneType => Truthiness::AlwaysFalse,
|
||||
|
@ -2013,6 +2017,8 @@ impl<'db> KnownClass {
|
|||
| Self::BaseExceptionGroup
|
||||
| Self::Classmethod
|
||||
| Self::GenericAlias
|
||||
| Self::GeneratorType
|
||||
| Self::AsyncGeneratorType
|
||||
| Self::ModuleType
|
||||
| Self::FunctionType
|
||||
| Self::MethodType
|
||||
|
@ -2075,6 +2081,8 @@ impl<'db> KnownClass {
|
|||
Self::UnionType => "UnionType",
|
||||
Self::MethodWrapperType => "MethodWrapperType",
|
||||
Self::WrapperDescriptorType => "WrapperDescriptorType",
|
||||
Self::GeneratorType => "GeneratorType",
|
||||
Self::AsyncGeneratorType => "AsyncGeneratorType",
|
||||
Self::NamedTuple => "NamedTuple",
|
||||
Self::NoneType => "NoneType",
|
||||
Self::SpecialForm => "_SpecialForm",
|
||||
|
@ -2116,7 +2124,7 @@ impl<'db> KnownClass {
|
|||
}
|
||||
}
|
||||
|
||||
fn display(self, db: &'db dyn Db) -> impl std::fmt::Display + 'db {
|
||||
pub(super) fn display(self, db: &'db dyn Db) -> impl std::fmt::Display + 'db {
|
||||
struct KnownClassDisplay<'db> {
|
||||
db: &'db dyn Db,
|
||||
class: KnownClass,
|
||||
|
@ -2293,6 +2301,8 @@ impl<'db> KnownClass {
|
|||
| Self::ModuleType
|
||||
| Self::FunctionType
|
||||
| Self::MethodType
|
||||
| Self::GeneratorType
|
||||
| Self::AsyncGeneratorType
|
||||
| Self::MethodWrapperType
|
||||
| Self::UnionType
|
||||
| Self::WrapperDescriptorType => KnownModule::Types,
|
||||
|
@ -2374,6 +2384,8 @@ impl<'db> KnownClass {
|
|||
| Self::GenericAlias
|
||||
| Self::ModuleType
|
||||
| Self::FunctionType
|
||||
| Self::GeneratorType
|
||||
| Self::AsyncGeneratorType
|
||||
| Self::MethodType
|
||||
| Self::MethodWrapperType
|
||||
| Self::WrapperDescriptorType
|
||||
|
@ -2434,6 +2446,8 @@ impl<'db> KnownClass {
|
|||
| Self::MethodType
|
||||
| Self::MethodWrapperType
|
||||
| Self::WrapperDescriptorType
|
||||
| Self::GeneratorType
|
||||
| Self::AsyncGeneratorType
|
||||
| Self::SpecialForm
|
||||
| Self::ChainMap
|
||||
| Self::Counter
|
||||
|
@ -2491,6 +2505,8 @@ impl<'db> KnownClass {
|
|||
"GenericAlias" => Self::GenericAlias,
|
||||
"NoneType" => Self::NoneType,
|
||||
"ModuleType" => Self::ModuleType,
|
||||
"GeneratorType" => Self::GeneratorType,
|
||||
"AsyncGeneratorType" => Self::AsyncGeneratorType,
|
||||
"FunctionType" => Self::FunctionType,
|
||||
"MethodType" => Self::MethodType,
|
||||
"UnionType" => Self::UnionType,
|
||||
|
@ -2574,6 +2590,8 @@ impl<'db> KnownClass {
|
|||
| Self::Super
|
||||
| Self::NotImplementedType
|
||||
| Self::UnionType
|
||||
| Self::GeneratorType
|
||||
| Self::AsyncGeneratorType
|
||||
| Self::WrapperDescriptorType => module == self.canonical_module(db),
|
||||
Self::NoneType => matches!(module, KnownModule::Typeshed | KnownModule::Types),
|
||||
Self::SpecialForm
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use super::context::InferContext;
|
||||
use super::ClassLiteral;
|
||||
use super::{ClassLiteral, KnownClass};
|
||||
use crate::db::Db;
|
||||
use crate::declare_lint;
|
||||
use crate::lint::{Level, LintRegistryBuilder, LintStatus};
|
||||
|
@ -12,7 +12,7 @@ use crate::types::string_annotation::{
|
|||
use crate::types::{protocol_class::ProtocolClassLiteral, KnownFunction, KnownInstanceType, Type};
|
||||
use ruff_db::diagnostic::{Annotation, Diagnostic, Severity, Span, SubDiagnostic};
|
||||
use ruff_python_ast::{self as ast, AnyNodeRef};
|
||||
use ruff_text_size::Ranged;
|
||||
use ruff_text_size::{Ranged, TextRange};
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::fmt::Formatter;
|
||||
|
||||
|
@ -1307,6 +1307,41 @@ pub(super) fn report_invalid_return_type(
|
|||
);
|
||||
}
|
||||
|
||||
pub(super) fn report_invalid_generator_function_return_type(
|
||||
context: &InferContext,
|
||||
return_type_range: TextRange,
|
||||
inferred_return: KnownClass,
|
||||
expected_ty: Type,
|
||||
) {
|
||||
let Some(builder) = context.report_lint(&INVALID_RETURN_TYPE, return_type_range) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let mut diag = builder.into_diagnostic("Return type does not match returned value");
|
||||
let inferred_ty = inferred_return.display(context.db());
|
||||
diag.set_primary_message(format_args!(
|
||||
"Expected `{expected_ty}`, found `{inferred_ty}`",
|
||||
expected_ty = expected_ty.display(context.db()),
|
||||
));
|
||||
|
||||
let (description, link) = if inferred_return == KnownClass::AsyncGeneratorType {
|
||||
(
|
||||
"an async generator function",
|
||||
"https://docs.python.org/3/glossary.html#term-asynchronous-generator",
|
||||
)
|
||||
} else {
|
||||
(
|
||||
"a generator function",
|
||||
"https://docs.python.org/3/glossary.html#term-generator",
|
||||
)
|
||||
};
|
||||
|
||||
diag.info(format_args!(
|
||||
"Function is inferred as returning `{inferred_ty}` because it is {description}"
|
||||
));
|
||||
diag.info(format_args!("See {link} for more details"));
|
||||
}
|
||||
|
||||
pub(super) fn report_implicit_return_type(
|
||||
context: &InferContext,
|
||||
range: impl Ranged,
|
||||
|
|
|
@ -68,12 +68,12 @@ use crate::types::class::MetaclassErrorKind;
|
|||
use crate::types::diagnostic::{
|
||||
report_implicit_return_type, report_invalid_arguments_to_annotated,
|
||||
report_invalid_arguments_to_callable, report_invalid_assignment,
|
||||
report_invalid_attribute_assignment, report_invalid_return_type,
|
||||
report_possibly_unbound_attribute, TypeCheckDiagnostics, CALL_NON_CALLABLE,
|
||||
CALL_POSSIBLY_UNBOUND_METHOD, CONFLICTING_DECLARATIONS, CONFLICTING_METACLASS,
|
||||
CYCLIC_CLASS_DEFINITION, DIVISION_BY_ZERO, DUPLICATE_BASE, INCONSISTENT_MRO,
|
||||
INVALID_ARGUMENT_TYPE, INVALID_ASSIGNMENT, INVALID_ATTRIBUTE_ACCESS, INVALID_BASE,
|
||||
INVALID_DECLARATION, INVALID_GENERIC_CLASS, INVALID_LEGACY_TYPE_VARIABLE,
|
||||
report_invalid_attribute_assignment, report_invalid_generator_function_return_type,
|
||||
report_invalid_return_type, report_possibly_unbound_attribute, TypeCheckDiagnostics,
|
||||
CALL_NON_CALLABLE, CALL_POSSIBLY_UNBOUND_METHOD, CONFLICTING_DECLARATIONS,
|
||||
CONFLICTING_METACLASS, CYCLIC_CLASS_DEFINITION, DIVISION_BY_ZERO, DUPLICATE_BASE,
|
||||
INCONSISTENT_MRO, INVALID_ARGUMENT_TYPE, INVALID_ASSIGNMENT, INVALID_ATTRIBUTE_ACCESS,
|
||||
INVALID_BASE, INVALID_DECLARATION, INVALID_GENERIC_CLASS, INVALID_LEGACY_TYPE_VARIABLE,
|
||||
INVALID_PARAMETER_DEFAULT, INVALID_TYPE_FORM, INVALID_TYPE_VARIABLE_CONSTRAINTS,
|
||||
POSSIBLY_UNBOUND_IMPORT, UNDEFINED_REVEAL, UNRESOLVED_ATTRIBUTE, UNRESOLVED_IMPORT,
|
||||
UNSUPPORTED_OPERATOR,
|
||||
|
@ -1611,11 +1611,7 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
}
|
||||
self.infer_body(&function.body);
|
||||
|
||||
if let Some(declared_ty) = function
|
||||
.returns
|
||||
.as_deref()
|
||||
.map(|ret| self.file_expression_type(ret))
|
||||
{
|
||||
if let Some(returns) = function.returns.as_deref() {
|
||||
fn is_stub_suite(suite: &[ast::Stmt]) -> bool {
|
||||
match suite {
|
||||
[ast::Stmt::Expr(ast::StmtExpr { value: first, .. }), ast::Stmt::Expr(ast::StmtExpr { value: second, .. }), ..] => {
|
||||
|
@ -1641,6 +1637,36 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
return;
|
||||
}
|
||||
|
||||
let declared_ty = self.file_expression_type(returns);
|
||||
|
||||
let scope_id = self.index.node_scope(NodeWithScopeRef::Function(function));
|
||||
if scope_id.is_generator_function(self.index) {
|
||||
// TODO: `AsyncGeneratorType` and `GeneratorType` are both generic classes.
|
||||
//
|
||||
// If type arguments are supplied to `(Async)Iterable`, `(Async)Iterator`,
|
||||
// `(Async)Generator` or `(Async)GeneratorType` in the return annotation,
|
||||
// we should iterate over the `yield` expressions and `return` statements in the function
|
||||
// to check that they are consistent with the type arguments provided.
|
||||
let inferred_return = if function.is_async {
|
||||
KnownClass::AsyncGeneratorType
|
||||
} else {
|
||||
KnownClass::GeneratorType
|
||||
};
|
||||
|
||||
if !inferred_return
|
||||
.to_instance(self.db())
|
||||
.is_assignable_to(self.db(), declared_ty)
|
||||
{
|
||||
report_invalid_generator_function_return_type(
|
||||
&self.context,
|
||||
returns.range(),
|
||||
inferred_return,
|
||||
declared_ty,
|
||||
);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for invalid in self
|
||||
.return_types_and_ranges
|
||||
.iter()
|
||||
|
@ -1660,23 +1686,18 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
report_invalid_return_type(
|
||||
&self.context,
|
||||
invalid.range,
|
||||
function.returns.as_ref().unwrap().range(),
|
||||
returns.range(),
|
||||
declared_ty,
|
||||
invalid.ty,
|
||||
);
|
||||
}
|
||||
let scope_id = self.index.node_scope(NodeWithScopeRef::Function(function));
|
||||
let use_def = self.index.use_def_map(scope_id);
|
||||
if use_def.can_implicit_return(self.db())
|
||||
&& !KnownClass::NoneType
|
||||
.to_instance(self.db())
|
||||
.is_assignable_to(self.db(), declared_ty)
|
||||
{
|
||||
report_implicit_return_type(
|
||||
&self.context,
|
||||
function.returns.as_ref().unwrap().range(),
|
||||
declared_ty,
|
||||
);
|
||||
report_implicit_return_type(&self.context, returns.range(), declared_ty);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue