[ty] Add diagnostics for invalid await expressions (#19711)
Some checks are pending
CI / mkdocs (push) Waiting to run
CI / Determine changes (push) Waiting to run
CI / cargo fmt (push) Waiting to run
CI / cargo clippy (push) Blocked by required conditions
CI / cargo test (linux) (push) Blocked by required conditions
CI / cargo test (linux, release) (push) Blocked by required conditions
CI / cargo test (windows) (push) Blocked by required conditions
CI / cargo test (wasm) (push) Blocked by required conditions
CI / cargo build (release) (push) Waiting to run
CI / formatter instabilities and black similarity (push) Blocked by required conditions
CI / cargo build (msrv) (push) Blocked by required conditions
CI / cargo fuzz build (push) Blocked by required conditions
CI / fuzz parser (push) Blocked by required conditions
CI / test scripts (push) Blocked by required conditions
CI / ecosystem (push) Blocked by required conditions
CI / Fuzz for new ty panics (push) Blocked by required conditions
CI / cargo shear (push) Blocked by required conditions
CI / python package (push) Waiting to run
CI / pre-commit (push) Waiting to run
CI / test ruff-lsp (push) Blocked by required conditions
CI / check playground (push) Blocked by required conditions
CI / benchmarks-instrumented (push) Blocked by required conditions
CI / benchmarks-walltime (push) Blocked by required conditions
[ty Playground] Release / publish (push) Waiting to run

## Summary

This PR adds a new lint, `invalid-await`, for all sorts of reasons why
an object may not be `await`able, as discussed in astral-sh/ty#919.
Precisely, `__await__` is guarded against being missing, possibly
unbound, or improperly defined (expects additional arguments or doesn't
return an iterator).

Of course, diagnostics need to be fine-tuned. If `__await__` cannot be
called with no extra arguments, it indicates an error (or a quirk?) in
the method signature, not at the call site. Without any doubt, such an
object is not `Awaitable`, but I feel like talking about arguments for
an *implicit* call is a bit leaky.
I didn't reference any actual diagnostic messages in the lint
definition, because I want to hear feedback first.

Also, there's no mention of the actual required method signature for
`__await__` anywhere in the docs. The only reference I had is the
`typing` stub. I basically ended up linking `[Awaitable]` to ["must
implement
`__await__`"](https://docs.python.org/3/library/collections.abc.html#collections.abc.Awaitable),
which is insufficient on its own.

## Test Plan

The following code was tested:
```python
import asyncio
import typing


class Awaitable:
    def __await__(self) -> typing.Generator[typing.Any, None, int]:
        yield None
        return 5


class NoDunderMethod:
    pass


class InvalidAwaitArgs:
    def __await__(self, value: int) -> int:
        return value


class InvalidAwaitReturn:
    def __await__(self) -> int:
        return 5


class InvalidAwaitReturnImplicit:
    def __await__(self):
        pass


async def main() -> None:
    result = await Awaitable()  # valid
    result = await NoDunderMethod()  # `__await__` is missing
    result = await InvalidAwaitReturn()  # `__await__` returns `int`, which is not a valid iterator 
    result = await InvalidAwaitArgs()  # `__await__` expects additional arguments and cannot be called implicitly
    result = await InvalidAwaitReturnImplicit()  # `__await__` returns `Unknown`, which is not a valid iterator


asyncio.run(main())
```

---------

Co-authored-by: Carl Meyer <carl@astral.sh>
This commit is contained in:
Andrii Turov 2025-08-15 00:38:33 +03:00 committed by GitHub
parent f6093452ed
commit 957320c0f1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 661 additions and 84 deletions

View file

@ -41,7 +41,7 @@ use crate::types::call::{Binding, Bindings, CallArguments, CallableBinding};
use crate::types::class::{CodeGeneratorKind, Field};
pub(crate) use crate::types::class_base::ClassBase;
use crate::types::context::{LintDiagnosticGuard, LintDiagnosticGuardBuilder};
use crate::types::diagnostic::{INVALID_TYPE_FORM, UNSUPPORTED_BOOL_CONVERSION};
use crate::types::diagnostic::{INVALID_AWAIT, INVALID_TYPE_FORM, UNSUPPORTED_BOOL_CONVERSION};
use crate::types::enums::{enum_metadata, is_single_member_enum};
use crate::types::function::{
DataclassTransformerParams, FunctionSpans, FunctionType, KnownFunction,
@ -4778,20 +4778,24 @@ impl<'db> Type<'db> {
mode: EvaluationMode,
) -> Result<Cow<'db, TupleSpec<'db>>, IterationError<'db>> {
if mode.is_async() {
let try_call_dunder_anext_on_iterator = |iterator: Type<'db>| {
let try_call_dunder_anext_on_iterator = |iterator: Type<'db>| -> Result<
Result<Type<'db>, AwaitError<'db>>,
CallDunderError<'db>,
> {
iterator
.try_call_dunder(db, "__anext__", CallArguments::none())
.map(|dunder_anext_outcome| {
dunder_anext_outcome.return_type(db).resolve_await(db)
})
.map(|dunder_anext_outcome| dunder_anext_outcome.return_type(db).try_await(db))
};
return match self.try_call_dunder(db, "__aiter__", CallArguments::none()) {
Ok(dunder_aiter_bindings) => {
let iterator = dunder_aiter_bindings.return_type(db);
match try_call_dunder_anext_on_iterator(iterator) {
Ok(result) => Ok(Cow::Owned(TupleSpec::homogeneous(result))),
Err(dunder_anext_error) => {
Ok(Ok(result)) => Ok(Cow::Owned(TupleSpec::homogeneous(result))),
Ok(Err(AwaitError::InvalidReturnType(..))) => {
Err(IterationError::UnboundAiterError)
} // TODO: __anext__ is bound, but is not properly awaitable
Err(dunder_anext_error) | Ok(Err(AwaitError::Call(dunder_anext_error))) => {
Err(IterationError::IterReturnsInvalidIterator {
iterator,
dunder_error: dunder_anext_error,
@ -4996,7 +5000,7 @@ impl<'db> Type<'db> {
(Ok(enter), Ok(_)) => {
let ty = enter.return_type(db);
Ok(if mode.is_async() {
ty.resolve_await(db)
ty.try_await(db).unwrap_or(Type::unknown())
} else {
ty
})
@ -5005,7 +5009,7 @@ impl<'db> Type<'db> {
let ty = enter.return_type(db);
Err(ContextManagerError::Exit {
enter_return_type: if mode.is_async() {
ty.resolve_await(db)
ty.try_await(db).unwrap_or(Type::unknown())
} else {
ty
},
@ -5024,15 +5028,17 @@ impl<'db> Type<'db> {
}
/// Resolve the type of an `await …` expression where `self` is the type of the awaitable.
fn resolve_await(self, db: &'db dyn Db) -> Type<'db> {
// TODO: Add proper error handling and rename this method to `try_await`.
self.try_call_dunder(db, "__await__", CallArguments::none())
.map_or(Type::unknown(), |result| {
result
.return_type(db)
.generator_return_type(db)
.unwrap_or_else(Type::unknown)
})
fn try_await(self, db: &'db dyn Db) -> Result<Type<'db>, AwaitError<'db>> {
let await_result = self.try_call_dunder(db, "__await__", CallArguments::none());
match await_result {
Ok(bindings) => {
let return_type = bindings.return_type(db);
Ok(return_type.generator_return_type(db).ok_or_else(|| {
AwaitError::InvalidReturnType(return_type, Box::new(bindings))
})?)
}
Err(call_error) => Err(AwaitError::Call(call_error)),
}
}
/// Get the return type of a `yield from …` expression where `self` is the type of the generator.
@ -5068,6 +5074,8 @@ impl<'db> Type<'db> {
None
}
}
Type::Union(union) => union.try_map(db, |ty| ty.generator_return_type(db)),
ty @ (Type::Dynamic(_) | Type::Never) => Some(ty),
_ => None,
}
}
@ -7224,6 +7232,97 @@ impl<'db> TypeVarBoundOrConstraints<'db> {
}
}
/// Error returned if a type is not awaitable.
#[derive(Debug)]
enum AwaitError<'db> {
/// `__await__` is either missing, potentially unbound or cannot be called with provided
/// arguments.
Call(CallDunderError<'db>),
/// `__await__` resolved successfully, but its return type is known not to be a generator.
InvalidReturnType(Type<'db>, Box<Bindings<'db>>),
}
impl<'db> AwaitError<'db> {
fn report_diagnostic(
&self,
context: &InferContext<'db, '_>,
context_expression_type: Type<'db>,
context_expression_node: ast::AnyNodeRef,
) {
let Some(builder) = context.report_lint(&INVALID_AWAIT, context_expression_node) else {
return;
};
let db = context.db();
let mut diag = builder.into_diagnostic(
format_args!("`{type}` is not awaitable", type = context_expression_type.display(db)),
);
match self {
Self::Call(CallDunderError::CallError(CallErrorKind::BindingError, bindings)) => {
diag.info("`__await__` requires arguments and cannot be called implicitly");
if let Some(definition_spans) = bindings.callable_type().function_spans(db) {
diag.annotate(
Annotation::secondary(definition_spans.parameters)
.message("parameters here"),
);
}
}
Self::Call(CallDunderError::CallError(
kind @ (CallErrorKind::NotCallable | CallErrorKind::PossiblyNotCallable),
bindings,
)) => {
let possibly = if matches!(kind, CallErrorKind::PossiblyNotCallable) {
" possibly"
} else {
""
};
diag.info(format_args!("`__await__` is{possibly} not callable"));
if let Some(definition) = bindings.callable_type().definition(db) {
if let Some(definition_range) = definition.focus_range(db) {
diag.annotate(
Annotation::secondary(definition_range.into())
.message("attribute defined here"),
);
}
}
}
Self::Call(CallDunderError::PossiblyUnbound(bindings)) => {
diag.info("`__await__` is possibly unbound");
if let Some(definition_spans) = bindings.callable_type().function_spans(db) {
diag.annotate(
Annotation::secondary(definition_spans.signature)
.message("method defined here"),
);
}
}
Self::Call(CallDunderError::MethodNotAvailable) => {
diag.info("`__await__` is missing");
if let Some(type_definition) = context_expression_type.definition(db) {
if let Some(definition_range) = type_definition.focus_range(db) {
diag.annotate(
Annotation::secondary(definition_range.into())
.message("type defined here"),
);
}
}
}
Self::InvalidReturnType(return_type, bindings) => {
diag.info(format_args!(
"`__await__` returns `{return_type}`, which is not a valid iterator",
return_type = return_type.display(db)
));
if let Some(definition_spans) = bindings.callable_type().function_spans(db) {
diag.annotate(
Annotation::secondary(definition_spans.signature)
.message("method defined here"),
);
}
}
}
}
}
/// Error returned if a type is not (or may not be) a context manager.
#[derive(Debug)]
enum ContextManagerError<'db> {
@ -7447,11 +7546,11 @@ impl<'db> IterationError<'db> {
match self {
Self::IterReturnsInvalidIterator {
dunder_error, mode, ..
} => dunder_error.return_type(db).map(|ty| {
} => dunder_error.return_type(db).and_then(|ty| {
if mode.is_async() {
ty.resolve_await(db)
ty.try_await(db).ok()
} else {
ty
Some(ty)
}
}),
@ -7466,7 +7565,7 @@ impl<'db> IterationError<'db> {
"__anext__",
CallArguments::none(),
))
.map(|ty| ty.resolve_await(db))
.and_then(|ty| ty.try_await(db).ok())
} else {
return_type(dunder_iter_bindings.return_type(db).try_call_dunder(
db,

View file

@ -45,6 +45,7 @@ pub(crate) fn register_lints(registry: &mut LintRegistryBuilder) {
registry.register_lint(&INVALID_ARGUMENT_TYPE);
registry.register_lint(&INVALID_RETURN_TYPE);
registry.register_lint(&INVALID_ASSIGNMENT);
registry.register_lint(&INVALID_AWAIT);
registry.register_lint(&INVALID_BASE);
registry.register_lint(&INVALID_CONTEXT_MANAGER);
registry.register_lint(&INVALID_DECLARATION);
@ -578,6 +579,36 @@ declare_lint! {
}
}
declare_lint! {
/// ## What it does
/// Checks for `await` being used with types that are not [Awaitable].
///
/// ## Why is this bad?
/// Such expressions will lead to `TypeError` being raised at runtime.
///
/// ## Examples
/// ```python
/// import asyncio
///
/// class InvalidAwait:
/// def __await__(self) -> int:
/// return 5
///
/// async def main() -> None:
/// await InvalidAwait() # error: [invalid-await]
/// await 42 # error: [invalid-await]
///
/// asyncio.run(main())
/// ```
///
/// [Awaitable]: https://docs.python.org/3/library/collections.abc.html#collections.abc.Awaitable
pub(crate) static INVALID_AWAIT = {
summary: "detects awaiting on types that don't support it",
status: LintStatus::preview("1.0.0"),
default_level: Level::Error,
}
}
declare_lint! {
/// ## What it does
/// Checks for class definitions that have bases which are not instances of `type`.

View file

@ -94,7 +94,6 @@ pub(crate) struct FunctionSpans {
pub(crate) name: Span,
/// The span of the parameter list, including the opening and
/// closing parentheses.
#[expect(dead_code)]
pub(crate) parameters: Span,
/// The span of the annotated return type, if present.
pub(crate) return_type: Option<Span>,

View file

@ -6380,7 +6380,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
node_index: _,
value,
} = await_expression;
self.infer_expression(value).resolve_await(self.db())
let expr_type = self.infer_expression(value);
expr_type.try_await(self.db()).unwrap_or_else(|err| {
err.report_diagnostic(&self.context, expr_type, value.as_ref().into());
Type::unknown()
})
}
// Perform narrowing with applicable constraints between the current scope and the enclosing scope.