[ty] use type context more aggressively to infer values ​​when constructing a TypedDict (#20806)

## Summary

Based on @ibraheemdev's comment on #20792:

> I think we can also update our bidirectional inference code, [which
makes the same
assumption](https://github.com/astral-sh/ruff/blob/main/crates/ty_python_semantic/src/types/infer/builder.rs?rgh-link-date=2025-10-09T21%3A30%3A31Z#L5860).

This PR also adds more test cases for how `TypedDict` annotations affect
generic call inference.

## Test Plan

New tests in `typed_dict.md`
This commit is contained in:
Shunsuke Shibayama 2025-10-11 08:51:16 +09:00 committed by GitHub
parent bbd3856de8
commit 11a9e7ee44
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 65 additions and 4 deletions

View file

@ -5,6 +5,11 @@ specific value types for each valid key. Each string key can be either required
## Basic
```toml
[environment]
python-version = "3.12"
```
Here, we define a `TypedDict` using the class-based syntax:
```py
@ -105,6 +110,39 @@ eve3a: Person = {"name": "Eve", "age": 25, "extra": True}
eve3b = Person(name="Eve", age=25, extra=True)
```
Also, the value types declared in a `TypedDict` affect generic call inference:
```py
class Plot(TypedDict):
y: list[int]
x: list[int] | None
plot1: Plot = {"y": [1, 2, 3], "x": None}
def homogeneous_list[T](*args: T) -> list[T]:
return list(args)
reveal_type(homogeneous_list(1, 2, 3)) # revealed: list[Literal[1, 2, 3]]
plot2: Plot = {"y": homogeneous_list(1, 2, 3), "x": None}
reveal_type(plot2["y"]) # revealed: list[int]
# TODO: no error
# error: [invalid-argument-type]
plot3: Plot = {"y": homogeneous_list(1, 2, 3), "x": homogeneous_list(1, 2, 3)}
Y = "y"
X = "x"
plot4: Plot = {Y: [1, 2, 3], X: None}
plot5: Plot = {Y: homogeneous_list(1, 2, 3), X: None}
class Items(TypedDict):
items: list[int | str]
items1: Items = {"items": homogeneous_list(1, 2, 3)}
ITEMS = "items"
items2: Items = {ITEMS: homogeneous_list(1, 2, 3)}
```
Assignments to keys are also validated:
```py
@ -796,6 +834,18 @@ p2: TaggedData[str] = {"data": "Hello", "tag": "text"}
# error: [invalid-argument-type] "Invalid argument to key "data" with declared type `int` on TypedDict `TaggedData`: value of type `Literal["not a number"]`"
p3: TaggedData[int] = {"data": "not a number", "tag": "number"}
class Items(TypedDict, Generic[T]):
items: list[T]
def homogeneous_list(*args: T) -> list[T]:
return list(args)
items1: Items[int] = {"items": [1, 2, 3]}
items2: Items[str] = {"items": ["a", "b", "c"]}
items3: Items[int] = {"items": homogeneous_list(1, 2, 3)}
items4: Items[str] = {"items": homogeneous_list("a", "b", "c")}
items5: Items[int | str] = {"items": homogeneous_list(1, 2, 3)}
```
### PEP-695 generics
@ -817,6 +867,18 @@ p2: TaggedData[str] = {"data": "Hello", "tag": "text"}
# error: [invalid-argument-type] "Invalid argument to key "data" with declared type `int` on TypedDict `TaggedData`: value of type `Literal["not a number"]`"
p3: TaggedData[int] = {"data": "not a number", "tag": "number"}
class Items[T](TypedDict):
items: list[T]
def homogeneous_list[T](*args: T) -> list[T]:
return list(args)
items1: Items[int] = {"items": [1, 2, 3]}
items2: Items[str] = {"items": ["a", "b", "c"]}
items3: Items[int] = {"items": homogeneous_list(1, 2, 3)}
items4: Items[str] = {"items": homogeneous_list("a", "b", "c")}
items5: Items[int | str] = {"items": homogeneous_list(1, 2, 3)}
```
## Recursive `TypedDict`

View file

@ -5858,11 +5858,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let typed_dict_items = typed_dict.items(self.db());
for item in items {
self.infer_optional_expression(item.key.as_ref(), TypeContext::default());
let key_ty = self.infer_optional_expression(item.key.as_ref(), TypeContext::default());
if let Some(ast::Expr::StringLiteral(ref key)) = item.key
&& let Some(key) = key.as_single_part_string()
&& let Some(field) = typed_dict_items.get(key.as_str())
if let Some(Type::StringLiteral(key)) = key_ty
&& let Some(field) = typed_dict_items.get(key.value(self.db()))
{
self.infer_expression(&item.value, TypeContext::new(Some(field.declared_ty)));
} else {