[ty] Support async/await, async with and yield from (#19595)

## Summary

- Add support for the return types of `async` functions
- Add type inference for `await` expressions
- Add support for `async with` / async context managers
- Add support for `yield from` expressions

This PR is generally lacking proper error handling in some cases (e.g.
illegal `__await__` attributes). I'm planning to work on this in a
follow-up.

part of https://github.com/astral-sh/ty/issues/151

closes https://github.com/astral-sh/ty/issues/736

## Ecosystem

There are a lot of true positives on `prefect` which look similar to:
```diff
prefect (https://github.com/PrefectHQ/prefect)
+ src/integrations/prefect-aws/tests/workers/test_ecs_worker.py:406:12: error[unresolved-attribute] Type `str` has no attribute `status_code`
```

This is due to a wrong return type annotation
[here](e926b8c4c1/src/integrations/prefect-aws/tests/workers/test_ecs_worker.py (L355-L391)).

```diff
mitmproxy (https://github.com/mitmproxy/mitmproxy)
+ test/mitmproxy/addons/test_clientplayback.py:18:1: error[invalid-argument-type] Argument to function `asynccontextmanager` is incorrect: Expected `(...) -> AsyncIterator[Unknown]`, found `def tcp_server(handle_conn, **server_args) -> Unknown | tuple[str, int]`
```


[This](a4d794c59a/test/mitmproxy/addons/test_clientplayback.py (L18-L19))
is a true positive. That function should return
`AsyncIterator[Address]`, not `Address`.

I looked through almost all of the other new diagnostics and they all
look like known problems or true positives.

## Typing conformance

The typing conformance diff looks good.

## Test Plan

New Markdown tests
This commit is contained in:
David Peter 2025-07-30 11:51:21 +02:00 committed by GitHub
parent c5ac998892
commit 4ecf1d205a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 472 additions and 46 deletions

View file

@ -0,0 +1,123 @@
# `async` / `await`
## Basic
```py
async def retrieve() -> int:
return 42
async def main():
result = await retrieve()
reveal_type(result) # revealed: int
```
## Generic `async` functions
```py
from typing import TypeVar
T = TypeVar("T")
async def persist(x: T) -> T:
return x
async def f(x: int):
result = await persist(x)
reveal_type(result) # revealed: int
```
## Use cases
### `Future`
```py
import asyncio
import concurrent.futures
def blocking_function() -> int:
return 42
async def main():
loop = asyncio.get_event_loop()
with concurrent.futures.ThreadPoolExecutor() as pool:
result = await loop.run_in_executor(pool, blocking_function)
# TODO: should be `int`
reveal_type(result) # revealed: Unknown
```
### `asyncio.Task`
```py
import asyncio
async def f() -> int:
return 1
async def main():
task = asyncio.create_task(f())
result = await task
# TODO: this should be `int`
reveal_type(result) # revealed: Unknown
```
### `asyncio.gather`
```py
import asyncio
async def task(name: str) -> int:
return len(name)
async def main():
(a, b) = await asyncio.gather(
task("A"),
task("B"),
)
# TODO: these should be `int`
reveal_type(a) # revealed: Unknown
reveal_type(b) # revealed: Unknown
```
## Under the hood
```toml
[environment]
python-version = "3.12" # Use 3.12 to be able to use PEP 695 generics
```
Let's look at the example from the beginning again:
```py
async def retrieve() -> int:
return 42
```
When we look at the signature of this function, we see that it actually returns a `CoroutineType`:
```py
reveal_type(retrieve) # revealed: def retrieve() -> CoroutineType[Any, Any, int]
```
The expression `await retrieve()` desugars into a call to the `__await__` dunder method on the
`CoroutineType` object, followed by a `yield from`. Let's first see the return type of `__await__`:
```py
reveal_type(retrieve().__await__()) # revealed: Generator[Any, None, int]
```
We can see that this returns a `Generator` that yields `Any`, and eventually returns `int`. For the
final type of the `await` expression, we retrieve that third argument of the `Generator` type:
```py
from typing import Generator
def _():
result = yield from retrieve().__await__()
reveal_type(result) # revealed: int
```

View file

@ -15,8 +15,7 @@ reveal_type(get_int()) # revealed: int
async def get_int_async() -> int: async def get_int_async() -> int:
return 42 return 42
# TODO: we don't yet support `types.CoroutineType`, should be generic `Coroutine[Any, Any, int]` reveal_type(get_int_async()) # revealed: CoroutineType[Any, Any, int]
reveal_type(get_int_async()) # revealed: @Todo(generic types.CoroutineType)
``` ```
## Generic ## Generic

View file

@ -0,0 +1,130 @@
# `yield` and `yield from`
## Basic `yield` and `yield from`
The type of a `yield` expression is the "send" type of the generator function. The type of a
`yield from` expression is the return type of the inner generator:
```py
from typing import Generator
def inner_generator() -> Generator[int, bytes, str]:
yield 1
yield 2
x = yield 3
# TODO: this should be `bytes`
reveal_type(x) # revealed: @Todo(yield expressions)
return "done"
def outer_generator():
result = yield from inner_generator()
reveal_type(result) # revealed: str
```
## `yield from` with a custom iterable
`yield from` can also be used with custom iterable types. In that case, the type of the `yield from`
expression can not be determined
```py
from typing import Generator, TypeVar, Generic
T = TypeVar("T")
class OnceIterator(Generic[T]):
def __init__(self, value: T):
self.value = value
self.returned = False
def __next__(self) -> T:
if self.returned:
raise StopIteration(42)
self.returned = True
return self.value
class Once(Generic[T]):
def __init__(self, value: T):
self.value = value
def __iter__(self) -> OnceIterator[T]:
return OnceIterator(self.value)
for x in Once("a"):
reveal_type(x) # revealed: str
def generator() -> Generator:
result = yield from Once("a")
# At runtime, the value of `result` will be the `.value` attribute of the `StopIteration`
# error raised by `OnceIterator` to signal to the interpreter that the iterator has been
# exhausted. Here that will always be 42, but this information cannot be captured in the
# signature of `OnceIterator.__next__`, since exceptions lie outside the type signature.
# We therefore just infer `Unknown` here.
#
# If the `StopIteration` error in `OnceIterator.__next__` had been simply `raise StopIteration`
# (the more common case), then the `.value` attribute of the `StopIteration` instance
# would default to `None`.
reveal_type(result) # revealed: Unknown
```
## `yield from` with a generator that return `types.GeneratorType`
`types.GeneratorType` is a nominal type that implements the `typing.Generator` protocol:
```py
from types import GeneratorType
def inner_generator() -> GeneratorType[int, bytes, str]:
yield 1
yield 2
x = yield 3
# TODO: this should be `bytes`
reveal_type(x) # revealed: @Todo(yield expressions)
return "done"
def outer_generator():
result = yield from inner_generator()
reveal_type(result) # revealed: str
```
## Error cases
### Non-iterable type
```py
from typing import Generator
def generator() -> Generator:
yield from 42 # error: [not-iterable] "Object of type `Literal[42]` is not iterable"
```
### Invalid `yield` type
```py
from typing import Generator
# TODO: This should be an error. Claims to yield `int`, but yields `str`.
def invalid_generator() -> Generator[int, None, None]:
yield "not an int" # This should be an `int`
```
### Invalid return type
```py
from typing import Generator
# TODO: should emit an error (does not return `str`)
def invalid_generator1() -> Generator[int, None, str]:
yield 1
# TODO: should emit an error (does not return `int`)
def invalid_generator2() -> Generator[int, None, None]:
yield 1
return "done"
```

View file

@ -17,5 +17,80 @@ class Manager:
async def test(): async def test():
async with Manager() as f: async with Manager() as f:
reveal_type(f) # revealed: @Todo(async `with` statement) reveal_type(f) # revealed: Target
```
## Multiple targets
```py
class Manager:
async def __aenter__(self) -> tuple[int, str]:
return 42, "hello"
async def __aexit__(self, exc_type, exc_value, traceback): ...
async def test():
async with Manager() as (x, y):
reveal_type(x) # revealed: int
reveal_type(y) # revealed: str
```
## `@asynccontextmanager`
```py
from contextlib import asynccontextmanager
from typing import AsyncGenerator
class Session: ...
@asynccontextmanager
async def connect() -> AsyncGenerator[Session]:
yield Session()
# TODO: this should be `() -> _AsyncGeneratorContextManager[Session, None]`
reveal_type(connect) # revealed: (...) -> _AsyncGeneratorContextManager[Unknown, None]
async def main():
async with connect() as session:
# TODO: should be `Session`
reveal_type(session) # revealed: Unknown
```
## `asyncio.timeout`
```toml
[environment]
python-version = "3.11"
```
```py
import asyncio
async def long_running_task():
await asyncio.sleep(5)
async def main():
async with asyncio.timeout(1):
await long_running_task()
```
## `asyncio.TaskGroup`
```toml
[environment]
python-version = "3.11"
```
```py
import asyncio
async def long_running_task():
await asyncio.sleep(5)
async def main():
async with asyncio.TaskGroup() as tg:
# TODO: should be `TaskGroup`
reveal_type(tg) # revealed: Unknown
tg.create_task(long_running_task())
``` ```

View file

@ -2805,7 +2805,9 @@ impl<'ast> Unpackable<'ast> {
match self { match self {
Unpackable::Assign(_) => UnpackKind::Assign, Unpackable::Assign(_) => UnpackKind::Assign,
Unpackable::For(_) | Unpackable::Comprehension { .. } => UnpackKind::Iterable, Unpackable::For(_) | Unpackable::Comprehension { .. } => UnpackKind::Iterable,
Unpackable::WithItem { .. } => UnpackKind::ContextManager, Unpackable::WithItem { is_async, .. } => UnpackKind::ContextManager {
is_async: *is_async,
},
} }
} }

