[ty] Type inference for comprehensions (#20962)

## Summary

Adds type inference for list/dict/set comprehensions, including
bidirectional inference:

```py
reveal_type({k: v for k, v in [("a", 1), ("b", 2)]})  # dict[Unknown | str, Unknown | int]

squares: list[int | None] = [x for x in range(10)]
reveal_type(squares)  # list[int | None]
```

## Ecosystem impact

I did spot check the changes and most of them seem like known
limitations or true positives. Without proper bidirectional inference,
we saw a lot of false positives.

## Test Plan

New Markdown tests
This commit is contained in:
David Peter 2025-11-02 14:35:33 +01:00 committed by GitHub
parent de1a6fb8ad
commit 73107a083c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 266 additions and 28 deletions

View file

@ -0,0 +1,18 @@
name_1
{0: 0 for unique_name_0 in unique_name_1 if name_1}
@[name_2 for unique_name_2 in name_2]
def name_2():
pass
def name_2():
pass
match 0:
case name_2():
pass
case []:
name_1 = 0

View file

@ -103,3 +103,92 @@ async def _():
# revealed: Unknown
[reveal_type(x) async for x in range(3)]
```
## Comprehension expression types
The type of the comprehension expression itself should reflect the inferred element type:
```py
from typing import TypedDict, Literal
# revealed: list[int | Unknown]
reveal_type([x for x in range(10)])
# revealed: set[int | Unknown]
reveal_type({x for x in range(10)})
# revealed: dict[int | Unknown, str | Unknown]
reveal_type({x: str(x) for x in range(10)})
# revealed: list[tuple[int, Unknown | str] | Unknown]
reveal_type([(x, y) for x in range(5) for y in ["a", "b", "c"]])
squares: list[int | None] = [x**2 for x in range(10)]
reveal_type(squares) # revealed: list[int | None]
```
Inference for comprehensions takes the type context into account:
```py
# Without type context:
reveal_type([x for x in [1, 2, 3]]) # revealed: list[Unknown | int]
reveal_type({x: "a" for x in [1, 2, 3]}) # revealed: dict[Unknown | int, str | Unknown]
reveal_type({str(x): x for x in [1, 2, 3]}) # revealed: dict[str | Unknown, Unknown | int]
reveal_type({x for x in [1, 2, 3]}) # revealed: set[Unknown | int]
# With type context:
xs: list[int] = [x for x in [1, 2, 3]]
reveal_type(xs) # revealed: list[int]
ys: dict[int, str] = {x: str(x) for x in [1, 2, 3]}
reveal_type(ys) # revealed: dict[int, str]
zs: set[int] = {x for x in [1, 2, 3]}
```
This also works for nested comprehensions:
```py
table = [[(x, y) for x in range(3)] for y in range(3)]
reveal_type(table) # revealed: list[list[tuple[int, int] | Unknown] | Unknown]
table_with_content: list[list[tuple[int, int, str | None]]] = [[(x, y, None) for x in range(3)] for y in range(3)]
reveal_type(table_with_content) # revealed: list[list[tuple[int, int, str | None]]]
```
The type context is propagated down into the comprehension:
```py
class Person(TypedDict):
name: str
persons: list[Person] = [{"name": n} for n in ["Alice", "Bob"]]
reveal_type(persons) # revealed: list[Person]
# TODO: This should be an error
invalid: list[Person] = [{"misspelled": n} for n in ["Alice", "Bob"]]
```
We promote literals to avoid overly-precise types in invariant positions:
```py
reveal_type([x for x in ("a", "b", "c")]) # revealed: list[str | Unknown]
reveal_type({x for x in (1, 2, 3)}) # revealed: set[int | Unknown]
reveal_type({k: 0 for k in ("a", "b", "c")}) # revealed: dict[str | Unknown, int | Unknown]
```
Type context can prevent this promotion from happening:
```py
list_of_literals: list[Literal["a", "b", "c"]] = [x for x in ("a", "b", "c")]
reveal_type(list_of_literals) # revealed: list[Literal["a", "b", "c"]]
dict_with_literal_keys: dict[Literal["a", "b", "c"], int] = {k: 0 for k in ("a", "b", "c")}
reveal_type(dict_with_literal_keys) # revealed: dict[Literal["a", "b", "c"], int]
dict_with_literal_values: dict[str, Literal[1, 2, 3]] = {str(k): k for k in (1, 2, 3)}
reveal_type(dict_with_literal_values) # revealed: dict[str, Literal[1, 2, 3]]
set_with_literals: set[Literal[1, 2, 3]] = {k for k in (1, 2, 3)}
reveal_type(set_with_literals) # revealed: set[Literal[1, 2, 3]]
```

View file

@ -51,6 +51,6 @@ reveal_type({"a": 1, "b": (1, 2), "c": (1, 2, 3)})
## Dict comprehensions
```py
# revealed: dict[@Todo(dict comprehension key type), @Todo(dict comprehension value type)]
# revealed: dict[int | Unknown, int | Unknown]
reveal_type({x: y for x, y in enumerate(range(42))})
```

View file

@ -41,5 +41,5 @@ reveal_type([1, (1, 2), (1, 2, 3)])
## List comprehensions
```py
reveal_type([x for x in range(42)]) # revealed: list[@Todo(list comprehension element type)]
reveal_type([x for x in range(42)]) # revealed: list[int | Unknown]
```

View file

@ -35,5 +35,5 @@ reveal_type({1, (1, 2), (1, 2, 3)})
## Set comprehensions
```py
reveal_type({x for x in range(42)}) # revealed: set[@Todo(set comprehension element type)]
reveal_type({x for x in range(42)}) # revealed: set[int | Unknown]
```

View file

@ -0,0 +1,50 @@
# Documentation of two fuzzer panics involving comprehensions
Type inference for comprehensions was added in <https://github.com/astral-sh/ruff/pull/20962>. It
added two new fuzzer panics that are documented here for regression testing.
## Too many cycle iterations in `place_by_id`
<!-- expect-panic: too many cycle iterations -->
```py
name_5(name_3)
[0 for unique_name_0 in unique_name_1 for unique_name_2 in name_3]
@{name_3 for unique_name_3 in unique_name_4}
class name_4[**name_3](0, name_2=name_5):
pass
try:
name_0 = name_4
except* 0:
pass
else:
match unique_name_12:
case 0:
from name_2 import name_3
case name_0():
@name_4
def name_3():
pass
(name_3 := 0)
@name_3
async def name_5():
pass
```
## Too many cycle iterations in `infer_definition_types`
<!-- expect-panic: too many cycle iterations -->
```py
for name_1 in {
{{0: name_4 for unique_name_0 in unique_name_1}: 0 for unique_name_2 in unique_name_3 if name_4}: 0
for unique_name_4 in name_1
for name_4 in name_1
}:
pass
```

View file

@ -534,6 +534,14 @@ pub struct FunctionLiteral<'db> {
// The Salsa heap is tracked separately.
impl get_size2::GetSize for FunctionLiteral<'_> {}
fn overloads_and_implementation_cycle_initial<'db>(
_db: &'db dyn Db,
_id: salsa::Id,
_self: FunctionLiteral<'db>,
) -> (Box<[OverloadLiteral<'db>]>, Option<OverloadLiteral<'db>>) {
(Box::new([]), None)
}
#[salsa::tracked]
impl<'db> FunctionLiteral<'db> {
fn name(self, db: &'db dyn Db) -> &'db ast::name::Name {
@ -576,7 +584,7 @@ impl<'db> FunctionLiteral<'db> {
self.last_definition(db).spans(db)
}
#[salsa::tracked(returns(ref), heap_size=ruff_memory_usage::heap_size)]
#[salsa::tracked(returns(ref), heap_size=ruff_memory_usage::heap_size, cycle_initial=overloads_and_implementation_cycle_initial)]
fn overloads_and_implementation(
self,
db: &'db dyn Db,

View file

@ -5943,9 +5943,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
ast::Expr::Set(set) => self.infer_set_expression(set, tcx),
ast::Expr::Dict(dict) => self.infer_dict_expression(dict, tcx),
ast::Expr::Generator(generator) => self.infer_generator_expression(generator),
ast::Expr::ListComp(listcomp) => self.infer_list_comprehension_expression(listcomp),
ast::Expr::DictComp(dictcomp) => self.infer_dict_comprehension_expression(dictcomp),
ast::Expr::SetComp(setcomp) => self.infer_set_comprehension_expression(setcomp),
ast::Expr::ListComp(listcomp) => {
self.infer_list_comprehension_expression(listcomp, tcx)
}
ast::Expr::DictComp(dictcomp) => {
self.infer_dict_comprehension_expression(dictcomp, tcx)
}
ast::Expr::SetComp(setcomp) => self.infer_set_comprehension_expression(setcomp, tcx),
ast::Expr::Name(name) => self.infer_name_expression(name),
ast::Expr::Attribute(attribute) => self.infer_attribute_expression(attribute),
ast::Expr::UnaryOp(unary_op) => self.infer_unary_expression(unary_op),
@ -6450,52 +6454,121 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
)
}
fn infer_list_comprehension_expression(&mut self, listcomp: &ast::ExprListComp) -> Type<'db> {
/// Return a specialization of the collection class (list, dict, set) based on the type context and the inferred
/// element / key-value types from the comprehension expression.
fn infer_comprehension_specialization(
&self,
collection_class: KnownClass,
inferred_element_types: &[Type<'db>],
tcx: TypeContext<'db>,
) -> Type<'db> {
// Remove any union elements of that are unrelated to the collection type.
let tcx = tcx.map(|annotation| {
annotation.filter_disjoint_elements(
self.db(),
collection_class.to_instance(self.db()),
InferableTypeVars::None,
)
});
if let Some(annotated_element_types) = tcx
.known_specialization(self.db(), collection_class)
.map(|specialization| specialization.types(self.db()))
&& annotated_element_types
.iter()
.zip(inferred_element_types.iter())
.all(|(annotated, inferred)| inferred.is_assignable_to(self.db(), *annotated))
{
collection_class
.to_specialized_instance(self.db(), annotated_element_types.iter().copied())
} else {
collection_class.to_specialized_instance(
self.db(),
inferred_element_types.iter().map(|ty| {
UnionType::from_elements(
self.db(),
[
ty.promote_literals(self.db(), TypeContext::default()),
Type::unknown(),
],
)
}),
)
}
}
fn infer_list_comprehension_expression(
&mut self,
listcomp: &ast::ExprListComp,
tcx: TypeContext<'db>,
) -> Type<'db> {
let ast::ExprListComp {
range: _,
node_index: _,
elt: _,
elt,
generators,
} = listcomp;
self.infer_first_comprehension_iter(generators);
KnownClass::List
.to_specialized_instance(self.db(), [todo_type!("list comprehension element type")])
let scope_id = self
.index
.node_scope(NodeWithScopeRef::ListComprehension(listcomp));
let scope = scope_id.to_scope_id(self.db(), self.file());
let inference = infer_scope_types(self.db(), scope);
let element_type = inference.expression_type(elt.as_ref());
self.infer_comprehension_specialization(KnownClass::List, &[element_type], tcx)
}
fn infer_dict_comprehension_expression(&mut self, dictcomp: &ast::ExprDictComp) -> Type<'db> {
fn infer_dict_comprehension_expression(
&mut self,
dictcomp: &ast::ExprDictComp,
tcx: TypeContext<'db>,
) -> Type<'db> {
let ast::ExprDictComp {
range: _,
node_index: _,
key: _,
value: _,
key,
value,
generators,
} = dictcomp;
self.infer_first_comprehension_iter(generators);
KnownClass::Dict.to_specialized_instance(
self.db(),
[
todo_type!("dict comprehension key type"),
todo_type!("dict comprehension value type"),
],
)
let scope_id = self
.index
.node_scope(NodeWithScopeRef::DictComprehension(dictcomp));
let scope = scope_id.to_scope_id(self.db(), self.file());
let inference = infer_scope_types(self.db(), scope);
let key_type = inference.expression_type(key.as_ref());
let value_type = inference.expression_type(value.as_ref());
self.infer_comprehension_specialization(KnownClass::Dict, &[key_type, value_type], tcx)
}
fn infer_set_comprehension_expression(&mut self, setcomp: &ast::ExprSetComp) -> Type<'db> {
fn infer_set_comprehension_expression(
&mut self,
setcomp: &ast::ExprSetComp,
tcx: TypeContext<'db>,
) -> Type<'db> {
let ast::ExprSetComp {
range: _,
node_index: _,
elt: _,
elt,
generators,
} = setcomp;
self.infer_first_comprehension_iter(generators);
KnownClass::Set
.to_specialized_instance(self.db(), [todo_type!("set comprehension element type")])
let scope_id = self
.index
.node_scope(NodeWithScopeRef::SetComprehension(setcomp));
let scope = scope_id.to_scope_id(self.db(), self.file());
let inference = infer_scope_types(self.db(), scope);
let element_type = inference.expression_type(elt.as_ref());
self.infer_comprehension_specialization(KnownClass::Set, &[element_type], tcx)
}
fn infer_generator_expression_scope(&mut self, generator: &ast::ExprGenerator) {

View file

@ -346,7 +346,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
}
ast::Expr::DictComp(dictcomp) => {
self.infer_dict_comprehension_expression(dictcomp);
self.infer_dict_comprehension_expression(dictcomp, TypeContext::default());
self.report_invalid_type_expression(
expression,
format_args!("Dict comprehensions are not allowed in type expressions"),
@ -355,7 +355,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
}
ast::Expr::ListComp(listcomp) => {
self.infer_list_comprehension_expression(listcomp);
self.infer_list_comprehension_expression(listcomp, TypeContext::default());
self.report_invalid_type_expression(
expression,
format_args!("List comprehensions are not allowed in type expressions"),
@ -364,7 +364,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
}
ast::Expr::SetComp(setcomp) => {
self.infer_set_comprehension_expression(setcomp);
self.infer_set_comprehension_expression(setcomp, TypeContext::default());
self.report_invalid_type_expression(
expression,
format_args!("Set comprehensions are not allowed in type expressions"),