[ty] Type inference for genererator expressions (#21437)

## Summary

Add type inference for (async) generator expressions.

closes https://github.com/astral-sh/ty/issues/1510

## Test Plan

New Markdown tests.
This commit is contained in:
David Peter 2025-11-14 14:04:11 +01:00 committed by GitHub
parent 6a26f86778
commit 05cf53aae8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 87 additions and 13 deletions

View file

@ -1,6 +1,62 @@
# Generator expressions
## Basic
We infer specialized `GeneratorType` instance types for generator expressions:
```py
# revealed: GeneratorType[@Todo(generator expression yield type), @Todo(generator expression send type), @Todo(generator expression return type)]
reveal_type((x for x in range(42)))
# revealed: GeneratorType[int, None, None]
reveal_type(x for x in range(10))
# revealed: GeneratorType[tuple[int, str], None, None]
reveal_type((x, str(y)) for x in range(3) for y in range(3))
```
When used in a loop, the yielded type can be inferred:
```py
squares = (x**2 for x in range(10))
for s in squares:
reveal_type(s) # revealed: int
```
`GeneratorType` is covariant in its yielded type, so it can be used where a wider yielded type is
expected:
```py
from typing import Iterator
def process_numbers(x: Iterator[float]): ...
numbers = (x for x in range(10))
reveal_type(numbers) # revealed: GeneratorType[int, None, None]
process_numbers(numbers)
```
## Async generators
For async generator expressions, we infer specialized `AsyncGeneratorType` instance types:
```py
import asyncio
from typing import AsyncGenerator
async def slow_numbers() -> AsyncGenerator[int, None]:
current = 0
while True:
await asyncio.sleep(1)
yield current
current += 1
async def main() -> None:
slow_squares = (x**2 async for x in slow_numbers())
reveal_type(slow_squares) # revealed: AsyncGeneratorType[int, None]
async for s in slow_squares:
reveal_type(s) # revealed: int
print(s)
asyncio.run(main())
```

View file

@ -7392,33 +7392,51 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
/// Infer the type of the `iter` expression of the first comprehension.
fn infer_first_comprehension_iter(&mut self, comprehensions: &[ast::Comprehension]) {
/// Returns the evaluation mode (async or sync) of the comprehension.
fn infer_first_comprehension_iter(
&mut self,
comprehensions: &[ast::Comprehension],
) -> EvaluationMode {
let mut comprehensions_iter = comprehensions.iter();
let Some(first_comprehension) = comprehensions_iter.next() else {
unreachable!("Comprehension must contain at least one generator");
};
self.infer_standalone_expression(&first_comprehension.iter, TypeContext::default());
if first_comprehension.is_async {
EvaluationMode::Async
} else {
EvaluationMode::Sync
}
}
fn infer_generator_expression(&mut self, generator: &ast::ExprGenerator) -> Type<'db> {
let ast::ExprGenerator {
range: _,
node_index: _,
elt: _,
elt,
generators,
parenthesized: _,
} = generator;
self.infer_first_comprehension_iter(generators);
let evaluation_mode = self.infer_first_comprehension_iter(generators);
KnownClass::GeneratorType.to_specialized_instance(
self.db(),
[
todo_type!("generator expression yield type"),
todo_type!("generator expression send type"),
todo_type!("generator expression return type"),
],
)
let scope_id = self
.index
.node_scope(NodeWithScopeRef::GeneratorExpression(generator));
let scope = scope_id.to_scope_id(self.db(), self.file());
let inference = infer_scope_types(self.db(), scope);
let yield_type = inference.expression_type(elt.as_ref());
if evaluation_mode.is_async() {
KnownClass::AsyncGeneratorType
.to_specialized_instance(self.db(), [yield_type, Type::none(self.db())])
} else {
KnownClass::GeneratorType.to_specialized_instance(
self.db(),
[yield_type, Type::none(self.db()), Type::none(self.db())],
)
}
}
/// Return a specialization of the collection class (list, dict, set) based on the type context and the inferred