View file

@ -4790,6 +4790,64 @@ impl<'db> Type<'db> {
} }
} }
/// Similar to [`Self::try_enter`], but for async context managers.
fn aenter(self, db: &'db dyn Db) -> Type<'db> {
// TODO: Add proper error handling and rename this method to `try_aenter`.
self.try_call_dunder(db, "__aenter__", CallArguments::none())
.map_or(Type::unknown(), |result| {
result.return_type(db).resolve_await(db)
})
}
/// Resolve the type of an `await …` expression where `self` is the type of the awaitable.
fn resolve_await(self, db: &'db dyn Db) -> Type<'db> {
// TODO: Add proper error handling and rename this method to `try_await`.
self.try_call_dunder(db, "__await__", CallArguments::none())
.map_or(Type::unknown(), |result| {
result
.return_type(db)
.generator_return_type(db)
.unwrap_or_else(Type::unknown)
})
}
/// Get the return type of a `yield from …` expression where `self` is the type of the generator.
///
/// This corresponds to the `ReturnT` parameter of the generic `typing.Generator[YieldT, SendT, ReturnT]`
/// protocol.
fn generator_return_type(self, db: &'db dyn Db) -> Option<Type<'db>> {
// TODO: Ideally, we would first try to upcast `self` to an instance of `Generator` and *then*
// match on the protocol instance to get the `ReturnType` type parameter. For now, implement
// an ad-hoc solution that works for protocols and instances of classes that directly inherit
// from the `Generator` protocol, such as `types.GeneratorType`.
let from_class_base = |base: ClassBase<'db>| {
let class = base.into_class()?;
if class.is_known(db, KnownClass::Generator) {
if let Some(specialization) = class.class_literal_specialized(db, None).1 {
if let [_, _, return_ty] = specialization.types(db) {
return Some(*return_ty);
}
}
}
None
};
match self {
Type::NominalInstance(instance) => {
instance.class.iter_mro(db).find_map(from_class_base)
}
Type::ProtocolInstance(instance) => {
if let Protocol::FromClass(class) = instance.inner {
class.iter_mro(db).find_map(from_class_base)
} else {
None
}
}
_ => None,
}
}
/// Given a class literal or non-dynamic SubclassOf type, try calling it (creating an instance) /// Given a class literal or non-dynamic SubclassOf type, try calling it (creating an instance)
/// and return the resulting instance type. /// and return the resulting instance type.
/// ///

View file

@ -2614,10 +2614,13 @@ pub enum KnownClass {
UnionType, UnionType,
GeneratorType, GeneratorType,
AsyncGeneratorType, AsyncGeneratorType,
CoroutineType,
// Typeshed // Typeshed
NoneType, // Part of `types` for Python >= 3.10 NoneType, // Part of `types` for Python >= 3.10
// Typing // Typing
Any, Any,
Awaitable,
Generator,
Deprecated, Deprecated,
StdlibAlias, StdlibAlias,
SpecialForm, SpecialForm,
@ -2689,7 +2692,8 @@ impl KnownClass {
| Self::UnionType | Self::UnionType
| Self::GeneratorType | Self::GeneratorType
| Self::AsyncGeneratorType | Self::AsyncGeneratorType
| Self::MethodWrapperType => Truthiness::AlwaysTrue, | Self::MethodWrapperType
| Self::CoroutineType => Truthiness::AlwaysTrue,
Self::NoneType => Truthiness::AlwaysFalse, Self::NoneType => Truthiness::AlwaysFalse,
@ -2740,6 +2744,8 @@ impl KnownClass {
| Self::NotImplementedType | Self::NotImplementedType
| Self::Staticmethod | Self::Staticmethod
| Self::Classmethod | Self::Classmethod
| Self::Awaitable
| Self::Generator
| Self::Deprecated | Self::Deprecated
| Self::Field | Self::Field
| Self::KwOnly | Self::KwOnly
@ -2805,12 +2811,15 @@ impl KnownClass {
| Self::InitVar | Self::InitVar
| Self::VersionInfo | Self::VersionInfo
| Self::Bool | Self::Bool
| Self::NoneType => false, | Self::NoneType
| Self::CoroutineType => false,
// Anything with a *runtime* MRO (N.B. sometimes different from the MRO that typeshed gives!) // Anything with a *runtime* MRO (N.B. sometimes different from the MRO that typeshed gives!)
// with length >2, or anything that is implemented in pure Python, is not a solid base. // with length >2, or anything that is implemented in pure Python, is not a solid base.
Self::ABCMeta Self::ABCMeta
| Self::Any | Self::Any
| Self::Awaitable
| Self::Generator
| Self::Enum | Self::Enum
| Self::EnumType | Self::EnumType
| Self::Auto | Self::Auto
@ -2859,6 +2868,8 @@ impl KnownClass {
| KnownClass::ExceptionGroup | KnownClass::ExceptionGroup
| KnownClass::Staticmethod | KnownClass::Staticmethod
| KnownClass::Classmethod | KnownClass::Classmethod
| KnownClass::Awaitable
| KnownClass::Generator
| KnownClass::Deprecated | KnownClass::Deprecated
| KnownClass::Super | KnownClass::Super
| KnownClass::Enum | KnownClass::Enum
@ -2876,6 +2887,7 @@ impl KnownClass {
| KnownClass::UnionType | KnownClass::UnionType
| KnownClass::GeneratorType | KnownClass::GeneratorType
| KnownClass::AsyncGeneratorType | KnownClass::AsyncGeneratorType
| KnownClass::CoroutineType
| KnownClass::NoneType | KnownClass::NoneType
| KnownClass::Any | KnownClass::Any
| KnownClass::StdlibAlias | KnownClass::StdlibAlias
@ -2921,7 +2933,11 @@ impl KnownClass {
/// 2. It's probably more performant. /// 2. It's probably more performant.
const fn is_protocol(self) -> bool { const fn is_protocol(self) -> bool {
match self { match self {
Self::SupportsIndex | Self::Iterable | Self::Iterator => true, Self::SupportsIndex
| Self::Iterable
| Self::Iterator
| Self::Awaitable
| Self::Generator => true,
Self::Any Self::Any
| Self::Bool | Self::Bool
@ -2950,6 +2966,7 @@ impl KnownClass {
| Self::GenericAlias | Self::GenericAlias
| Self::GeneratorType | Self::GeneratorType
| Self::AsyncGeneratorType | Self::AsyncGeneratorType
| Self::CoroutineType
| Self::ModuleType | Self::ModuleType
| Self::FunctionType | Self::FunctionType
| Self::MethodType | Self::MethodType
@ -3015,6 +3032,8 @@ impl KnownClass {
Self::ExceptionGroup => "ExceptionGroup", Self::ExceptionGroup => "ExceptionGroup",
Self::Staticmethod => "staticmethod", Self::Staticmethod => "staticmethod",
Self::Classmethod => "classmethod", Self::Classmethod => "classmethod",
Self::Awaitable => "Awaitable",
Self::Generator => "Generator",
Self::Deprecated => "deprecated", Self::Deprecated => "deprecated",
Self::GenericAlias => "GenericAlias", Self::GenericAlias => "GenericAlias",
Self::ModuleType => "ModuleType", Self::ModuleType => "ModuleType",
@ -3025,6 +3044,7 @@ impl KnownClass {
Self::WrapperDescriptorType => "WrapperDescriptorType", Self::WrapperDescriptorType => "WrapperDescriptorType",
Self::GeneratorType => "GeneratorType", Self::GeneratorType => "GeneratorType",
Self::AsyncGeneratorType => "AsyncGeneratorType", Self::AsyncGeneratorType => "AsyncGeneratorType",
Self::CoroutineType => "CoroutineType",
Self::NamedTuple => "NamedTuple", Self::NamedTuple => "NamedTuple",
Self::NoneType => "NoneType", Self::NoneType => "NoneType",
Self::SpecialForm => "_SpecialForm", Self::SpecialForm => "_SpecialForm",
@ -3285,11 +3305,14 @@ impl KnownClass {
| Self::MethodType | Self::MethodType
| Self::GeneratorType | Self::GeneratorType
| Self::AsyncGeneratorType | Self::AsyncGeneratorType
| Self::CoroutineType
| Self::MethodWrapperType | Self::MethodWrapperType
| Self::UnionType | Self::UnionType
| Self::WrapperDescriptorType => KnownModule::Types, | Self::WrapperDescriptorType => KnownModule::Types,
Self::NoneType => KnownModule::Typeshed, Self::NoneType => KnownModule::Typeshed,
Self::Any Self::Any
| Self::Awaitable
| Self::Generator
| Self::SpecialForm | Self::SpecialForm
| Self::TypeVar | Self::TypeVar
| Self::NamedTuple | Self::NamedTuple
@ -3370,12 +3393,15 @@ impl KnownClass {
| Self::ExceptionGroup | Self::ExceptionGroup
| Self::Staticmethod | Self::Staticmethod
| Self::Classmethod | Self::Classmethod
| Self::Awaitable
| Self::Generator
| Self::Deprecated | Self::Deprecated
| Self::GenericAlias | Self::GenericAlias
| Self::ModuleType | Self::ModuleType
| Self::FunctionType | Self::FunctionType
| Self::GeneratorType | Self::GeneratorType
| Self::AsyncGeneratorType | Self::AsyncGeneratorType
| Self::CoroutineType
| Self::MethodType | Self::MethodType
| Self::MethodWrapperType | Self::MethodWrapperType
| Self::WrapperDescriptorType | Self::WrapperDescriptorType
@ -3447,6 +3473,7 @@ impl KnownClass {
| Self::WrapperDescriptorType | Self::WrapperDescriptorType
| Self::GeneratorType | Self::GeneratorType
| Self::AsyncGeneratorType | Self::AsyncGeneratorType
| Self::CoroutineType
| Self::SpecialForm | Self::SpecialForm
| Self::ChainMap | Self::ChainMap
| Self::Counter | Self::Counter
@ -3461,6 +3488,8 @@ impl KnownClass {
| Self::ExceptionGroup | Self::ExceptionGroup
| Self::Staticmethod | Self::Staticmethod
| Self::Classmethod | Self::Classmethod
| Self::Awaitable
| Self::Generator
| Self::Deprecated | Self::Deprecated
| Self::TypeVar | Self::TypeVar
| Self::ParamSpec | Self::ParamSpec
@ -3517,12 +3546,15 @@ impl KnownClass {
"ExceptionGroup" => Self::ExceptionGroup, "ExceptionGroup" => Self::ExceptionGroup,
"staticmethod" => Self::Staticmethod, "staticmethod" => Self::Staticmethod,
"classmethod" => Self::Classmethod, "classmethod" => Self::Classmethod,
"Awaitable" => Self::Awaitable,
"Generator" => Self::Generator,
"deprecated" => Self::Deprecated, "deprecated" => Self::Deprecated,
"GenericAlias" => Self::GenericAlias, "GenericAlias" => Self::GenericAlias,
"NoneType" => Self::NoneType, "NoneType" => Self::NoneType,
"ModuleType" => Self::ModuleType, "ModuleType" => Self::ModuleType,
"GeneratorType" => Self::GeneratorType, "GeneratorType" => Self::GeneratorType,
"AsyncGeneratorType" => Self::AsyncGeneratorType, "AsyncGeneratorType" => Self::AsyncGeneratorType,
"CoroutineType" => Self::CoroutineType,
"FunctionType" => Self::FunctionType, "FunctionType" => Self::FunctionType,
"MethodType" => Self::MethodType, "MethodType" => Self::MethodType,
"UnionType" => Self::UnionType, "UnionType" => Self::UnionType,
@ -3627,11 +3659,14 @@ impl KnownClass {
| Self::UnionType | Self::UnionType
| Self::GeneratorType | Self::GeneratorType
| Self::AsyncGeneratorType | Self::AsyncGeneratorType
| Self::CoroutineType
| Self::WrapperDescriptorType | Self::WrapperDescriptorType
| Self::Field | Self::Field
| Self::KwOnly | Self::KwOnly
| Self::InitVar | Self::InitVar
| Self::NamedTupleFallback => module == self.canonical_module(db), | Self::NamedTupleFallback
| Self::Awaitable
| Self::Generator => module == self.canonical_module(db),
Self::NoneType => matches!(module, KnownModule::Typeshed | KnownModule::Types), Self::NoneType => matches!(module, KnownModule::Typeshed | KnownModule::Types),
Self::SpecialForm Self::SpecialForm
| Self::TypeVar | Self::TypeVar

View file

@ -341,12 +341,16 @@ impl<'db> OverloadLiteral<'db> {
GenericContext::from_type_params(db, index, type_params) GenericContext::from_type_params(db, index, type_params)
}); });
let index = semantic_index(db, scope.file(db));
let is_generator = scope.file_scope_id(db).is_generator_function(index);
Signature::from_function( Signature::from_function(
db, db,
generic_context, generic_context,
inherited_generic_context, inherited_generic_context,
definition, definition,
function_stmt_node, function_stmt_node,
is_generator,
) )
} }

View file

@ -3169,26 +3169,17 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let context_expr = with_item.context_expr(self.module()); let context_expr = with_item.context_expr(self.module());
let target = with_item.target(self.module()); let target = with_item.target(self.module());
let target_ty = if with_item.is_async() { let target_ty = match with_item.target_kind() {
let _context_expr_ty = self.infer_standalone_expression(context_expr); TargetKind::Sequence(unpack_position, unpack) => {
todo_type!("async `with` statement") let unpacked = infer_unpack_types(self.db(), unpack);
} else { if unpack_position == UnpackPosition::First {
match with_item.target_kind() { self.context.extend(unpacked.diagnostics());
TargetKind::Sequence(unpack_position, unpack) => {
let unpacked = infer_unpack_types(self.db(), unpack);
if unpack_position == UnpackPosition::First {
self.context.extend(unpacked.diagnostics());
}
unpacked.expression_type(target)
}
TargetKind::Single => {
let context_expr_ty = self.infer_standalone_expression(context_expr);
self.infer_context_expression(
context_expr,
context_expr_ty,
with_item.is_async(),
)
} }
unpacked.expression_type(target)
}
TargetKind::Single => {
let context_expr_ty = self.infer_standalone_expression(context_expr);
self.infer_context_expression(context_expr, context_expr_ty, with_item.is_async())
} }
}; };
@ -3208,9 +3199,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
context_expression_type: Type<'db>, context_expression_type: Type<'db>,
is_async: bool, is_async: bool,
) -> Type<'db> { ) -> Type<'db> {
// TODO: Handle async with statements (they use `aenter` and `aexit`)
if is_async { if is_async {
return todo_type!("async `with` statement"); return context_expression_type.aenter(self.db());
} }
context_expression_type context_expression_type
@ -6102,8 +6092,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
err.fallback_element_type(self.db()) err.fallback_element_type(self.db())
}); });
// TODO get type from `ReturnType` of generator iterable_type
todo_type!("Generic `typing.Generator` type") .generator_return_type(self.db())
.unwrap_or_else(Type::unknown)
} }
fn infer_await_expression(&mut self, await_expression: &ast::ExprAwait) -> Type<'db> { fn infer_await_expression(&mut self, await_expression: &ast::ExprAwait) -> Type<'db> {
@ -6112,8 +6103,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
node_index: _, node_index: _,
value, value,
} = await_expression; } = await_expression;
self.infer_expression(value); self.infer_expression(value).resolve_await(self.db())
todo_type!("generic `typing.Awaitable` type")
} }
// Perform narrowing with applicable constraints between the current scope and the enclosing scope. // Perform narrowing with applicable constraints between the current scope and the enclosing scope.

View file

@ -18,7 +18,7 @@ use smallvec::{SmallVec, smallvec_inline};
use super::{DynamicType, Type, TypeTransformer, TypeVarVariance, definition_expression_type}; use super::{DynamicType, Type, TypeTransformer, TypeVarVariance, definition_expression_type};
use crate::semantic_index::definition::Definition; use crate::semantic_index::definition::Definition;
use crate::types::generics::{GenericContext, walk_generic_context}; use crate::types::generics::{GenericContext, walk_generic_context};
use crate::types::{TypeMapping, TypeRelation, TypeVarInstance, todo_type}; use crate::types::{KnownClass, TypeMapping, TypeRelation, TypeVarInstance, todo_type};
use crate::{Db, FxOrderSet}; use crate::{Db, FxOrderSet};
use ruff_python_ast::{self as ast, name::Name}; use ruff_python_ast::{self as ast, name::Name};
@ -320,14 +320,18 @@ impl<'db> Signature<'db> {
inherited_generic_context: Option<GenericContext<'db>>, inherited_generic_context: Option<GenericContext<'db>>,
definition: Definition<'db>, definition: Definition<'db>,
function_node: &ast::StmtFunctionDef, function_node: &ast::StmtFunctionDef,
is_generator: bool,
) -> Self { ) -> Self {
let parameters = let parameters =
Parameters::from_parameters(db, definition, function_node.parameters.as_ref()); Parameters::from_parameters(db, definition, function_node.parameters.as_ref());
let return_ty = function_node.returns.as_ref().map(|returns| { let return_ty = function_node.returns.as_ref().map(|returns| {
if function_node.is_async { let plain_return_ty = definition_expression_type(db, definition, returns.as_ref());
todo_type!("generic types.CoroutineType")
if function_node.is_async && !is_generator {
KnownClass::CoroutineType
.to_specialized_instance(db, [Type::any(), Type::any(), plain_return_ty])
} else { } else {
definition_expression_type(db, definition, returns.as_ref()) plain_return_ty
} }
}); });
let legacy_generic_context = let legacy_generic_context =

View file

@ -75,14 +75,20 @@ impl<'db, 'ast> Unpacker<'db, 'ast> {
); );
err.fallback_element_type(self.db()) err.fallback_element_type(self.db())
}), }),
UnpackKind::ContextManager => value_type.try_enter(self.db()).unwrap_or_else(|err| { UnpackKind::ContextManager { is_async } => {
err.report_diagnostic( if is_async {
&self.context, value_type.aenter(self.db())
value_type, } else {
value.as_any_node_ref(self.db(), self.module()), value_type.try_enter(self.db()).unwrap_or_else(|err| {
); err.report_diagnostic(
err.fallback_enter_type(self.db()) &self.context,
}), value_type,
value.as_any_node_ref(self.db(), self.module()),
);
err.fallback_enter_type(self.db())
})
}
}
}; };
self.unpack_inner( self.unpack_inner(

View file

@ -104,7 +104,7 @@ pub(crate) enum UnpackKind {
/// An iterable expression like the one in a `for` loop or a comprehension. /// An iterable expression like the one in a `for` loop or a comprehension.
Iterable, Iterable,
/// An context manager expression like the one in a `with` statement. /// An context manager expression like the one in a `with` statement.
ContextManager, ContextManager { is_async: bool },
/// An expression that is being assigned to a target. /// An expression that is being assigned to a target.
Assign, Assign,
} }