[ty] Fix false-positive [invalid-return-type] diagnostics on generator functions (#17871)

This commit is contained in:
Alex Waygood 2025-05-05 22:44:59 +01:00 committed by GitHub
parent 47e3aa40b3
commit bb6c7cad07
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 253 additions and 22 deletions

View file

@ -340,3 +340,57 @@ def f() -> int:
def f(cond: bool) -> str:
return "hello" if cond else NotImplemented
```
## Generator functions
<!-- snapshot-diagnostics -->
A function with a `yield` statement anywhere in its body is a
[generator function](https://docs.python.org/3/glossary.html#term-generator). A generator function
implicitly returns an instance of `types.GeneratorType` even if it does not contain any `return`
statements.
```py
import types
import typing
def f() -> types.GeneratorType:
yield 42
def g() -> typing.Generator:
yield 42
def h() -> typing.Iterator:
yield 42
def i() -> typing.Iterable:
yield 42
def j() -> str: # error: [invalid-return-type]
yield 42
```
If it is an `async` function with a `yield` statement in its body, it is an
[asynchronous generator function](https://docs.python.org/3/glossary.html#term-asynchronous-generator).
An asynchronous generator function implicitly returns an instance of `types.AsyncGeneratorType` even
if it does not contain any `return` statements.
```py
import types
import typing
async def f() -> types.AsyncGeneratorType:
yield 42
async def g() -> typing.AsyncGenerator:
yield 42
async def h() -> typing.AsyncIterator:
yield 42
async def i() -> typing.AsyncIterable:
yield 42
async def j() -> str: # error: [invalid-return-type]
yield 42
```

View file

@ -0,0 +1,82 @@
---
source: crates/ty_test/src/lib.rs
expression: snapshot
---
---
mdtest name: return_type.md - Function return type - Generator functions
mdtest path: crates/ty_python_semantic/resources/mdtest/function/return_type.md
---
# Python source files
## mdtest_snippet.py
```
1 | import types
2 | import typing
3 |
4 | def f() -> types.GeneratorType:
5 | yield 42
6 |
7 | def g() -> typing.Generator:
8 | yield 42
9 |
10 | def h() -> typing.Iterator:
11 | yield 42
12 |
13 | def i() -> typing.Iterable:
14 | yield 42
15 |
16 | def j() -> str: # error: [invalid-return-type]
17 | yield 42
18 | import types
19 | import typing
20 |
21 | async def f() -> types.AsyncGeneratorType:
22 | yield 42
23 |
24 | async def g() -> typing.AsyncGenerator:
25 | yield 42
26 |
27 | async def h() -> typing.AsyncIterator:
28 | yield 42
29 |
30 | async def i() -> typing.AsyncIterable:
31 | yield 42
32 |
33 | async def j() -> str: # error: [invalid-return-type]
34 | yield 42
```
# Diagnostics
```
error: lint:invalid-return-type: Return type does not match returned value
--> src/mdtest_snippet.py:16:12
|
14 | yield 42
15 |
16 | def j() -> str: # error: [invalid-return-type]
| ^^^ Expected `str`, found `types.GeneratorType`
17 | yield 42
18 | import types
|
info: Function is inferred as returning `types.GeneratorType` because it is a generator function
info: See https://docs.python.org/3/glossary.html#term-generator for more details
```
```
error: lint:invalid-return-type: Return type does not match returned value
--> src/mdtest_snippet.py:33:18
|
31 | yield 42
32 |
33 | async def j() -> str: # error: [invalid-return-type]
| ^^^ Expected `str`, found `types.AsyncGeneratorType`
34 | yield 42
|
info: Function is inferred as returning `types.AsyncGeneratorType` because it is an async generator function
info: See https://docs.python.org/3/glossary.html#term-asynchronous-generator for more details
```

View file

@ -194,6 +194,9 @@ pub(crate) struct SemanticIndex<'db> {
/// List of all semantic syntax errors in this file.
semantic_syntax_errors: Vec<SemanticSyntaxError>,
/// Set of all generator functions in this file.
generator_functions: FxHashSet<FileScopeId>,
}
impl<'db> SemanticIndex<'db> {

View file

@ -109,6 +109,10 @@ pub(super) struct SemanticIndexBuilder<'db> {
definitions_by_node: FxHashMap<DefinitionNodeKey, Definitions<'db>>,
expressions_by_node: FxHashMap<ExpressionNodeKey, Expression<'db>>,
imported_modules: FxHashSet<ModuleName>,
/// Hashset of all [`FileScopeId`]s that correspond to [generator functions].
///
/// [generator functions]: https://docs.python.org/3/glossary.html#term-generator
generator_functions: FxHashSet<FileScopeId>,
eager_bindings: FxHashMap<EagerBindingsKey, ScopedEagerBindingsId>,
/// Errors collected by the `semantic_checker`.
semantic_syntax_errors: RefCell<Vec<SemanticSyntaxError>>,
@ -142,6 +146,7 @@ impl<'db> SemanticIndexBuilder<'db> {
expressions_by_node: FxHashMap::default(),
imported_modules: FxHashSet::default(),
generator_functions: FxHashSet::default(),
eager_bindings: FxHashMap::default(),
@ -1081,6 +1086,7 @@ impl<'db> SemanticIndexBuilder<'db> {
self.scope_ids_by_scope.shrink_to_fit();
self.scopes_by_node.shrink_to_fit();
self.eager_bindings.shrink_to_fit();
self.generator_functions.shrink_to_fit();
SemanticIndex {
symbol_tables,
@ -1097,6 +1103,7 @@ impl<'db> SemanticIndexBuilder<'db> {
has_future_annotations: self.has_future_annotations,
eager_bindings: self.eager_bindings,
semantic_syntax_errors: self.semantic_syntax_errors.into_inner(),
generator_functions: self.generator_functions,
}
}
@ -2305,6 +2312,13 @@ where
walk_expr(self, expr);
}
ast::Expr::Yield(_) => {
let scope = self.current_scope();
if self.scopes[scope].kind() == ScopeKind::Function {
self.generator_functions.insert(scope);
}
walk_expr(self, expr);
}
_ => {
walk_expr(self, expr);
}

View file

@ -13,7 +13,7 @@ use rustc_hash::FxHasher;
use crate::ast_node_ref::AstNodeRef;
use crate::node_key::NodeKey;
use crate::semantic_index::visibility_constraints::ScopedVisibilityConstraintId;
use crate::semantic_index::{semantic_index, SymbolMap};
use crate::semantic_index::{semantic_index, SemanticIndex, SymbolMap};
use crate::Db;
#[derive(Eq, PartialEq, Debug)]
@ -170,6 +170,10 @@ impl FileScopeId {
let index = semantic_index(db, file);
index.scope_ids_by_scope[self]
}
pub(crate) fn is_generator_function(self, index: &SemanticIndex) -> bool {
index.generator_functions.contains(&self)
}
}
#[derive(Debug, salsa::Update)]

View file

@ -1865,6 +1865,8 @@ pub enum KnownClass {
MethodWrapperType,
WrapperDescriptorType,
UnionType,
GeneratorType,
AsyncGeneratorType,
// Typeshed
NoneType, // Part of `types` for Python >= 3.10
// Typing
@ -1929,6 +1931,8 @@ impl<'db> KnownClass {
| Self::Super
| Self::WrapperDescriptorType
| Self::UnionType
| Self::GeneratorType
| Self::AsyncGeneratorType
| Self::MethodWrapperType => Truthiness::AlwaysTrue,
Self::NoneType => Truthiness::AlwaysFalse,
@ -2013,6 +2017,8 @@ impl<'db> KnownClass {
| Self::BaseExceptionGroup
| Self::Classmethod
| Self::GenericAlias
| Self::GeneratorType
| Self::AsyncGeneratorType
| Self::ModuleType
| Self::FunctionType
| Self::MethodType
@ -2075,6 +2081,8 @@ impl<'db> KnownClass {
Self::UnionType => "UnionType",
Self::MethodWrapperType => "MethodWrapperType",
Self::WrapperDescriptorType => "WrapperDescriptorType",
Self::GeneratorType => "GeneratorType",
Self::AsyncGeneratorType => "AsyncGeneratorType",
Self::NamedTuple => "NamedTuple",
Self::NoneType => "NoneType",
Self::SpecialForm => "_SpecialForm",
@ -2116,7 +2124,7 @@ impl<'db> KnownClass {
}
}
fn display(self, db: &'db dyn Db) -> impl std::fmt::Display + 'db {
pub(super) fn display(self, db: &'db dyn Db) -> impl std::fmt::Display + 'db {
struct KnownClassDisplay<'db> {
db: &'db dyn Db,
class: KnownClass,
@ -2293,6 +2301,8 @@ impl<'db> KnownClass {
| Self::ModuleType
| Self::FunctionType
| Self::MethodType
| Self::GeneratorType
| Self::AsyncGeneratorType
| Self::MethodWrapperType
| Self::UnionType
| Self::WrapperDescriptorType => KnownModule::Types,
@ -2374,6 +2384,8 @@ impl<'db> KnownClass {
| Self::GenericAlias
| Self::ModuleType
| Self::FunctionType
| Self::GeneratorType
| Self::AsyncGeneratorType
| Self::MethodType
| Self::MethodWrapperType
| Self::WrapperDescriptorType
@ -2434,6 +2446,8 @@ impl<'db> KnownClass {
| Self::MethodType
| Self::MethodWrapperType
| Self::WrapperDescriptorType
| Self::GeneratorType
| Self::AsyncGeneratorType
| Self::SpecialForm
| Self::ChainMap
| Self::Counter
@ -2491,6 +2505,8 @@ impl<'db> KnownClass {
"GenericAlias" => Self::GenericAlias,
"NoneType" => Self::NoneType,
"ModuleType" => Self::ModuleType,
"GeneratorType" => Self::GeneratorType,
"AsyncGeneratorType" => Self::AsyncGeneratorType,
"FunctionType" => Self::FunctionType,
"MethodType" => Self::MethodType,
"UnionType" => Self::UnionType,
@ -2574,6 +2590,8 @@ impl<'db> KnownClass {
| Self::Super
| Self::NotImplementedType
| Self::UnionType
| Self::GeneratorType
| Self::AsyncGeneratorType
| Self::WrapperDescriptorType => module == self.canonical_module(db),
Self::NoneType => matches!(module, KnownModule::Typeshed | KnownModule::Types),
Self::SpecialForm

View file

@ -1,5 +1,5 @@
use super::context::InferContext;
use super::ClassLiteral;
use super::{ClassLiteral, KnownClass};
use crate::db::Db;
use crate::declare_lint;
use crate::lint::{Level, LintRegistryBuilder, LintStatus};
@ -12,7 +12,7 @@ use crate::types::string_annotation::{
use crate::types::{protocol_class::ProtocolClassLiteral, KnownFunction, KnownInstanceType, Type};
use ruff_db::diagnostic::{Annotation, Diagnostic, Severity, Span, SubDiagnostic};
use ruff_python_ast::{self as ast, AnyNodeRef};
use ruff_text_size::Ranged;
use ruff_text_size::{Ranged, TextRange};
use rustc_hash::FxHashSet;
use std::fmt::Formatter;
@ -1307,6 +1307,41 @@ pub(super) fn report_invalid_return_type(
);
}
pub(super) fn report_invalid_generator_function_return_type(
context: &InferContext,
return_type_range: TextRange,
inferred_return: KnownClass,
expected_ty: Type,
) {
let Some(builder) = context.report_lint(&INVALID_RETURN_TYPE, return_type_range) else {
return;
};
let mut diag = builder.into_diagnostic("Return type does not match returned value");
let inferred_ty = inferred_return.display(context.db());
diag.set_primary_message(format_args!(
"Expected `{expected_ty}`, found `{inferred_ty}`",
expected_ty = expected_ty.display(context.db()),
));
let (description, link) = if inferred_return == KnownClass::AsyncGeneratorType {
(
"an async generator function",
"https://docs.python.org/3/glossary.html#term-asynchronous-generator",
)
} else {
(
"a generator function",
"https://docs.python.org/3/glossary.html#term-generator",
)
};
diag.info(format_args!(
"Function is inferred as returning `{inferred_ty}` because it is {description}"
));
diag.info(format_args!("See {link} for more details"));
}
pub(super) fn report_implicit_return_type(
context: &InferContext,
range: impl Ranged,

View file

@ -68,12 +68,12 @@ use crate::types::class::MetaclassErrorKind;
use crate::types::diagnostic::{
report_implicit_return_type, report_invalid_arguments_to_annotated,
report_invalid_arguments_to_callable, report_invalid_assignment,
report_invalid_attribute_assignment, report_invalid_return_type,
report_possibly_unbound_attribute, TypeCheckDiagnostics, CALL_NON_CALLABLE,
CALL_POSSIBLY_UNBOUND_METHOD, CONFLICTING_DECLARATIONS, CONFLICTING_METACLASS,
CYCLIC_CLASS_DEFINITION, DIVISION_BY_ZERO, DUPLICATE_BASE, INCONSISTENT_MRO,
INVALID_ARGUMENT_TYPE, INVALID_ASSIGNMENT, INVALID_ATTRIBUTE_ACCESS, INVALID_BASE,
INVALID_DECLARATION, INVALID_GENERIC_CLASS, INVALID_LEGACY_TYPE_VARIABLE,
report_invalid_attribute_assignment, report_invalid_generator_function_return_type,
report_invalid_return_type, report_possibly_unbound_attribute, TypeCheckDiagnostics,
CALL_NON_CALLABLE, CALL_POSSIBLY_UNBOUND_METHOD, CONFLICTING_DECLARATIONS,
CONFLICTING_METACLASS, CYCLIC_CLASS_DEFINITION, DIVISION_BY_ZERO, DUPLICATE_BASE,
INCONSISTENT_MRO, INVALID_ARGUMENT_TYPE, INVALID_ASSIGNMENT, INVALID_ATTRIBUTE_ACCESS,
INVALID_BASE, INVALID_DECLARATION, INVALID_GENERIC_CLASS, INVALID_LEGACY_TYPE_VARIABLE,
INVALID_PARAMETER_DEFAULT, INVALID_TYPE_FORM, INVALID_TYPE_VARIABLE_CONSTRAINTS,
POSSIBLY_UNBOUND_IMPORT, UNDEFINED_REVEAL, UNRESOLVED_ATTRIBUTE, UNRESOLVED_IMPORT,
UNSUPPORTED_OPERATOR,
@ -1611,11 +1611,7 @@ impl<'db> TypeInferenceBuilder<'db> {
}
self.infer_body(&function.body);
if let Some(declared_ty) = function
.returns
.as_deref()
.map(|ret| self.file_expression_type(ret))
{
if let Some(returns) = function.returns.as_deref() {
fn is_stub_suite(suite: &[ast::Stmt]) -> bool {
match suite {
[ast::Stmt::Expr(ast::StmtExpr { value: first, .. }), ast::Stmt::Expr(ast::StmtExpr { value: second, .. }), ..] => {
@ -1641,6 +1637,36 @@ impl<'db> TypeInferenceBuilder<'db> {
return;
}
let declared_ty = self.file_expression_type(returns);
let scope_id = self.index.node_scope(NodeWithScopeRef::Function(function));
if scope_id.is_generator_function(self.index) {
// TODO: `AsyncGeneratorType` and `GeneratorType` are both generic classes.
//
// If type arguments are supplied to `(Async)Iterable`, `(Async)Iterator`,
// `(Async)Generator` or `(Async)GeneratorType` in the return annotation,
// we should iterate over the `yield` expressions and `return` statements in the function
// to check that they are consistent with the type arguments provided.
let inferred_return = if function.is_async {
KnownClass::AsyncGeneratorType
} else {
KnownClass::GeneratorType
};
if !inferred_return
.to_instance(self.db())
.is_assignable_to(self.db(), declared_ty)
{
report_invalid_generator_function_return_type(
&self.context,
returns.range(),
inferred_return,
declared_ty,
);
}
return;
}
for invalid in self
.return_types_and_ranges
.iter()
@ -1660,23 +1686,18 @@ impl<'db> TypeInferenceBuilder<'db> {
report_invalid_return_type(
&self.context,
invalid.range,
function.returns.as_ref().unwrap().range(),
returns.range(),
declared_ty,
invalid.ty,
);
}
let scope_id = self.index.node_scope(NodeWithScopeRef::Function(function));
let use_def = self.index.use_def_map(scope_id);
if use_def.can_implicit_return(self.db())
&& !KnownClass::NoneType
.to_instance(self.db())
.is_assignable_to(self.db(), declared_ty)
{
report_implicit_return_type(
&self.context,
function.returns.as_ref().unwrap().range(),
declared_ty,
);
report_implicit_return_type(&self.context, returns.range(), declared_ty);
}
}
}