diff --git a/crates/ty_python_semantic/resources/mdtest/typed_dict.md b/crates/ty_python_semantic/resources/mdtest/typed_dict.md index 19c53533a4..d5544b3d79 100644 --- a/crates/ty_python_semantic/resources/mdtest/typed_dict.md +++ b/crates/ty_python_semantic/resources/mdtest/typed_dict.md @@ -450,19 +450,51 @@ class Person(TypedDict, total=False): ```py from typing import TypedDict +from typing_extensions import NotRequired class Person(TypedDict): name: str age: int | None + extra: NotRequired[str] def _(p: Person) -> None: reveal_type(p.keys()) # revealed: dict_keys[str, object] reveal_type(p.values()) # revealed: dict_values[str, object] - reveal_type(p.setdefault("name", "Alice")) # revealed: @Todo(Support for `TypedDict`) + # `get()` returns the field type for required keys (no None union) + reveal_type(p.get("name")) # revealed: str + reveal_type(p.get("age")) # revealed: int | None - reveal_type(p.get("name")) # revealed: @Todo(Support for `TypedDict`) - reveal_type(p.get("name", "Unknown")) # revealed: @Todo(Support for `TypedDict`) + # It doesn't matter if a default is specified: + reveal_type(p.get("name", "default")) # revealed: str + reveal_type(p.get("age", 999)) # revealed: int | None + + # `get()` can return `None` for non-required keys + reveal_type(p.get("extra")) # revealed: str | None + reveal_type(p.get("extra", "default")) # revealed: str + + # The type of the default parameter can be anything: + reveal_type(p.get("extra", 0)) # revealed: str | Literal[0] + + # We allow access to unknown keys (they could be set for a subtype of Person) + reveal_type(p.get("unknown")) # revealed: Unknown | None + reveal_type(p.get("unknown", "default")) # revealed: Unknown | Literal["default"] + + # `pop()` only works on non-required fields + reveal_type(p.pop("extra")) # revealed: str + reveal_type(p.pop("extra", "fallback")) # revealed: str + # error: [invalid-argument-type] "Cannot pop required field 'name' from TypedDict `Person`" + reveal_type(p.pop("name")) # revealed: Unknown + + # Similar to above, the default parameter can be of any type: + reveal_type(p.pop("extra", 0)) # revealed: str | Literal[0] + + # `setdefault()` always returns the field type + reveal_type(p.setdefault("name", "Alice")) # revealed: str + reveal_type(p.setdefault("extra", "default")) # revealed: str + + # error: [invalid-key] "Invalid key access on TypedDict `Person`: Unknown key "extraz" - did you mean "extra"?" + reveal_type(p.setdefault("extraz", "value")) # revealed: Unknown ``` ## Unlike normal classes diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 0b90b9a31c..f75e7beb5e 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -7528,6 +7528,28 @@ pub struct BoundTypeVarInstance<'db> { impl get_size2::GetSize for BoundTypeVarInstance<'_> {} impl<'db> BoundTypeVarInstance<'db> { + /// Create a new PEP 695 type variable that can be used in signatures + /// of synthetic generic functions. + pub(crate) fn synthetic( + db: &'db dyn Db, + name: &'static str, + variance: TypeVarVariance, + ) -> Self { + Self::new( + db, + TypeVarInstance::new( + db, + Name::new_static(name), + None, // definition + None, // _bound_or_constraints + Some(variance), + None, // _default + TypeVarKind::Pep695, + ), + BindingContext::Synthetic, + ) + } + pub(crate) fn variance_with_polarity( self, db: &'db dyn Db, diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index 7f4175eb37..8881729751 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -34,7 +34,7 @@ use crate::types::{ IsEquivalentVisitor, KnownInstanceType, ManualPEP695TypeAliasType, MaterializationKind, NormalizedVisitor, PropertyInstanceType, StringLiteralType, TypeAliasType, TypeMapping, TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, TypedDictParams, - VarianceInferable, declaration_type, infer_definition_types, todo_type, + UnionBuilder, VarianceInferable, declaration_type, infer_definition_types, }; use crate::{ Db, FxIndexMap, FxOrderSet, Program, @@ -51,7 +51,7 @@ use crate::{ semantic_index, use_def_map, }, types::{ - CallArguments, CallError, CallErrorKind, MetaclassCandidate, UnionBuilder, UnionType, + CallArguments, CallError, CallErrorKind, MetaclassCandidate, UnionType, definition_expression_type, }, }; @@ -2331,49 +2331,179 @@ impl<'db> ClassLiteral<'db> { ))) } (CodeGeneratorKind::TypedDict, "get") => { - // TODO: synthesize a set of overloads with precise types - let signature = Signature::new( - Parameters::new([ - Parameter::positional_only(Some(Name::new_static("self"))) - .with_annotated_type(instance_ty), - Parameter::positional_only(Some(Name::new_static("key"))), - Parameter::positional_only(Some(Name::new_static("default"))) - .with_default_type(Type::unknown()), - ]), - Some(todo_type!("Support for `TypedDict`")), - ); + let overloads = self + .fields(db, specialization, field_policy) + .into_iter() + .flat_map(|(name, field)| { + let key_type = + Type::StringLiteral(StringLiteralType::new(db, name.as_str())); - Some(CallableType::function_like(db, signature)) + // For a required key, `.get()` always returns the value type. For a non-required key, + // `.get()` returns the union of the value type and the type of the default argument + // (which defaults to `None`). + + // TODO: For now, we use two overloads here. They can be merged into a single function + // once the generics solver takes default arguments into account. + + let get_sig = Signature::new( + Parameters::new([ + Parameter::positional_only(Some(Name::new_static("self"))) + .with_annotated_type(instance_ty), + Parameter::positional_only(Some(Name::new_static("key"))) + .with_annotated_type(key_type), + ]), + Some(if field.is_required() { + field.declared_ty + } else { + UnionType::from_elements(db, [field.declared_ty, Type::none(db)]) + }), + ); + + let t_default = + BoundTypeVarInstance::synthetic(db, "T", TypeVarVariance::Covariant); + + let get_with_default_sig = Signature::new_generic( + Some(GenericContext::from_typevar_instances(db, [t_default])), + Parameters::new([ + Parameter::positional_only(Some(Name::new_static("self"))) + .with_annotated_type(instance_ty), + Parameter::positional_only(Some(Name::new_static("key"))) + .with_annotated_type(key_type), + Parameter::positional_only(Some(Name::new_static("default"))) + .with_annotated_type(Type::TypeVar(t_default)), + ]), + Some(if field.is_required() { + field.declared_ty + } else { + UnionType::from_elements( + db, + [field.declared_ty, Type::TypeVar(t_default)], + ) + }), + ); + + [get_sig, get_with_default_sig] + }) + // Fallback overloads for unknown keys + .chain(std::iter::once({ + Signature::new( + Parameters::new([ + Parameter::positional_only(Some(Name::new_static("self"))) + .with_annotated_type(instance_ty), + Parameter::positional_only(Some(Name::new_static("key"))) + .with_annotated_type(KnownClass::Str.to_instance(db)), + ]), + Some(UnionType::from_elements( + db, + [Type::unknown(), Type::none(db)], + )), + ) + })) + .chain(std::iter::once({ + let t_default = + BoundTypeVarInstance::synthetic(db, "T", TypeVarVariance::Covariant); + + Signature::new_generic( + Some(GenericContext::from_typevar_instances(db, [t_default])), + Parameters::new([ + Parameter::positional_only(Some(Name::new_static("self"))) + .with_annotated_type(instance_ty), + Parameter::positional_only(Some(Name::new_static("key"))) + .with_annotated_type(KnownClass::Str.to_instance(db)), + Parameter::positional_only(Some(Name::new_static("default"))) + .with_annotated_type(Type::TypeVar(t_default)), + ]), + Some(UnionType::from_elements( + db, + [Type::unknown(), Type::TypeVar(t_default)], + )), + ) + })); + + Some(Type::Callable(CallableType::new( + db, + CallableSignature::from_overloads(overloads), + true, + ))) } (CodeGeneratorKind::TypedDict, "pop") => { - // TODO: synthesize a set of overloads with precise types. - // Required keys should be forbidden to be popped. - let signature = Signature::new( - Parameters::new([ - Parameter::positional_only(Some(Name::new_static("self"))) - .with_annotated_type(instance_ty), - Parameter::positional_only(Some(Name::new_static("key"))), - Parameter::positional_only(Some(Name::new_static("default"))) - .with_default_type(Type::unknown()), - ]), - Some(todo_type!("Support for `TypedDict`")), - ); + let fields = self.fields(db, specialization, field_policy); + let overloads = fields + .iter() + .filter(|(_, field)| { + // Only synthesize `pop` for fields that are not required. + !field.is_required() + }) + .flat_map(|(name, field)| { + let key_type = + Type::StringLiteral(StringLiteralType::new(db, name.as_str())); - Some(CallableType::function_like(db, signature)) + // TODO: Similar to above: consider merging these two overloads into one + + // `.pop()` without default + let pop_sig = Signature::new( + Parameters::new([ + Parameter::positional_only(Some(Name::new_static("self"))) + .with_annotated_type(instance_ty), + Parameter::positional_only(Some(Name::new_static("key"))) + .with_annotated_type(key_type), + ]), + Some(field.declared_ty), + ); + + // `.pop()` with a default value + let t_default = + BoundTypeVarInstance::synthetic(db, "T", TypeVarVariance::Covariant); + + let pop_with_default_sig = Signature::new_generic( + Some(GenericContext::from_typevar_instances(db, [t_default])), + Parameters::new([ + Parameter::positional_only(Some(Name::new_static("self"))) + .with_annotated_type(instance_ty), + Parameter::positional_only(Some(Name::new_static("key"))) + .with_annotated_type(key_type), + Parameter::positional_only(Some(Name::new_static("default"))) + .with_annotated_type(Type::TypeVar(t_default)), + ]), + Some(UnionType::from_elements( + db, + [field.declared_ty, Type::TypeVar(t_default)], + )), + ); + + [pop_sig, pop_with_default_sig] + }); + + Some(Type::Callable(CallableType::new( + db, + CallableSignature::from_overloads(overloads), + true, + ))) } (CodeGeneratorKind::TypedDict, "setdefault") => { - // TODO: synthesize a set of overloads with precise types - let signature = Signature::new( - Parameters::new([ - Parameter::positional_only(Some(Name::new_static("self"))) - .with_annotated_type(instance_ty), - Parameter::positional_only(Some(Name::new_static("key"))), - Parameter::positional_only(Some(Name::new_static("default"))), - ]), - Some(todo_type!("Support for `TypedDict`")), - ); + let fields = self.fields(db, specialization, field_policy); + let overloads = fields.iter().map(|(name, field)| { + let key_type = Type::StringLiteral(StringLiteralType::new(db, name.as_str())); - Some(CallableType::function_like(db, signature)) + // `setdefault` always returns the field type + Signature::new( + Parameters::new([ + Parameter::positional_only(Some(Name::new_static("self"))) + .with_annotated_type(instance_ty), + Parameter::positional_only(Some(Name::new_static("key"))) + .with_annotated_type(key_type), + Parameter::positional_only(Some(Name::new_static("default"))) + .with_annotated_type(field.declared_ty), + ]), + Some(field.declared_ty), + ) + }); + + Some(Type::Callable(CallableType::new( + db, + CallableSignature::from_overloads(overloads), + true, + ))) } (CodeGeneratorKind::TypedDict, "update") => { // TODO: synthesize a set of overloads with precise types diff --git a/crates/ty_python_semantic/src/types/diagnostic.rs b/crates/ty_python_semantic/src/types/diagnostic.rs index 7449c73828..a080300e65 100644 --- a/crates/ty_python_semantic/src/types/diagnostic.rs +++ b/crates/ty_python_semantic/src/types/diagnostic.rs @@ -2952,6 +2952,21 @@ pub(crate) fn report_missing_typed_dict_key<'db>( } } +pub(crate) fn report_cannot_pop_required_field_on_typed_dict<'db>( + context: &InferContext<'db, '_>, + key_node: AnyNodeRef, + typed_dict_ty: Type<'db>, + field_name: &str, +) { + let db = context.db(); + if let Some(builder) = context.report_lint(&INVALID_ARGUMENT_TYPE, key_node) { + let typed_dict_name = typed_dict_ty.display(db); + builder.into_diagnostic(format_args!( + "Cannot pop required field '{field_name}' from TypedDict `{typed_dict_name}`", + )); + } +} + /// This function receives an unresolved `from foo import bar` import, /// where `foo` can be resolved to a module but that module does not /// have a `bar` member or submodule. diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index f4c7c1bdb2..7bc10a1496 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -119,13 +119,19 @@ impl<'db> GenericContext<'db> { binding_context: Definition<'db>, type_params_node: &ast::TypeParams, ) -> Self { - let variables: FxOrderSet<_> = type_params_node - .iter() - .filter_map(|type_param| { - Self::variable_from_type_param(db, index, binding_context, type_param) - }) - .collect(); - Self::new(db, variables) + let variables = type_params_node.iter().filter_map(|type_param| { + Self::variable_from_type_param(db, index, binding_context, type_param) + }); + + Self::from_typevar_instances(db, variables) + } + + /// Creates a generic context from a list of `BoundTypeVarInstance`s. + pub(crate) fn from_typevar_instances( + db: &'db dyn Db, + type_params: impl IntoIterator>, + ) -> Self { + Self::new(db, type_params.into_iter().collect::>()) } fn variable_from_type_param( @@ -365,12 +371,12 @@ impl<'db> GenericContext<'db> { } pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { - let variables: FxOrderSet<_> = self + let variables = self .variables(db) .iter() - .map(|bound_typevar| bound_typevar.normalized_impl(db, visitor)) - .collect(); - Self::new(db, variables) + .map(|bound_typevar| bound_typevar.normalized_impl(db, visitor)); + + Self::from_typevar_instances(db, variables) } fn heap_size((variables,): &(FxOrderSet>,)) -> usize { diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index a120610785..896d075636 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -102,7 +102,8 @@ use crate::types::diagnostic::{ INVALID_TYPE_VARIABLE_CONSTRAINTS, IncompatibleBases, POSSIBLY_UNBOUND_IMPLICIT_CALL, POSSIBLY_UNBOUND_IMPORT, TypeCheckDiagnostics, UNDEFINED_REVEAL, UNRESOLVED_ATTRIBUTE, UNRESOLVED_GLOBAL, UNRESOLVED_IMPORT, UNRESOLVED_REFERENCE, UNSUPPORTED_OPERATOR, - report_bad_dunder_set_call, report_implicit_return_type, report_instance_layout_conflict, + report_bad_dunder_set_call, report_cannot_pop_required_field_on_typed_dict, + report_implicit_return_type, report_instance_layout_conflict, report_invalid_argument_number_to_special_form, report_invalid_arguments_to_annotated, report_invalid_arguments_to_callable, report_invalid_assignment, report_invalid_attribute_assignment, report_invalid_generator_function_return_type, @@ -6270,6 +6271,58 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let callable_type = self.infer_maybe_standalone_expression(func); + // Special handling for `TypedDict` method calls + if let ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) = func.as_ref() { + let value_type = self.expression_type(value); + if let Type::TypedDict(typed_dict_ty) = value_type { + if matches!(attr.id.as_str(), "pop" | "setdefault") && !arguments.args.is_empty() { + // Validate the key argument for `TypedDict` methods + if let Some(first_arg) = arguments.args.first() { + if let ast::Expr::StringLiteral(ast::ExprStringLiteral { + value: key_literal, + .. + }) = first_arg + { + let key = key_literal.to_str(); + let items = typed_dict_ty.items(self.db()); + + // Check if key exists + if let Some((_, field)) = items + .iter() + .find(|(field_name, _)| field_name.as_str() == key) + { + // Key exists - check if it's a `pop()` on a required field + if attr.id.as_str() == "pop" && field.is_required() { + report_cannot_pop_required_field_on_typed_dict( + &self.context, + first_arg.into(), + Type::TypedDict(typed_dict_ty), + key, + ); + return Type::unknown(); + } + } else { + // Key not found, report error with suggestion and return early + let key_ty = Type::StringLiteral( + crate::types::StringLiteralType::new(self.db(), key), + ); + report_invalid_key_on_typed_dict( + &self.context, + first_arg.into(), + first_arg.into(), + Type::TypedDict(typed_dict_ty), + key_ty, + &items, + ); + // Return `Unknown` to prevent the overload system from generating its own error + return Type::unknown(); + } + } + } + } + } + } + if let Type::FunctionLiteral(function) = callable_type { // Make sure that the `function.definition` is only called when the function is defined // in the same file as the one we're currently inferring the types for. This is because @@ -7170,13 +7223,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { /// Infer the type of a [`ast::ExprAttribute`] expression, assuming a load context. fn infer_attribute_load(&mut self, attribute: &ast::ExprAttribute) -> Type<'db> { - let ast::ExprAttribute { - value, - attr, - range: _, - node_index: _, - ctx: _, - } = attribute; + let ast::ExprAttribute { value, attr, .. } = attribute; let value_type = self.infer_maybe_standalone_expression(value); let db = self.db();