mirror of
https://github.com/astral-sh/ruff.git
synced 2025-11-18 19:41:34 +00:00
[ty] Type inference for comprehensions (#20962)
## Summary
Adds type inference for list/dict/set comprehensions, including
bidirectional inference:
```py
reveal_type({k: v for k, v in [("a", 1), ("b", 2)]}) # dict[Unknown | str, Unknown | int]
squares: list[int | None] = [x for x in range(10)]
reveal_type(squares) # list[int | None]
```
## Ecosystem impact
I did spot check the changes and most of them seem like known
limitations or true positives. Without proper bidirectional inference,
we saw a lot of false positives.
## Test Plan
New Markdown tests
This commit is contained in:
parent
de1a6fb8ad
commit
73107a083c
9 changed files with 266 additions and 28 deletions
|
|
@ -0,0 +1,18 @@
|
|||
name_1
|
||||
{0: 0 for unique_name_0 in unique_name_1 if name_1}
|
||||
|
||||
|
||||
@[name_2 for unique_name_2 in name_2]
|
||||
def name_2():
|
||||
pass
|
||||
|
||||
|
||||
def name_2():
|
||||
pass
|
||||
|
||||
|
||||
match 0:
|
||||
case name_2():
|
||||
pass
|
||||
case []:
|
||||
name_1 = 0
|
||||
|
|
@ -103,3 +103,92 @@ async def _():
|
|||
# revealed: Unknown
|
||||
[reveal_type(x) async for x in range(3)]
|
||||
```
|
||||
|
||||
## Comprehension expression types
|
||||
|
||||
The type of the comprehension expression itself should reflect the inferred element type:
|
||||
|
||||
```py
|
||||
from typing import TypedDict, Literal
|
||||
|
||||
# revealed: list[int | Unknown]
|
||||
reveal_type([x for x in range(10)])
|
||||
|
||||
# revealed: set[int | Unknown]
|
||||
reveal_type({x for x in range(10)})
|
||||
|
||||
# revealed: dict[int | Unknown, str | Unknown]
|
||||
reveal_type({x: str(x) for x in range(10)})
|
||||
|
||||
# revealed: list[tuple[int, Unknown | str] | Unknown]
|
||||
reveal_type([(x, y) for x in range(5) for y in ["a", "b", "c"]])
|
||||
|
||||
squares: list[int | None] = [x**2 for x in range(10)]
|
||||
reveal_type(squares) # revealed: list[int | None]
|
||||
```
|
||||
|
||||
Inference for comprehensions takes the type context into account:
|
||||
|
||||
```py
|
||||
# Without type context:
|
||||
reveal_type([x for x in [1, 2, 3]]) # revealed: list[Unknown | int]
|
||||
reveal_type({x: "a" for x in [1, 2, 3]}) # revealed: dict[Unknown | int, str | Unknown]
|
||||
reveal_type({str(x): x for x in [1, 2, 3]}) # revealed: dict[str | Unknown, Unknown | int]
|
||||
reveal_type({x for x in [1, 2, 3]}) # revealed: set[Unknown | int]
|
||||
|
||||
# With type context:
|
||||
xs: list[int] = [x for x in [1, 2, 3]]
|
||||
reveal_type(xs) # revealed: list[int]
|
||||
|
||||
ys: dict[int, str] = {x: str(x) for x in [1, 2, 3]}
|
||||
reveal_type(ys) # revealed: dict[int, str]
|
||||
|
||||
zs: set[int] = {x for x in [1, 2, 3]}
|
||||
```
|
||||
|
||||
This also works for nested comprehensions:
|
||||
|
||||
```py
|
||||
table = [[(x, y) for x in range(3)] for y in range(3)]
|
||||
reveal_type(table) # revealed: list[list[tuple[int, int] | Unknown] | Unknown]
|
||||
|
||||
table_with_content: list[list[tuple[int, int, str | None]]] = [[(x, y, None) for x in range(3)] for y in range(3)]
|
||||
reveal_type(table_with_content) # revealed: list[list[tuple[int, int, str | None]]]
|
||||
```
|
||||
|
||||
The type context is propagated down into the comprehension:
|
||||
|
||||
```py
|
||||
class Person(TypedDict):
|
||||
name: str
|
||||
|
||||
persons: list[Person] = [{"name": n} for n in ["Alice", "Bob"]]
|
||||
reveal_type(persons) # revealed: list[Person]
|
||||
|
||||
# TODO: This should be an error
|
||||
invalid: list[Person] = [{"misspelled": n} for n in ["Alice", "Bob"]]
|
||||
```
|
||||
|
||||
We promote literals to avoid overly-precise types in invariant positions:
|
||||
|
||||
```py
|
||||
reveal_type([x for x in ("a", "b", "c")]) # revealed: list[str | Unknown]
|
||||
reveal_type({x for x in (1, 2, 3)}) # revealed: set[int | Unknown]
|
||||
reveal_type({k: 0 for k in ("a", "b", "c")}) # revealed: dict[str | Unknown, int | Unknown]
|
||||
```
|
||||
|
||||
Type context can prevent this promotion from happening:
|
||||
|
||||
```py
|
||||
list_of_literals: list[Literal["a", "b", "c"]] = [x for x in ("a", "b", "c")]
|
||||
reveal_type(list_of_literals) # revealed: list[Literal["a", "b", "c"]]
|
||||
|
||||
dict_with_literal_keys: dict[Literal["a", "b", "c"], int] = {k: 0 for k in ("a", "b", "c")}
|
||||
reveal_type(dict_with_literal_keys) # revealed: dict[Literal["a", "b", "c"], int]
|
||||
|
||||
dict_with_literal_values: dict[str, Literal[1, 2, 3]] = {str(k): k for k in (1, 2, 3)}
|
||||
reveal_type(dict_with_literal_values) # revealed: dict[str, Literal[1, 2, 3]]
|
||||
|
||||
set_with_literals: set[Literal[1, 2, 3]] = {k for k in (1, 2, 3)}
|
||||
reveal_type(set_with_literals) # revealed: set[Literal[1, 2, 3]]
|
||||
```
|
||||
|
|
|
|||
|
|
@ -51,6 +51,6 @@ reveal_type({"a": 1, "b": (1, 2), "c": (1, 2, 3)})
|
|||
## Dict comprehensions
|
||||
|
||||
```py
|
||||
# revealed: dict[@Todo(dict comprehension key type), @Todo(dict comprehension value type)]
|
||||
# revealed: dict[int | Unknown, int | Unknown]
|
||||
reveal_type({x: y for x, y in enumerate(range(42))})
|
||||
```
|
||||
|
|
|
|||
|
|
@ -41,5 +41,5 @@ reveal_type([1, (1, 2), (1, 2, 3)])
|
|||
## List comprehensions
|
||||
|
||||
```py
|
||||
reveal_type([x for x in range(42)]) # revealed: list[@Todo(list comprehension element type)]
|
||||
reveal_type([x for x in range(42)]) # revealed: list[int | Unknown]
|
||||
```
|
||||
|
|
|
|||
|
|
@ -35,5 +35,5 @@ reveal_type({1, (1, 2), (1, 2, 3)})
|
|||
## Set comprehensions
|
||||
|
||||
```py
|
||||
reveal_type({x for x in range(42)}) # revealed: set[@Todo(set comprehension element type)]
|
||||
reveal_type({x for x in range(42)}) # revealed: set[int | Unknown]
|
||||
```
|
||||
|
|
|
|||
|
|
@ -0,0 +1,50 @@
|
|||
# Documentation of two fuzzer panics involving comprehensions
|
||||
|
||||
Type inference for comprehensions was added in <https://github.com/astral-sh/ruff/pull/20962>. It
|
||||
added two new fuzzer panics that are documented here for regression testing.
|
||||
|
||||
## Too many cycle iterations in `place_by_id`
|
||||
|
||||
<!-- expect-panic: too many cycle iterations -->
|
||||
|
||||
```py
|
||||
name_5(name_3)
|
||||
[0 for unique_name_0 in unique_name_1 for unique_name_2 in name_3]
|
||||
|
||||
@{name_3 for unique_name_3 in unique_name_4}
|
||||
class name_4[**name_3](0, name_2=name_5):
|
||||
pass
|
||||
|
||||
try:
|
||||
name_0 = name_4
|
||||
except* 0:
|
||||
pass
|
||||
else:
|
||||
match unique_name_12:
|
||||
case 0:
|
||||
from name_2 import name_3
|
||||
case name_0():
|
||||
|
||||
@name_4
|
||||
def name_3():
|
||||
pass
|
||||
|
||||
(name_3 := 0)
|
||||
|
||||
@name_3
|
||||
async def name_5():
|
||||
pass
|
||||
```
|
||||
|
||||
## Too many cycle iterations in `infer_definition_types`
|
||||
|
||||
<!-- expect-panic: too many cycle iterations -->
|
||||
|
||||
```py
|
||||
for name_1 in {
|
||||
{{0: name_4 for unique_name_0 in unique_name_1}: 0 for unique_name_2 in unique_name_3 if name_4}: 0
|
||||
for unique_name_4 in name_1
|
||||
for name_4 in name_1
|
||||
}:
|
||||
pass
|
||||
```
|
||||
|
|
@ -534,6 +534,14 @@ pub struct FunctionLiteral<'db> {
|
|||
// The Salsa heap is tracked separately.
|
||||
impl get_size2::GetSize for FunctionLiteral<'_> {}
|
||||
|
||||
fn overloads_and_implementation_cycle_initial<'db>(
|
||||
_db: &'db dyn Db,
|
||||
_id: salsa::Id,
|
||||
_self: FunctionLiteral<'db>,
|
||||
) -> (Box<[OverloadLiteral<'db>]>, Option<OverloadLiteral<'db>>) {
|
||||
(Box::new([]), None)
|
||||
}
|
||||
|
||||
#[salsa::tracked]
|
||||
impl<'db> FunctionLiteral<'db> {
|
||||
fn name(self, db: &'db dyn Db) -> &'db ast::name::Name {
|
||||
|
|
@ -576,7 +584,7 @@ impl<'db> FunctionLiteral<'db> {
|
|||
self.last_definition(db).spans(db)
|
||||
}
|
||||
|
||||
#[salsa::tracked(returns(ref), heap_size=ruff_memory_usage::heap_size)]
|
||||
#[salsa::tracked(returns(ref), heap_size=ruff_memory_usage::heap_size, cycle_initial=overloads_and_implementation_cycle_initial)]
|
||||
fn overloads_and_implementation(
|
||||
self,
|
||||
db: &'db dyn Db,
|
||||
|
|
|
|||
|
|
@ -5943,9 +5943,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
ast::Expr::Set(set) => self.infer_set_expression(set, tcx),
|
||||
ast::Expr::Dict(dict) => self.infer_dict_expression(dict, tcx),
|
||||
ast::Expr::Generator(generator) => self.infer_generator_expression(generator),
|
||||
ast::Expr::ListComp(listcomp) => self.infer_list_comprehension_expression(listcomp),
|
||||
ast::Expr::DictComp(dictcomp) => self.infer_dict_comprehension_expression(dictcomp),
|
||||
ast::Expr::SetComp(setcomp) => self.infer_set_comprehension_expression(setcomp),
|
||||
ast::Expr::ListComp(listcomp) => {
|
||||
self.infer_list_comprehension_expression(listcomp, tcx)
|
||||
}
|
||||
ast::Expr::DictComp(dictcomp) => {
|
||||
self.infer_dict_comprehension_expression(dictcomp, tcx)
|
||||
}
|
||||
ast::Expr::SetComp(setcomp) => self.infer_set_comprehension_expression(setcomp, tcx),
|
||||
ast::Expr::Name(name) => self.infer_name_expression(name),
|
||||
ast::Expr::Attribute(attribute) => self.infer_attribute_expression(attribute),
|
||||
ast::Expr::UnaryOp(unary_op) => self.infer_unary_expression(unary_op),
|
||||
|
|
@ -6450,52 +6454,121 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
)
|
||||
}
|
||||
|
||||
fn infer_list_comprehension_expression(&mut self, listcomp: &ast::ExprListComp) -> Type<'db> {
|
||||
/// Return a specialization of the collection class (list, dict, set) based on the type context and the inferred
|
||||
/// element / key-value types from the comprehension expression.
|
||||
fn infer_comprehension_specialization(
|
||||
&self,
|
||||
collection_class: KnownClass,
|
||||
inferred_element_types: &[Type<'db>],
|
||||
tcx: TypeContext<'db>,
|
||||
) -> Type<'db> {
|
||||
// Remove any union elements of that are unrelated to the collection type.
|
||||
let tcx = tcx.map(|annotation| {
|
||||
annotation.filter_disjoint_elements(
|
||||
self.db(),
|
||||
collection_class.to_instance(self.db()),
|
||||
InferableTypeVars::None,
|
||||
)
|
||||
});
|
||||
|
||||
if let Some(annotated_element_types) = tcx
|
||||
.known_specialization(self.db(), collection_class)
|
||||
.map(|specialization| specialization.types(self.db()))
|
||||
&& annotated_element_types
|
||||
.iter()
|
||||
.zip(inferred_element_types.iter())
|
||||
.all(|(annotated, inferred)| inferred.is_assignable_to(self.db(), *annotated))
|
||||
{
|
||||
collection_class
|
||||
.to_specialized_instance(self.db(), annotated_element_types.iter().copied())
|
||||
} else {
|
||||
collection_class.to_specialized_instance(
|
||||
self.db(),
|
||||
inferred_element_types.iter().map(|ty| {
|
||||
UnionType::from_elements(
|
||||
self.db(),
|
||||
[
|
||||
ty.promote_literals(self.db(), TypeContext::default()),
|
||||
Type::unknown(),
|
||||
],
|
||||
)
|
||||
}),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_list_comprehension_expression(
|
||||
&mut self,
|
||||
listcomp: &ast::ExprListComp,
|
||||
tcx: TypeContext<'db>,
|
||||
) -> Type<'db> {
|
||||
let ast::ExprListComp {
|
||||
range: _,
|
||||
node_index: _,
|
||||
elt: _,
|
||||
elt,
|
||||
generators,
|
||||
} = listcomp;
|
||||
|
||||
self.infer_first_comprehension_iter(generators);
|
||||
|
||||
KnownClass::List
|
||||
.to_specialized_instance(self.db(), [todo_type!("list comprehension element type")])
|
||||
let scope_id = self
|
||||
.index
|
||||
.node_scope(NodeWithScopeRef::ListComprehension(listcomp));
|
||||
let scope = scope_id.to_scope_id(self.db(), self.file());
|
||||
let inference = infer_scope_types(self.db(), scope);
|
||||
let element_type = inference.expression_type(elt.as_ref());
|
||||
|
||||
self.infer_comprehension_specialization(KnownClass::List, &[element_type], tcx)
|
||||
}
|
||||
|
||||
fn infer_dict_comprehension_expression(&mut self, dictcomp: &ast::ExprDictComp) -> Type<'db> {
|
||||
fn infer_dict_comprehension_expression(
|
||||
&mut self,
|
||||
dictcomp: &ast::ExprDictComp,
|
||||
tcx: TypeContext<'db>,
|
||||
) -> Type<'db> {
|
||||
let ast::ExprDictComp {
|
||||
range: _,
|
||||
node_index: _,
|
||||
key: _,
|
||||
value: _,
|
||||
key,
|
||||
value,
|
||||
generators,
|
||||
} = dictcomp;
|
||||
|
||||
self.infer_first_comprehension_iter(generators);
|
||||
|
||||
KnownClass::Dict.to_specialized_instance(
|
||||
self.db(),
|
||||
[
|
||||
todo_type!("dict comprehension key type"),
|
||||
todo_type!("dict comprehension value type"),
|
||||
],
|
||||
)
|
||||
let scope_id = self
|
||||
.index
|
||||
.node_scope(NodeWithScopeRef::DictComprehension(dictcomp));
|
||||
let scope = scope_id.to_scope_id(self.db(), self.file());
|
||||
let inference = infer_scope_types(self.db(), scope);
|
||||
let key_type = inference.expression_type(key.as_ref());
|
||||
let value_type = inference.expression_type(value.as_ref());
|
||||
|
||||
self.infer_comprehension_specialization(KnownClass::Dict, &[key_type, value_type], tcx)
|
||||
}
|
||||
|
||||
fn infer_set_comprehension_expression(&mut self, setcomp: &ast::ExprSetComp) -> Type<'db> {
|
||||
fn infer_set_comprehension_expression(
|
||||
&mut self,
|
||||
setcomp: &ast::ExprSetComp,
|
||||
tcx: TypeContext<'db>,
|
||||
) -> Type<'db> {
|
||||
let ast::ExprSetComp {
|
||||
range: _,
|
||||
node_index: _,
|
||||
elt: _,
|
||||
elt,
|
||||
generators,
|
||||
} = setcomp;
|
||||
|
||||
self.infer_first_comprehension_iter(generators);
|
||||
|
||||
KnownClass::Set
|
||||
.to_specialized_instance(self.db(), [todo_type!("set comprehension element type")])
|
||||
let scope_id = self
|
||||
.index
|
||||
.node_scope(NodeWithScopeRef::SetComprehension(setcomp));
|
||||
let scope = scope_id.to_scope_id(self.db(), self.file());
|
||||
let inference = infer_scope_types(self.db(), scope);
|
||||
let element_type = inference.expression_type(elt.as_ref());
|
||||
|
||||
self.infer_comprehension_specialization(KnownClass::Set, &[element_type], tcx)
|
||||
}
|
||||
|
||||
fn infer_generator_expression_scope(&mut self, generator: &ast::ExprGenerator) {
|
||||
|
|
|
|||
|
|
@ -346,7 +346,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
|
|||
}
|
||||
|
||||
ast::Expr::DictComp(dictcomp) => {
|
||||
self.infer_dict_comprehension_expression(dictcomp);
|
||||
self.infer_dict_comprehension_expression(dictcomp, TypeContext::default());
|
||||
self.report_invalid_type_expression(
|
||||
expression,
|
||||
format_args!("Dict comprehensions are not allowed in type expressions"),
|
||||
|
|
@ -355,7 +355,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
|
|||
}
|
||||
|
||||
ast::Expr::ListComp(listcomp) => {
|
||||
self.infer_list_comprehension_expression(listcomp);
|
||||
self.infer_list_comprehension_expression(listcomp, TypeContext::default());
|
||||
self.report_invalid_type_expression(
|
||||
expression,
|
||||
format_args!("List comprehensions are not allowed in type expressions"),
|
||||
|
|
@ -364,7 +364,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
|
|||
}
|
||||
|
||||
ast::Expr::SetComp(setcomp) => {
|
||||
self.infer_set_comprehension_expression(setcomp);
|
||||
self.infer_set_comprehension_expression(setcomp, TypeContext::default());
|
||||
self.report_invalid_type_expression(
|
||||
expression,
|
||||
format_args!("Set comprehensions are not allowed in type expressions"),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue