[ty] typecheck dict methods for TypedDict (#19874)
Some checks are pending
CI / Determine changes (push) Waiting to run
CI / cargo fmt (push) Waiting to run
CI / cargo clippy (push) Blocked by required conditions
CI / cargo test (linux) (push) Blocked by required conditions
CI / cargo test (linux, release) (push) Blocked by required conditions
CI / cargo test (windows) (push) Blocked by required conditions
CI / cargo test (wasm) (push) Blocked by required conditions
CI / cargo build (release) (push) Waiting to run
CI / cargo build (msrv) (push) Blocked by required conditions
CI / cargo fuzz build (push) Blocked by required conditions
CI / fuzz parser (push) Blocked by required conditions
CI / test scripts (push) Blocked by required conditions
CI / ecosystem (push) Blocked by required conditions
CI / Fuzz for new ty panics (push) Blocked by required conditions
CI / cargo shear (push) Blocked by required conditions
CI / python package (push) Waiting to run
CI / pre-commit (push) Waiting to run
CI / mkdocs (push) Waiting to run
CI / formatter instabilities and black similarity (push) Blocked by required conditions
CI / test ruff-lsp (push) Blocked by required conditions
CI / check playground (push) Blocked by required conditions
CI / benchmarks-instrumented (push) Blocked by required conditions
CI / benchmarks-walltime (push) Blocked by required conditions
[ty Playground] Release / publish (push) Waiting to run

## Summary

Typecheck `get()`, `setdefault()`, `pop()` for `TypedDict`

```py
from typing import TypedDict
from typing_extensions import NotRequired

class Employee(TypedDict):
    name: str
    department: NotRequired[str]

emp = Employee(name="Alice", department="Engineering")

emp.get("name")
emp.get("departmen", "Unknown")
emp.pop("department")
emp.pop("name")
```

<img width="838" height="529" alt="Screenshot 2025-08-12 at 11 42 12"
src="https://github.com/user-attachments/assets/77ce150a-223c-4931-b914-551095d8a3a6"
/>


part of https://github.com/astral-sh/ty/issues/154

## Test Plan

Updated Markdown tests

---------

Co-authored-by: David Peter <mail@david-peter.de>
This commit is contained in:
Eric Jolibois 2025-08-29 16:25:03 +02:00 committed by GitHub
parent c2d7c673ca
commit 5a608f7366
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 312 additions and 60 deletions

View file

@ -450,19 +450,51 @@ class Person(TypedDict, total=False):
```py ```py
from typing import TypedDict from typing import TypedDict
from typing_extensions import NotRequired
class Person(TypedDict): class Person(TypedDict):
name: str name: str
age: int | None age: int | None
extra: NotRequired[str]
def _(p: Person) -> None: def _(p: Person) -> None:
reveal_type(p.keys()) # revealed: dict_keys[str, object] reveal_type(p.keys()) # revealed: dict_keys[str, object]
reveal_type(p.values()) # revealed: dict_values[str, object] reveal_type(p.values()) # revealed: dict_values[str, object]
reveal_type(p.setdefault("name", "Alice")) # revealed: @Todo(Support for `TypedDict`) # `get()` returns the field type for required keys (no None union)
reveal_type(p.get("name")) # revealed: str
reveal_type(p.get("age")) # revealed: int | None
reveal_type(p.get("name")) # revealed: @Todo(Support for `TypedDict`) # It doesn't matter if a default is specified:
reveal_type(p.get("name", "Unknown")) # revealed: @Todo(Support for `TypedDict`) reveal_type(p.get("name", "default")) # revealed: str
reveal_type(p.get("age", 999)) # revealed: int | None
# `get()` can return `None` for non-required keys
reveal_type(p.get("extra")) # revealed: str | None
reveal_type(p.get("extra", "default")) # revealed: str
# The type of the default parameter can be anything:
reveal_type(p.get("extra", 0)) # revealed: str | Literal[0]
# We allow access to unknown keys (they could be set for a subtype of Person)
reveal_type(p.get("unknown")) # revealed: Unknown | None
reveal_type(p.get("unknown", "default")) # revealed: Unknown | Literal["default"]
# `pop()` only works on non-required fields
reveal_type(p.pop("extra")) # revealed: str
reveal_type(p.pop("extra", "fallback")) # revealed: str
# error: [invalid-argument-type] "Cannot pop required field 'name' from TypedDict `Person`"
reveal_type(p.pop("name")) # revealed: Unknown
# Similar to above, the default parameter can be of any type:
reveal_type(p.pop("extra", 0)) # revealed: str | Literal[0]
# `setdefault()` always returns the field type
reveal_type(p.setdefault("name", "Alice")) # revealed: str
reveal_type(p.setdefault("extra", "default")) # revealed: str
# error: [invalid-key] "Invalid key access on TypedDict `Person`: Unknown key "extraz" - did you mean "extra"?"
reveal_type(p.setdefault("extraz", "value")) # revealed: Unknown
``` ```
## Unlike normal classes ## Unlike normal classes

View file

@ -7528,6 +7528,28 @@ pub struct BoundTypeVarInstance<'db> {
impl get_size2::GetSize for BoundTypeVarInstance<'_> {} impl get_size2::GetSize for BoundTypeVarInstance<'_> {}
impl<'db> BoundTypeVarInstance<'db> { impl<'db> BoundTypeVarInstance<'db> {
/// Create a new PEP 695 type variable that can be used in signatures
/// of synthetic generic functions.
pub(crate) fn synthetic(
db: &'db dyn Db,
name: &'static str,
variance: TypeVarVariance,
) -> Self {
Self::new(
db,
TypeVarInstance::new(
db,
Name::new_static(name),
None, // definition
None, // _bound_or_constraints
Some(variance),
None, // _default
TypeVarKind::Pep695,
),
BindingContext::Synthetic,
)
}
pub(crate) fn variance_with_polarity( pub(crate) fn variance_with_polarity(
self, self,
db: &'db dyn Db, db: &'db dyn Db,

View file

@ -34,7 +34,7 @@ use crate::types::{
IsEquivalentVisitor, KnownInstanceType, ManualPEP695TypeAliasType, MaterializationKind, IsEquivalentVisitor, KnownInstanceType, ManualPEP695TypeAliasType, MaterializationKind,
NormalizedVisitor, PropertyInstanceType, StringLiteralType, TypeAliasType, TypeMapping, NormalizedVisitor, PropertyInstanceType, StringLiteralType, TypeAliasType, TypeMapping,
TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, TypedDictParams, TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, TypedDictParams,
VarianceInferable, declaration_type, infer_definition_types, todo_type, UnionBuilder, VarianceInferable, declaration_type, infer_definition_types,
}; };
use crate::{ use crate::{
Db, FxIndexMap, FxOrderSet, Program, Db, FxIndexMap, FxOrderSet, Program,
@ -51,7 +51,7 @@ use crate::{
semantic_index, use_def_map, semantic_index, use_def_map,
}, },
types::{ types::{
CallArguments, CallError, CallErrorKind, MetaclassCandidate, UnionBuilder, UnionType, CallArguments, CallError, CallErrorKind, MetaclassCandidate, UnionType,
definition_expression_type, definition_expression_type,
}, },
}; };
@ -2331,49 +2331,179 @@ impl<'db> ClassLiteral<'db> {
))) )))
} }
(CodeGeneratorKind::TypedDict, "get") => { (CodeGeneratorKind::TypedDict, "get") => {
// TODO: synthesize a set of overloads with precise types let overloads = self
let signature = Signature::new( .fields(db, specialization, field_policy)
Parameters::new([ .into_iter()
Parameter::positional_only(Some(Name::new_static("self"))) .flat_map(|(name, field)| {
.with_annotated_type(instance_ty), let key_type =
Parameter::positional_only(Some(Name::new_static("key"))), Type::StringLiteral(StringLiteralType::new(db, name.as_str()));
Parameter::positional_only(Some(Name::new_static("default")))
.with_default_type(Type::unknown()),
]),
Some(todo_type!("Support for `TypedDict`")),
);
Some(CallableType::function_like(db, signature)) // For a required key, `.get()` always returns the value type. For a non-required key,
// `.get()` returns the union of the value type and the type of the default argument
// (which defaults to `None`).
// TODO: For now, we use two overloads here. They can be merged into a single function
// once the generics solver takes default arguments into account.
let get_sig = Signature::new(
Parameters::new([
Parameter::positional_only(Some(Name::new_static("self")))
.with_annotated_type(instance_ty),
Parameter::positional_only(Some(Name::new_static("key")))
.with_annotated_type(key_type),
]),
Some(if field.is_required() {
field.declared_ty
} else {
UnionType::from_elements(db, [field.declared_ty, Type::none(db)])
}),
);
let t_default =
BoundTypeVarInstance::synthetic(db, "T", TypeVarVariance::Covariant);
let get_with_default_sig = Signature::new_generic(
Some(GenericContext::from_typevar_instances(db, [t_default])),
Parameters::new([
Parameter::positional_only(Some(Name::new_static("self")))
.with_annotated_type(instance_ty),
Parameter::positional_only(Some(Name::new_static("key")))
.with_annotated_type(key_type),
Parameter::positional_only(Some(Name::new_static("default")))
.with_annotated_type(Type::TypeVar(t_default)),
]),
Some(if field.is_required() {
field.declared_ty
} else {
UnionType::from_elements(
db,
[field.declared_ty, Type::TypeVar(t_default)],
)
}),
);
[get_sig, get_with_default_sig]
})
// Fallback overloads for unknown keys
.chain(std::iter::once({
Signature::new(
Parameters::new([
Parameter::positional_only(Some(Name::new_static("self")))
.with_annotated_type(instance_ty),
Parameter::positional_only(Some(Name::new_static("key")))
.with_annotated_type(KnownClass::Str.to_instance(db)),
]),
Some(UnionType::from_elements(
db,
[Type::unknown(), Type::none(db)],
)),
)
}))
.chain(std::iter::once({
let t_default =
BoundTypeVarInstance::synthetic(db, "T", TypeVarVariance::Covariant);
Signature::new_generic(
Some(GenericContext::from_typevar_instances(db, [t_default])),
Parameters::new([
Parameter::positional_only(Some(Name::new_static("self")))
.with_annotated_type(instance_ty),
Parameter::positional_only(Some(Name::new_static("key")))
.with_annotated_type(KnownClass::Str.to_instance(db)),
Parameter::positional_only(Some(Name::new_static("default")))
.with_annotated_type(Type::TypeVar(t_default)),
]),
Some(UnionType::from_elements(
db,
[Type::unknown(), Type::TypeVar(t_default)],
)),
)
}));
Some(Type::Callable(CallableType::new(
db,
CallableSignature::from_overloads(overloads),
true,
)))
} }
(CodeGeneratorKind::TypedDict, "pop") => { (CodeGeneratorKind::TypedDict, "pop") => {
// TODO: synthesize a set of overloads with precise types. let fields = self.fields(db, specialization, field_policy);
// Required keys should be forbidden to be popped. let overloads = fields
let signature = Signature::new( .iter()
Parameters::new([ .filter(|(_, field)| {
Parameter::positional_only(Some(Name::new_static("self"))) // Only synthesize `pop` for fields that are not required.
.with_annotated_type(instance_ty), !field.is_required()
Parameter::positional_only(Some(Name::new_static("key"))), })
Parameter::positional_only(Some(Name::new_static("default"))) .flat_map(|(name, field)| {
.with_default_type(Type::unknown()), let key_type =
]), Type::StringLiteral(StringLiteralType::new(db, name.as_str()));
Some(todo_type!("Support for `TypedDict`")),
);
Some(CallableType::function_like(db, signature)) // TODO: Similar to above: consider merging these two overloads into one
// `.pop()` without default
let pop_sig = Signature::new(
Parameters::new([
Parameter::positional_only(Some(Name::new_static("self")))
.with_annotated_type(instance_ty),
Parameter::positional_only(Some(Name::new_static("key")))
.with_annotated_type(key_type),
]),
Some(field.declared_ty),
);
// `.pop()` with a default value
let t_default =
BoundTypeVarInstance::synthetic(db, "T", TypeVarVariance::Covariant);
let pop_with_default_sig = Signature::new_generic(
Some(GenericContext::from_typevar_instances(db, [t_default])),
Parameters::new([
Parameter::positional_only(Some(Name::new_static("self")))
.with_annotated_type(instance_ty),
Parameter::positional_only(Some(Name::new_static("key")))
.with_annotated_type(key_type),
Parameter::positional_only(Some(Name::new_static("default")))
.with_annotated_type(Type::TypeVar(t_default)),
]),
Some(UnionType::from_elements(
db,
[field.declared_ty, Type::TypeVar(t_default)],
)),
);
[pop_sig, pop_with_default_sig]
});
Some(Type::Callable(CallableType::new(
db,
CallableSignature::from_overloads(overloads),
true,
)))
} }
(CodeGeneratorKind::TypedDict, "setdefault") => { (CodeGeneratorKind::TypedDict, "setdefault") => {
// TODO: synthesize a set of overloads with precise types let fields = self.fields(db, specialization, field_policy);
let signature = Signature::new( let overloads = fields.iter().map(|(name, field)| {
Parameters::new([ let key_type = Type::StringLiteral(StringLiteralType::new(db, name.as_str()));
Parameter::positional_only(Some(Name::new_static("self")))
.with_annotated_type(instance_ty),
Parameter::positional_only(Some(Name::new_static("key"))),
Parameter::positional_only(Some(Name::new_static("default"))),
]),
Some(todo_type!("Support for `TypedDict`")),
);
Some(CallableType::function_like(db, signature)) // `setdefault` always returns the field type
Signature::new(
Parameters::new([
Parameter::positional_only(Some(Name::new_static("self")))
.with_annotated_type(instance_ty),
Parameter::positional_only(Some(Name::new_static("key")))
.with_annotated_type(key_type),
Parameter::positional_only(Some(Name::new_static("default")))
.with_annotated_type(field.declared_ty),
]),
Some(field.declared_ty),
)
});
Some(Type::Callable(CallableType::new(
db,
CallableSignature::from_overloads(overloads),
true,
)))
} }
(CodeGeneratorKind::TypedDict, "update") => { (CodeGeneratorKind::TypedDict, "update") => {
// TODO: synthesize a set of overloads with precise types // TODO: synthesize a set of overloads with precise types

View file

@ -2952,6 +2952,21 @@ pub(crate) fn report_missing_typed_dict_key<'db>(
} }
} }
pub(crate) fn report_cannot_pop_required_field_on_typed_dict<'db>(
context: &InferContext<'db, '_>,
key_node: AnyNodeRef,
typed_dict_ty: Type<'db>,
field_name: &str,
) {
let db = context.db();
if let Some(builder) = context.report_lint(&INVALID_ARGUMENT_TYPE, key_node) {
let typed_dict_name = typed_dict_ty.display(db);
builder.into_diagnostic(format_args!(
"Cannot pop required field '{field_name}' from TypedDict `{typed_dict_name}`",
));
}
}
/// This function receives an unresolved `from foo import bar` import, /// This function receives an unresolved `from foo import bar` import,
/// where `foo` can be resolved to a module but that module does not /// where `foo` can be resolved to a module but that module does not
/// have a `bar` member or submodule. /// have a `bar` member or submodule.

View file

@ -119,13 +119,19 @@ impl<'db> GenericContext<'db> {
binding_context: Definition<'db>, binding_context: Definition<'db>,
type_params_node: &ast::TypeParams, type_params_node: &ast::TypeParams,
) -> Self { ) -> Self {
let variables: FxOrderSet<_> = type_params_node let variables = type_params_node.iter().filter_map(|type_param| {
.iter() Self::variable_from_type_param(db, index, binding_context, type_param)
.filter_map(|type_param| { });
Self::variable_from_type_param(db, index, binding_context, type_param)
}) Self::from_typevar_instances(db, variables)
.collect(); }
Self::new(db, variables)
/// Creates a generic context from a list of `BoundTypeVarInstance`s.
pub(crate) fn from_typevar_instances(
db: &'db dyn Db,
type_params: impl IntoIterator<Item = BoundTypeVarInstance<'db>>,
) -> Self {
Self::new(db, type_params.into_iter().collect::<FxOrderSet<_>>())
} }
fn variable_from_type_param( fn variable_from_type_param(
@ -365,12 +371,12 @@ impl<'db> GenericContext<'db> {
} }
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
let variables: FxOrderSet<_> = self let variables = self
.variables(db) .variables(db)
.iter() .iter()
.map(|bound_typevar| bound_typevar.normalized_impl(db, visitor)) .map(|bound_typevar| bound_typevar.normalized_impl(db, visitor));
.collect();
Self::new(db, variables) Self::from_typevar_instances(db, variables)
} }
fn heap_size((variables,): &(FxOrderSet<BoundTypeVarInstance<'db>>,)) -> usize { fn heap_size((variables,): &(FxOrderSet<BoundTypeVarInstance<'db>>,)) -> usize {

View file

@ -102,7 +102,8 @@ use crate::types::diagnostic::{
INVALID_TYPE_VARIABLE_CONSTRAINTS, IncompatibleBases, POSSIBLY_UNBOUND_IMPLICIT_CALL, INVALID_TYPE_VARIABLE_CONSTRAINTS, IncompatibleBases, POSSIBLY_UNBOUND_IMPLICIT_CALL,
POSSIBLY_UNBOUND_IMPORT, TypeCheckDiagnostics, UNDEFINED_REVEAL, UNRESOLVED_ATTRIBUTE, POSSIBLY_UNBOUND_IMPORT, TypeCheckDiagnostics, UNDEFINED_REVEAL, UNRESOLVED_ATTRIBUTE,
UNRESOLVED_GLOBAL, UNRESOLVED_IMPORT, UNRESOLVED_REFERENCE, UNSUPPORTED_OPERATOR, UNRESOLVED_GLOBAL, UNRESOLVED_IMPORT, UNRESOLVED_REFERENCE, UNSUPPORTED_OPERATOR,
report_bad_dunder_set_call, report_implicit_return_type, report_instance_layout_conflict, report_bad_dunder_set_call, report_cannot_pop_required_field_on_typed_dict,
report_implicit_return_type, report_instance_layout_conflict,
report_invalid_argument_number_to_special_form, report_invalid_arguments_to_annotated, report_invalid_argument_number_to_special_form, report_invalid_arguments_to_annotated,
report_invalid_arguments_to_callable, report_invalid_assignment, report_invalid_arguments_to_callable, report_invalid_assignment,
report_invalid_attribute_assignment, report_invalid_generator_function_return_type, report_invalid_attribute_assignment, report_invalid_generator_function_return_type,
@ -6270,6 +6271,58 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let callable_type = self.infer_maybe_standalone_expression(func); let callable_type = self.infer_maybe_standalone_expression(func);
// Special handling for `TypedDict` method calls
if let ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) = func.as_ref() {
let value_type = self.expression_type(value);
if let Type::TypedDict(typed_dict_ty) = value_type {
if matches!(attr.id.as_str(), "pop" | "setdefault") && !arguments.args.is_empty() {
// Validate the key argument for `TypedDict` methods
if let Some(first_arg) = arguments.args.first() {
if let ast::Expr::StringLiteral(ast::ExprStringLiteral {
value: key_literal,
..
}) = first_arg
{
let key = key_literal.to_str();
let items = typed_dict_ty.items(self.db());
// Check if key exists
if let Some((_, field)) = items
.iter()
.find(|(field_name, _)| field_name.as_str() == key)
{
// Key exists - check if it's a `pop()` on a required field
if attr.id.as_str() == "pop" && field.is_required() {
report_cannot_pop_required_field_on_typed_dict(
&self.context,
first_arg.into(),
Type::TypedDict(typed_dict_ty),
key,
);
return Type::unknown();
}
} else {
// Key not found, report error with suggestion and return early
let key_ty = Type::StringLiteral(
crate::types::StringLiteralType::new(self.db(), key),
);
report_invalid_key_on_typed_dict(
&self.context,
first_arg.into(),
first_arg.into(),
Type::TypedDict(typed_dict_ty),
key_ty,
&items,
);
// Return `Unknown` to prevent the overload system from generating its own error
return Type::unknown();
}
}
}
}
}
}
if let Type::FunctionLiteral(function) = callable_type { if let Type::FunctionLiteral(function) = callable_type {
// Make sure that the `function.definition` is only called when the function is defined // Make sure that the `function.definition` is only called when the function is defined
// in the same file as the one we're currently inferring the types for. This is because // in the same file as the one we're currently inferring the types for. This is because
@ -7170,13 +7223,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
/// Infer the type of a [`ast::ExprAttribute`] expression, assuming a load context. /// Infer the type of a [`ast::ExprAttribute`] expression, assuming a load context.
fn infer_attribute_load(&mut self, attribute: &ast::ExprAttribute) -> Type<'db> { fn infer_attribute_load(&mut self, attribute: &ast::ExprAttribute) -> Type<'db> {
let ast::ExprAttribute { let ast::ExprAttribute { value, attr, .. } = attribute;
value,
attr,
range: _,
node_index: _,
ctx: _,
} = attribute;
let value_type = self.infer_maybe_standalone_expression(value); let value_type = self.infer_maybe_standalone_expression(value);
let db = self.db(); let db = self.db();