[ty] Add diagnostics for invalid await expressions (#19711)
Some checks are pending
CI / mkdocs (push) Waiting to run
CI / Determine changes (push) Waiting to run
CI / cargo fmt (push) Waiting to run
CI / cargo clippy (push) Blocked by required conditions
CI / cargo test (linux) (push) Blocked by required conditions
CI / cargo test (linux, release) (push) Blocked by required conditions
CI / cargo test (windows) (push) Blocked by required conditions
CI / cargo test (wasm) (push) Blocked by required conditions
CI / cargo build (release) (push) Waiting to run
CI / formatter instabilities and black similarity (push) Blocked by required conditions
CI / cargo build (msrv) (push) Blocked by required conditions
CI / cargo fuzz build (push) Blocked by required conditions
CI / fuzz parser (push) Blocked by required conditions
CI / test scripts (push) Blocked by required conditions
CI / ecosystem (push) Blocked by required conditions
CI / Fuzz for new ty panics (push) Blocked by required conditions
CI / cargo shear (push) Blocked by required conditions
CI / python package (push) Waiting to run
CI / pre-commit (push) Waiting to run
CI / test ruff-lsp (push) Blocked by required conditions
CI / check playground (push) Blocked by required conditions
CI / benchmarks-instrumented (push) Blocked by required conditions
CI / benchmarks-walltime (push) Blocked by required conditions
[ty Playground] Release / publish (push) Waiting to run

## Summary

This PR adds a new lint, `invalid-await`, for all sorts of reasons why
an object may not be `await`able, as discussed in astral-sh/ty#919.
Precisely, `__await__` is guarded against being missing, possibly
unbound, or improperly defined (expects additional arguments or doesn't
return an iterator).

Of course, diagnostics need to be fine-tuned. If `__await__` cannot be
called with no extra arguments, it indicates an error (or a quirk?) in
the method signature, not at the call site. Without any doubt, such an
object is not `Awaitable`, but I feel like talking about arguments for
an *implicit* call is a bit leaky.
I didn't reference any actual diagnostic messages in the lint
definition, because I want to hear feedback first.

Also, there's no mention of the actual required method signature for
`__await__` anywhere in the docs. The only reference I had is the
`typing` stub. I basically ended up linking `[Awaitable]` to ["must
implement
`__await__`"](https://docs.python.org/3/library/collections.abc.html#collections.abc.Awaitable),
which is insufficient on its own.

## Test Plan

The following code was tested:
```python
import asyncio
import typing


class Awaitable:
    def __await__(self) -> typing.Generator[typing.Any, None, int]:
        yield None
        return 5


class NoDunderMethod:
    pass


class InvalidAwaitArgs:
    def __await__(self, value: int) -> int:
        return value


class InvalidAwaitReturn:
    def __await__(self) -> int:
        return 5


class InvalidAwaitReturnImplicit:
    def __await__(self):
        pass


async def main() -> None:
    result = await Awaitable()  # valid
    result = await NoDunderMethod()  # `__await__` is missing
    result = await InvalidAwaitReturn()  # `__await__` returns `int`, which is not a valid iterator 
    result = await InvalidAwaitArgs()  # `__await__` expects additional arguments and cannot be called implicitly
    result = await InvalidAwaitReturnImplicit()  # `__await__` returns `Unknown`, which is not a valid iterator


asyncio.run(main())
```

---------

Co-authored-by: Carl Meyer <carl@astral.sh>
This commit is contained in:
Andrii Turov 2025-08-15 00:38:33 +03:00 committed by GitHub
parent f6093452ed
commit 957320c0f1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 661 additions and 84 deletions

View file

@ -0,0 +1,105 @@
# Invalid await diagnostics
<!-- snapshot-diagnostics -->
## Basic
This is a test showcasing a primitive case where an object is not awaitable.
```py
async def main() -> None:
await 1 # error: [invalid-await]
```
## Custom type with missing `__await__`
This diagnostic also points to the class definition if available.
```py
class MissingAwait:
pass
async def main() -> None:
await MissingAwait() # error: [invalid-await]
```
## Custom type with possibly unbound `__await__`
This diagnostic also points to the method definition if available.
```py
from datetime import datetime
class PossiblyUnbound:
if datetime.today().weekday() == 0:
def __await__(self):
yield
async def main() -> None:
await PossiblyUnbound() # error: [invalid-await]
```
## `__await__` definition with extra arguments
Currently, the signature of `__await__` isn't checked for conformity with the `Awaitable` protocol
directly. Instead, individual anomalies are reported, such as the following. Here, the diagnostic
reports that the object is not implicitly awaitable, while also pointing at the function parameters.
```py
class InvalidAwaitArgs:
def __await__(self, value: int):
yield value
async def main() -> None:
await InvalidAwaitArgs() # error: [invalid-await]
```
## Non-callable `__await__`
This diagnostic doesn't point to the attribute definition, but complains about it being possibly not
awaitable.
```py
class NonCallableAwait:
__await__ = 42
async def main() -> None:
await NonCallableAwait() # error: [invalid-await]
```
## `__await__` definition with explicit invalid return type
`__await__` must return a valid iterator. This diagnostic also points to the method definition if
available.
```py
class InvalidAwaitReturn:
def __await__(self) -> int:
return 5
async def main() -> None:
await InvalidAwaitReturn() # error: [invalid-await]
```
## Invalid union return type
When multiple potential definitions of `__await__` exist, all of them must be proper in order for an
instance to be awaitable. In this specific case, no specific function definition is highlighted.
```py
import typing
from datetime import datetime
class UnawaitableUnion:
if datetime.today().weekday() == 6:
def __await__(self) -> typing.Generator[typing.Any, None, None]:
yield
else:
def __await__(self) -> int:
return 5
async def main() -> None:
await UnawaitableUnion() # error: [invalid-await]
```

View file

@ -0,0 +1,41 @@
---
source: crates/ty_test/src/lib.rs
expression: snapshot
---
---
mdtest name: invalid_await.md - Invalid await diagnostics - Basic
mdtest path: crates/ty_python_semantic/resources/mdtest/diagnostics/invalid_await.md
---
# Python source files
## mdtest_snippet.py
```
1 | async def main() -> None:
2 | await 1 # error: [invalid-await]
```
# Diagnostics
```
error[invalid-await]: `Literal[1]` is not awaitable
--> src/mdtest_snippet.py:2:11
|
1 | async def main() -> None:
2 | await 1 # error: [invalid-await]
| ^
|
::: stdlib/builtins.pyi:337:7
|
335 | _LiteralInteger = _PositiveInteger | _NegativeInteger | Literal[0] # noqa: Y026 # TODO: Use TypeAlias once mypy bugs are fixed
336 |
337 | class int:
| --- type defined here
338 | """int([x]) -> integer
339 | int(x, base=10) -> integer
|
info: `__await__` is missing
info: rule `invalid-await` is enabled by default
```

View file

@ -0,0 +1,41 @@
---
source: crates/ty_test/src/lib.rs
expression: snapshot
---
---
mdtest name: invalid_await.md - Invalid await diagnostics - Custom type with missing `__await__`
mdtest path: crates/ty_python_semantic/resources/mdtest/diagnostics/invalid_await.md
---
# Python source files
## mdtest_snippet.py
```
1 | class MissingAwait:
2 | pass
3 |
4 | async def main() -> None:
5 | await MissingAwait() # error: [invalid-await]
```
# Diagnostics
```
error[invalid-await]: `MissingAwait` is not awaitable
--> src/mdtest_snippet.py:5:11
|
4 | async def main() -> None:
5 | await MissingAwait() # error: [invalid-await]
| ^^^^^^^^^^^^^^
|
::: src/mdtest_snippet.py:1:7
|
1 | class MissingAwait:
| ------------ type defined here
2 | pass
|
info: `__await__` is missing
info: rule `invalid-await` is enabled by default
```

View file

@ -0,0 +1,47 @@
---
source: crates/ty_test/src/lib.rs
expression: snapshot
---
---
mdtest name: invalid_await.md - Invalid await diagnostics - Custom type with possibly unbound `__await__`
mdtest path: crates/ty_python_semantic/resources/mdtest/diagnostics/invalid_await.md
---
# Python source files
## mdtest_snippet.py
```
1 | from datetime import datetime
2 |
3 | class PossiblyUnbound:
4 | if datetime.today().weekday() == 0:
5 | def __await__(self):
6 | yield
7 |
8 | async def main() -> None:
9 | await PossiblyUnbound() # error: [invalid-await]
```
# Diagnostics
```
error[invalid-await]: `PossiblyUnbound` is not awaitable
--> src/mdtest_snippet.py:9:11
|
8 | async def main() -> None:
9 | await PossiblyUnbound() # error: [invalid-await]
| ^^^^^^^^^^^^^^^^^
|
::: src/mdtest_snippet.py:5:13
|
3 | class PossiblyUnbound:
4 | if datetime.today().weekday() == 0:
5 | def __await__(self):
| --------------- method defined here
6 | yield
|
info: `__await__` is possibly unbound
info: rule `invalid-await` is enabled by default
```

View file

@ -0,0 +1,45 @@
---
source: crates/ty_test/src/lib.rs
expression: snapshot
---
---
mdtest name: invalid_await.md - Invalid await diagnostics - Invalid union return type
mdtest path: crates/ty_python_semantic/resources/mdtest/diagnostics/invalid_await.md
---
# Python source files
## mdtest_snippet.py
```
1 | import typing
2 | from datetime import datetime
3 |
4 | class UnawaitableUnion:
5 | if datetime.today().weekday() == 6:
6 |
7 | def __await__(self) -> typing.Generator[typing.Any, None, None]:
8 | yield
9 | else:
10 |
11 | def __await__(self) -> int:
12 | return 5
13 |
14 | async def main() -> None:
15 | await UnawaitableUnion() # error: [invalid-await]
```
# Diagnostics
```
error[invalid-await]: `UnawaitableUnion` is not awaitable
--> src/mdtest_snippet.py:15:11
|
14 | async def main() -> None:
15 | await UnawaitableUnion() # error: [invalid-await]
| ^^^^^^^^^^^^^^^^^^
|
info: `__await__` returns `Generator[Any, None, None] | int`, which is not a valid iterator
info: rule `invalid-await` is enabled by default
```

View file

@ -0,0 +1,35 @@
---
source: crates/ty_test/src/lib.rs
expression: snapshot
---
---
mdtest name: invalid_await.md - Invalid await diagnostics - Non-callable `__await__`
mdtest path: crates/ty_python_semantic/resources/mdtest/diagnostics/invalid_await.md
---
# Python source files
## mdtest_snippet.py
```
1 | class NonCallableAwait:
2 | __await__ = 42
3 |
4 | async def main() -> None:
5 | await NonCallableAwait() # error: [invalid-await]
```
# Diagnostics
```
error[invalid-await]: `NonCallableAwait` is not awaitable
--> src/mdtest_snippet.py:5:11
|
4 | async def main() -> None:
5 | await NonCallableAwait() # error: [invalid-await]
| ^^^^^^^^^^^^^^^^^^
|
info: `__await__` is possibly not callable
info: rule `invalid-await` is enabled by default
```

View file

@ -0,0 +1,43 @@
---
source: crates/ty_test/src/lib.rs
expression: snapshot
---
---
mdtest name: invalid_await.md - Invalid await diagnostics - `__await__` definition with extra arguments
mdtest path: crates/ty_python_semantic/resources/mdtest/diagnostics/invalid_await.md
---
# Python source files
## mdtest_snippet.py
```
1 | class InvalidAwaitArgs:
2 | def __await__(self, value: int):
3 | yield value
4 |
5 | async def main() -> None:
6 | await InvalidAwaitArgs() # error: [invalid-await]
```
# Diagnostics
```
error[invalid-await]: `InvalidAwaitArgs` is not awaitable
--> src/mdtest_snippet.py:6:11
|
5 | async def main() -> None:
6 | await InvalidAwaitArgs() # error: [invalid-await]
| ^^^^^^^^^^^^^^^^^^
|
::: src/mdtest_snippet.py:2:18
|
1 | class InvalidAwaitArgs:
2 | def __await__(self, value: int):
| ------------------ parameters here
3 | yield value
|
info: `__await__` requires arguments and cannot be called implicitly
info: rule `invalid-await` is enabled by default
```

View file

@ -0,0 +1,43 @@
---
source: crates/ty_test/src/lib.rs
expression: snapshot
---
---
mdtest name: invalid_await.md - Invalid await diagnostics - `__await__` definition with explicit invalid return type
mdtest path: crates/ty_python_semantic/resources/mdtest/diagnostics/invalid_await.md
---
# Python source files
## mdtest_snippet.py
```
1 | class InvalidAwaitReturn:
2 | def __await__(self) -> int:
3 | return 5
4 |
5 | async def main() -> None:
6 | await InvalidAwaitReturn() # error: [invalid-await]
```
# Diagnostics
```
error[invalid-await]: `InvalidAwaitReturn` is not awaitable
--> src/mdtest_snippet.py:6:11
|
5 | async def main() -> None:
6 | await InvalidAwaitReturn() # error: [invalid-await]
| ^^^^^^^^^^^^^^^^^^^^
|
::: src/mdtest_snippet.py:2:9
|
1 | class InvalidAwaitReturn:
2 | def __await__(self) -> int:
| ---------------------- method defined here
3 | return 5
|
info: `__await__` returns `int`, which is not a valid iterator
info: rule `invalid-await` is enabled by default
```

View file

@ -41,7 +41,7 @@ use crate::types::call::{Binding, Bindings, CallArguments, CallableBinding};
use crate::types::class::{CodeGeneratorKind, Field};
pub(crate) use crate::types::class_base::ClassBase;
use crate::types::context::{LintDiagnosticGuard, LintDiagnosticGuardBuilder};
use crate::types::diagnostic::{INVALID_TYPE_FORM, UNSUPPORTED_BOOL_CONVERSION};
use crate::types::diagnostic::{INVALID_AWAIT, INVALID_TYPE_FORM, UNSUPPORTED_BOOL_CONVERSION};
use crate::types::enums::{enum_metadata, is_single_member_enum};
use crate::types::function::{
DataclassTransformerParams, FunctionSpans, FunctionType, KnownFunction,
@ -4778,20 +4778,24 @@ impl<'db> Type<'db> {
mode: EvaluationMode,
) -> Result<Cow<'db, TupleSpec<'db>>, IterationError<'db>> {
if mode.is_async() {
let try_call_dunder_anext_on_iterator = |iterator: Type<'db>| {
let try_call_dunder_anext_on_iterator = |iterator: Type<'db>| -> Result<
Result<Type<'db>, AwaitError<'db>>,
CallDunderError<'db>,
> {
iterator
.try_call_dunder(db, "__anext__", CallArguments::none())
.map(|dunder_anext_outcome| {
dunder_anext_outcome.return_type(db).resolve_await(db)
})
.map(|dunder_anext_outcome| dunder_anext_outcome.return_type(db).try_await(db))
};
return match self.try_call_dunder(db, "__aiter__", CallArguments::none()) {
Ok(dunder_aiter_bindings) => {
let iterator = dunder_aiter_bindings.return_type(db);
match try_call_dunder_anext_on_iterator(iterator) {
Ok(result) => Ok(Cow::Owned(TupleSpec::homogeneous(result))),
Err(dunder_anext_error) => {
Ok(Ok(result)) => Ok(Cow::Owned(TupleSpec::homogeneous(result))),
Ok(Err(AwaitError::InvalidReturnType(..))) => {
Err(IterationError::UnboundAiterError)
} // TODO: __anext__ is bound, but is not properly awaitable
Err(dunder_anext_error) | Ok(Err(AwaitError::Call(dunder_anext_error))) => {
Err(IterationError::IterReturnsInvalidIterator {
iterator,
dunder_error: dunder_anext_error,
@ -4996,7 +5000,7 @@ impl<'db> Type<'db> {
(Ok(enter), Ok(_)) => {
let ty = enter.return_type(db);
Ok(if mode.is_async() {
ty.resolve_await(db)
ty.try_await(db).unwrap_or(Type::unknown())
} else {
ty
})
@ -5005,7 +5009,7 @@ impl<'db> Type<'db> {
let ty = enter.return_type(db);
Err(ContextManagerError::Exit {
enter_return_type: if mode.is_async() {
ty.resolve_await(db)
ty.try_await(db).unwrap_or(Type::unknown())
} else {
ty
},
@ -5024,15 +5028,17 @@ impl<'db> Type<'db> {
}
/// Resolve the type of an `await …` expression where `self` is the type of the awaitable.
fn resolve_await(self, db: &'db dyn Db) -> Type<'db> {
// TODO: Add proper error handling and rename this method to `try_await`.
self.try_call_dunder(db, "__await__", CallArguments::none())
.map_or(Type::unknown(), |result| {
result
.return_type(db)
.generator_return_type(db)
.unwrap_or_else(Type::unknown)
})
fn try_await(self, db: &'db dyn Db) -> Result<Type<'db>, AwaitError<'db>> {
let await_result = self.try_call_dunder(db, "__await__", CallArguments::none());
match await_result {
Ok(bindings) => {
let return_type = bindings.return_type(db);
Ok(return_type.generator_return_type(db).ok_or_else(|| {
AwaitError::InvalidReturnType(return_type, Box::new(bindings))
})?)
}
Err(call_error) => Err(AwaitError::Call(call_error)),
}
}
/// Get the return type of a `yield from …` expression where `self` is the type of the generator.
@ -5068,6 +5074,8 @@ impl<'db> Type<'db> {
None
}
}
Type::Union(union) => union.try_map(db, |ty| ty.generator_return_type(db)),
ty @ (Type::Dynamic(_) | Type::Never) => Some(ty),
_ => None,
}
}
@ -7224,6 +7232,97 @@ impl<'db> TypeVarBoundOrConstraints<'db> {
}
}
/// Error returned if a type is not awaitable.
#[derive(Debug)]
enum AwaitError<'db> {
/// `__await__` is either missing, potentially unbound or cannot be called with provided
/// arguments.
Call(CallDunderError<'db>),
/// `__await__` resolved successfully, but its return type is known not to be a generator.
InvalidReturnType(Type<'db>, Box<Bindings<'db>>),
}
impl<'db> AwaitError<'db> {
fn report_diagnostic(
&self,
context: &InferContext<'db, '_>,
context_expression_type: Type<'db>,
context_expression_node: ast::AnyNodeRef,
) {
let Some(builder) = context.report_lint(&INVALID_AWAIT, context_expression_node) else {
return;
};
let db = context.db();
let mut diag = builder.into_diagnostic(
format_args!("`{type}` is not awaitable", type = context_expression_type.display(db)),
);
match self {
Self::Call(CallDunderError::CallError(CallErrorKind::BindingError, bindings)) => {
diag.info("`__await__` requires arguments and cannot be called implicitly");
if let Some(definition_spans) = bindings.callable_type().function_spans(db) {
diag.annotate(
Annotation::secondary(definition_spans.parameters)
.message("parameters here"),
);
}
}
Self::Call(CallDunderError::CallError(
kind @ (CallErrorKind::NotCallable | CallErrorKind::PossiblyNotCallable),
bindings,
)) => {
let possibly = if matches!(kind, CallErrorKind::PossiblyNotCallable) {
" possibly"
} else {
""
};
diag.info(format_args!("`__await__` is{possibly} not callable"));
if let Some(definition) = bindings.callable_type().definition(db) {
if let Some(definition_range) = definition.focus_range(db) {
diag.annotate(
Annotation::secondary(definition_range.into())
.message("attribute defined here"),
);
}
}
}
Self::Call(CallDunderError::PossiblyUnbound(bindings)) => {
diag.info("`__await__` is possibly unbound");
if let Some(definition_spans) = bindings.callable_type().function_spans(db) {
diag.annotate(
Annotation::secondary(definition_spans.signature)
.message("method defined here"),
);
}
}
Self::Call(CallDunderError::MethodNotAvailable) => {
diag.info("`__await__` is missing");
if let Some(type_definition) = context_expression_type.definition(db) {
if let Some(definition_range) = type_definition.focus_range(db) {
diag.annotate(
Annotation::secondary(definition_range.into())
.message("type defined here"),
);
}
}
}
Self::InvalidReturnType(return_type, bindings) => {
diag.info(format_args!(
"`__await__` returns `{return_type}`, which is not a valid iterator",
return_type = return_type.display(db)
));
if let Some(definition_spans) = bindings.callable_type().function_spans(db) {
diag.annotate(
Annotation::secondary(definition_spans.signature)
.message("method defined here"),
);
}
}
}
}
}
/// Error returned if a type is not (or may not be) a context manager.
#[derive(Debug)]
enum ContextManagerError<'db> {
@ -7447,11 +7546,11 @@ impl<'db> IterationError<'db> {
match self {
Self::IterReturnsInvalidIterator {
dunder_error, mode, ..
} => dunder_error.return_type(db).map(|ty| {
} => dunder_error.return_type(db).and_then(|ty| {
if mode.is_async() {
ty.resolve_await(db)
ty.try_await(db).ok()
} else {
ty
Some(ty)
}
}),
@ -7466,7 +7565,7 @@ impl<'db> IterationError<'db> {
"__anext__",
CallArguments::none(),
))
.map(|ty| ty.resolve_await(db))
.and_then(|ty| ty.try_await(db).ok())
} else {
return_type(dunder_iter_bindings.return_type(db).try_call_dunder(
db,

View file

@ -45,6 +45,7 @@ pub(crate) fn register_lints(registry: &mut LintRegistryBuilder) {
registry.register_lint(&INVALID_ARGUMENT_TYPE);
registry.register_lint(&INVALID_RETURN_TYPE);
registry.register_lint(&INVALID_ASSIGNMENT);
registry.register_lint(&INVALID_AWAIT);
registry.register_lint(&INVALID_BASE);
registry.register_lint(&INVALID_CONTEXT_MANAGER);
registry.register_lint(&INVALID_DECLARATION);
@ -578,6 +579,36 @@ declare_lint! {
}
}
declare_lint! {
/// ## What it does
/// Checks for `await` being used with types that are not [Awaitable].
///
/// ## Why is this bad?
/// Such expressions will lead to `TypeError` being raised at runtime.
///
/// ## Examples
/// ```python
/// import asyncio
///
/// class InvalidAwait:
/// def __await__(self) -> int:
/// return 5
///
/// async def main() -> None:
/// await InvalidAwait() # error: [invalid-await]
/// await 42 # error: [invalid-await]
///
/// asyncio.run(main())
/// ```
///
/// [Awaitable]: https://docs.python.org/3/library/collections.abc.html#collections.abc.Awaitable
pub(crate) static INVALID_AWAIT = {
summary: "detects awaiting on types that don't support it",
status: LintStatus::preview("1.0.0"),
default_level: Level::Error,
}
}
declare_lint! {
/// ## What it does
/// Checks for class definitions that have bases which are not instances of `type`.

View file

@ -94,7 +94,6 @@ pub(crate) struct FunctionSpans {
pub(crate) name: Span,
/// The span of the parameter list, including the opening and
/// closing parentheses.
#[expect(dead_code)]
pub(crate) parameters: Span,
/// The span of the annotated return type, if present.
pub(crate) return_type: Option<Span>,

View file

@ -6380,7 +6380,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
node_index: _,
value,
} = await_expression;
self.infer_expression(value).resolve_await(self.db())
let expr_type = self.infer_expression(value);
expr_type.try_await(self.db()).unwrap_or_else(|err| {
err.report_diagnostic(&self.context, expr_type, value.as_ref().into());
Type::unknown()
})
}
// Perform narrowing with applicable constraints between the current scope and the enclosing scope.