From 0cdca09e97945d9de46782a3095741abb432955c Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Fri, 26 Sep 2025 21:46:04 -0400 Subject: [PATCH] avoid eagerly reporting `TypedDict` literal diagnostics during overload evaluation --- .../resources/mdtest/call/overloads.md | 29 ++++--- .../ty_python_semantic/src/types/call/bind.rs | 2 +- crates/ty_python_semantic/src/types/infer.rs | 42 +++++++-- .../src/types/infer/builder.rs | 71 ++++++++++++--- .../src/types/typed_dict.rs | 87 +++++++++++++------ 5 files changed, 171 insertions(+), 60 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/call/overloads.md b/crates/ty_python_semantic/resources/mdtest/call/overloads.md index cd1fede988..f793d11ec9 100644 --- a/crates/ty_python_semantic/resources/mdtest/call/overloads.md +++ b/crates/ty_python_semantic/resources/mdtest/call/overloads.md @@ -1694,27 +1694,28 @@ class T(TypedDict): x: int @overload -def f(a: list[T], b: int): +def f(a: list[T], b: int) -> int: ... @overload -def f(a: list[dict[str, int]], b: str): +def f(a: list[dict[str, int]], b: str) -> str: ... -def f(a: list[dict[str, int]] | list[T], b: int | str): - ... +def f(a: list[dict[str, int]] | list[T], b: int | str) -> int | str: + return 1 def int_or_str() -> int | str: return 1 -f([{ "x": 1 }], int_or_str()) +x = f([{ "x": 1 }], int_or_str()) +reveal_type(x) # revealed: int | str -# error: [missing-typed-dict-key] "Missing required key 'x' in TypedDict `T` constructor" -# error: [invalid-key] "Invalid key access on TypedDict `T`: Unknown key "y"" +# we currently incorrectly consider `list[dict[str, int]]` a subtype of `list[T]` +# TODO: error: [no-matching-overload] "No overload of function `f` matches arguments" f([{ "y": 1 }], int_or_str()) ``` -Non-matching overloads do not produce errors: +Non-matching overloads do not produce diagnostics: ```py from typing import TypedDict, overload @@ -1723,16 +1724,16 @@ class T(TypedDict): x: int @overload -def f(a: T, b: int): +def f(a: T, b: int) -> int: ... @overload -def f(a: dict[str, int], b: str): +def f(a: dict[str, int], b: str) -> str: ... -def f(a: T | dict[str, int], b: int | str): - ... +def f(a: T | dict[str, int], b: int | str) -> int | str: + return 1 -# TODO -f({ "y": 1 }, "a") +x = f({ "y": 1 }, "a") +reveal_type(x) # revealed: str ``` diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index ff20563749..2fc25200c5 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -2467,7 +2467,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { // an annotated assignment, to closer match the order of any unions written in the type // annotation. if let Some(return_ty) = self.signature.return_ty - && let Some(call_expression_tcx) = self.call_expression_tcx.annotation + && let Some(call_expression_tcx) = self.call_expression_tcx.annotation() { match call_expression_tcx { // A type variable is not a useful type-context for expression inference, and applying it diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 22aeaba771..06aacee92d 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -330,7 +330,7 @@ impl<'db> InferExpression<'db> { expression: Expression<'db>, tcx: TypeContext<'db>, ) -> InferExpression<'db> { - if tcx.annotation.is_some() { + if tcx.annotation().is_some() { InferExpression::WithContext(ExpressionWithContext::new(db, expression, tcx)) } else { // Drop the empty `TypeContext` to avoid the interning cost. @@ -370,23 +370,51 @@ struct ExpressionWithContext<'db> { /// Knowing the outer type context when inferring an expression can enable /// more precise inference results, aka "bidirectional type inference". #[derive(Default, Copy, Clone, Debug, PartialEq, Eq, Hash, get_size2::GetSize)] -pub(crate) struct TypeContext<'db> { - pub(crate) annotation: Option>, +pub(crate) enum TypeContext<'db> { + /// No type context. + #[default] + None, + + /// The type of an annotated assignment, used as context for the RHS. + AnnotatedAssignment(Type<'db>), + + /// The type of an annotated parameter, used as context for arguments that are + /// matched against it in function call expressions. + AnnotatedParameter(Type<'db>), + + /// The unique type of annotated parameter, with no overloads. + UniqueAnnotatedParameter(Type<'db>), } impl<'db> TypeContext<'db> { - pub(crate) fn new(annotation: Option>) -> Self { - Self { annotation } + pub(crate) fn annotation(self) -> Option> { + match self { + TypeContext::None => None, + TypeContext::AnnotatedAssignment(annotation) => Some(annotation), + TypeContext::AnnotatedParameter(annotation) => Some(annotation), + TypeContext::UniqueAnnotatedParameter(annotation) => Some(annotation), + } + } + + pub(crate) fn with_annotation(self, annotation: Type<'db>) -> TypeContext<'db> { + match self { + TypeContext::None => panic!("no annotation to modify"), + TypeContext::AnnotatedAssignment(_) => TypeContext::AnnotatedAssignment(annotation), + TypeContext::AnnotatedParameter(_) => TypeContext::AnnotatedParameter(annotation), + TypeContext::UniqueAnnotatedParameter(_) => { + TypeContext::UniqueAnnotatedParameter(annotation) + } + } } // If the type annotation is a specialized instance of the given `KnownClass`, returns the // specialization. fn known_specialization( - &self, + self, known_class: KnownClass, db: &'db dyn Db, ) -> Option> { - self.annotation + self.annotation() .and_then(|ty| ty.known_specialization(known_class, db)) } } diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index a47078568b..1ab1ab5b90 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -3263,6 +3263,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { key, assigned_ty, value.as_ref(), + true, slice.as_ref(), rhs, TypedDictAssignmentKind::Subscript, @@ -4019,7 +4020,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { if let Some(value) = value { self.infer_maybe_standalone_expression( value, - TypeContext::new(Some(annotated.inner_type())), + TypeContext::AnnotatedAssignment(annotated.inner_type()), ); } @@ -4114,7 +4115,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { if let Some(value) = value { let inferred_ty = self.infer_maybe_standalone_expression( value, - TypeContext::new(Some(declared.inner_type())), + TypeContext::AnnotatedAssignment(declared.inner_type()), ); let inferred_ty = if target .as_name_expr() @@ -4973,6 +4974,21 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ast_arguments.arguments_source_order() ); + // If there is only a single binding and overload, there is a unique parameter type annotation for each argument. + let unique_parameter_type = bindings + .into_iter() + .map(|binding| match binding.matching_overload_index() { + MatchingOverloadIndex::Single(_) => 1, + MatchingOverloadIndex::Multiple(items) => items.len(), + MatchingOverloadIndex::None => { + // If there is a single overload that does not match, we still infer the + // arguments against it for better diagnostics. + if binding.overloads().len() == 1 { 1 } else { 0 } + } + }) + .sum::() + == 1; + for (argument_index, (_, argument_type), argument_form, ast_argument) in iter { let ast_argument = match ast_argument { // Splatted arguments are inferred before parameter matching to @@ -4995,13 +5011,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // Otherwise, we infer the type of each argument once for each matching overload signature, // with the given annotated type as type context. for binding in bindings { + // TODO: What if there are multiple bindings? let overloads = match binding.matching_overload_index() { MatchingOverloadIndex::Single(_) | MatchingOverloadIndex::Multiple(_) => { Either::Right(binding.matching_overloads()) } // If there is a single overload that does not match, we still infer the - // argument types for better diagnostics. + // arguments against it for better diagnostics. MatchingOverloadIndex::None => match binding.overloads() { [overload] => Either::Left([(0, overload)].into_iter()), _ => continue, @@ -5027,7 +5044,17 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let parameter_type = overload.signature.parameters()[*parameter_index].annotated_type(); - self.infer_expression_impl(ast_argument, TypeContext::new(parameter_type)); + let tcx = parameter_type + .map(|ty| { + if unique_parameter_type { + TypeContext::UniqueAnnotatedParameter(ty) + } else { + TypeContext::AnnotatedParameter(ty) + } + }) + .unwrap_or_default(); + + self.infer_expression_impl(ast_argument, tcx); } *argument_type = self.try_expression_type(ast_argument); @@ -5374,7 +5401,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let divergent = Type::divergent(self.scope()); let element_types = elts.iter().map(|element| { let annotated_elt_ty = annotated_elt_tys.as_mut().and_then(Iterator::next).copied(); - let element_type = self.infer_expression(element, TypeContext::new(annotated_elt_ty)); + let elt_tcx = annotated_elt_ty + .map(|ty| tcx.with_annotation(ty)) + .unwrap_or_default(); + let element_type = self.infer_expression(element, elt_tcx); if element_type.has_divergent_type(self.db(), divergent) { divergent @@ -5423,7 +5453,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } = dict; // Validate `TypedDict` dictionary literal assignments. - if let Some(typed_dict) = tcx.annotation.and_then(Type::into_typed_dict) { + if let Some(typed_dict) = tcx.annotation().and_then(Type::into_typed_dict) { let typed_dict_items = typed_dict.items(self.db()); for item in items { @@ -5433,28 +5463,43 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { && let Some(key) = key.as_single_part_string() && let Some(field) = typed_dict_items.get(key.as_str()) { - self.infer_expression(&item.value, TypeContext::new(Some(field.declared_ty))); + self.infer_expression(&item.value, tcx.with_annotation(field.declared_ty)); } else { self.infer_expression(&item.value, TypeContext::default()); } } - // TODO: Fall-back to `Unknown` so that we don't eagerly error when matching against a - // potential overload. - validate_typed_dict_dict_literal( + let report_diagnostics = match tcx { + // Do not eagerly report diagnostics when performing overload evaluation + // with multiple potential overloads. + TypeContext::AnnotatedParameter(_) => false, + _ => true, + }; + + let result = validate_typed_dict_dict_literal( &self.context, typed_dict, dict, dict.into(), + report_diagnostics, |expr| self.expression_type(expr), ); - return Type::TypedDict(typed_dict); + match result { + // Successfully validated the dictionary literal. + Ok(_) => return Type::TypedDict(typed_dict), + + // The dictionary is not valid, but we are eagerly reporting diagnostics. + Err(_) if report_diagnostics => return Type::TypedDict(typed_dict), + + // Otherwise, fallback to an untyped dictionary literal. + Err(_) => {} + } } // Avoid false positives for the functional `TypedDict` form, which is currently // unsupported. - if let Some(Type::Dynamic(DynamicType::Todo(_))) = tcx.annotation { + if let Some(Type::Dynamic(DynamicType::Todo(_))) = tcx.annotation() { return KnownClass::Dict .to_specialized_instance(self.db(), [Type::unknown(), Type::unknown()]); } @@ -5526,7 +5571,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let elt_tcxs = match annotated_elt_tys { None => Either::Left(iter::repeat(TypeContext::default())), - Some(tys) => Either::Right(tys.iter().map(|ty| TypeContext::new(Some(*ty)))), + Some(tys) => Either::Right(tys.iter().map(|ty| tcx.with_annotation(*ty))), }; for elts in elts { diff --git a/crates/ty_python_semantic/src/types/typed_dict.rs b/crates/ty_python_semantic/src/types/typed_dict.rs index c1b241093b..1144f36fa2 100644 --- a/crates/ty_python_semantic/src/types/typed_dict.rs +++ b/crates/ty_python_semantic/src/types/typed_dict.rs @@ -132,6 +132,7 @@ impl TypedDictAssignmentKind { } /// Validates assignment of a value to a specific key on a `TypedDict`. +/// /// Returns true if the assignment is valid, false otherwise. #[allow(clippy::too_many_arguments)] pub(super) fn validate_typed_dict_key_assignment<'db, 'ast>( @@ -140,6 +141,7 @@ pub(super) fn validate_typed_dict_key_assignment<'db, 'ast>( key: &str, value_ty: Type<'db>, typed_dict_node: impl Into>, + report_diagnostics: bool, key_node: impl Into>, value_node: impl Into>, assignment_kind: TypedDictAssignmentKind, @@ -149,14 +151,17 @@ pub(super) fn validate_typed_dict_key_assignment<'db, 'ast>( // Check if key exists in `TypedDict` let Some((_, item)) = items.iter().find(|(name, _)| *name == key) else { - report_invalid_key_on_typed_dict( - context, - typed_dict_node.into(), - key_node.into(), - Type::TypedDict(typed_dict), - Type::string_literal(db, key), - &items, - ); + if report_diagnostics { + report_invalid_key_on_typed_dict( + context, + typed_dict_node.into(), + key_node.into(), + Type::TypedDict(typed_dict), + Type::string_literal(db, key), + &items, + ); + } + return false; }; @@ -177,8 +182,9 @@ pub(super) fn validate_typed_dict_key_assignment<'db, 'ast>( }; if assignment_kind.is_subscript() && item.is_read_only() { - if let Some(builder) = - context.report_lint(assignment_kind.diagnostic_type(), key_node.into()) + if report_diagnostics + && let Some(builder) = + context.report_lint(assignment_kind.diagnostic_type(), key_node.into()) { let typed_dict_ty = Type::TypedDict(typed_dict); let typed_dict_d = typed_dict_ty.display(db); @@ -207,7 +213,9 @@ pub(super) fn validate_typed_dict_key_assignment<'db, 'ast>( } // Invalid assignment - emit diagnostic - if let Some(builder) = context.report_lint(assignment_kind.diagnostic_type(), value_node.into()) + if report_diagnostics + && let Some(builder) = + context.report_lint(assignment_kind.diagnostic_type(), value_node.into()) { let typed_dict_ty = Type::TypedDict(typed_dict); let typed_dict_d = typed_dict_ty.display(db); @@ -240,13 +248,17 @@ pub(super) fn validate_typed_dict_key_assignment<'db, 'ast>( } /// Validates that all required keys are provided in a `TypedDict` construction. +/// /// Reports errors for any keys that are required but not provided. +/// +/// Returns true if the assignment is valid, false otherwise. pub(super) fn validate_typed_dict_required_keys<'db, 'ast>( context: &InferContext<'db, 'ast>, typed_dict: TypedDictType<'db>, provided_keys: &OrderSet<&str>, error_node: AnyNodeRef<'ast>, -) { + report_diagnostics: bool, +) -> bool { let db = context.db(); let items = typed_dict.items(db); @@ -255,14 +267,23 @@ pub(super) fn validate_typed_dict_required_keys<'db, 'ast>( .filter_map(|(key_name, field)| field.is_required().then_some(key_name.as_str())) .collect(); - for missing_key in required_keys.difference(provided_keys) { - report_missing_typed_dict_key( - context, - error_node, - Type::TypedDict(typed_dict), - missing_key, - ); + let missing_keys = required_keys.difference(provided_keys); + + let mut has_missing_key = false; + for missing_key in missing_keys { + has_missing_key = true; + + if report_diagnostics { + report_missing_typed_dict_key( + context, + error_node, + Type::TypedDict(typed_dict), + missing_key, + ); + } } + + !has_missing_key } pub(super) fn validate_typed_dict_constructor<'db, 'ast>( @@ -292,7 +313,7 @@ pub(super) fn validate_typed_dict_constructor<'db, 'ast>( ) }; - validate_typed_dict_required_keys(context, typed_dict, &provided_keys, error_node); + validate_typed_dict_required_keys(context, typed_dict, &provided_keys, error_node, true); } /// Validates a `TypedDict` constructor call with a single positional dictionary argument @@ -325,6 +346,7 @@ fn validate_from_dict_literal<'db, 'ast>( key_str, value_type, error_node, + true, key_expr, &dict_item.value, TypedDictAssignmentKind::Constructor, @@ -363,6 +385,7 @@ fn validate_from_keywords<'db, 'ast>( arg_name.as_str(), arg_type, error_node, + true, keyword, &keyword.value, TypedDictAssignmentKind::Constructor, @@ -373,15 +396,17 @@ fn validate_from_keywords<'db, 'ast>( provided_keys } -/// Validates a `TypedDict` dictionary literal assignment +/// Validates a `TypedDict` dictionary literal assignment, /// e.g. `person: Person = {"name": "Alice", "age": 30}` pub(super) fn validate_typed_dict_dict_literal<'db, 'ast>( context: &InferContext<'db, 'ast>, typed_dict: TypedDictType<'db>, dict_expr: &'ast ast::ExprDict, error_node: AnyNodeRef<'ast>, + report_diagnostics: bool, expression_type_fn: impl Fn(&ast::Expr) -> Type<'db>, -) -> OrderSet<&'ast str> { +) -> Result, OrderSet<&'ast str>> { + let mut valid = true; let mut provided_keys = OrderSet::new(); // Validate each key-value pair in the dictionary literal @@ -392,12 +417,14 @@ pub(super) fn validate_typed_dict_dict_literal<'db, 'ast>( provided_keys.insert(key_str); let value_type = expression_type_fn(&item.value); - validate_typed_dict_key_assignment( + + valid &= validate_typed_dict_key_assignment( context, typed_dict, key_str, value_type, error_node, + report_diagnostics, key_expr, &item.value, TypedDictAssignmentKind::Constructor, @@ -406,7 +433,17 @@ pub(super) fn validate_typed_dict_dict_literal<'db, 'ast>( } } - validate_typed_dict_required_keys(context, typed_dict, &provided_keys, error_node); + valid &= validate_typed_dict_required_keys( + context, + typed_dict, + &provided_keys, + error_node, + report_diagnostics, + ); - provided_keys + if valid { + Ok(provided_keys) + } else { + Err(provided_keys) + } }