avoid eagerly reporting TypedDict literal diagnostics during overload evaluation

This commit is contained in:
Ibraheem Ahmed 2025-09-26 21:46:04 -04:00
parent e33bf726d7
commit 0cdca09e97
5 changed files with 171 additions and 60 deletions

View file

@ -1694,27 +1694,28 @@ class T(TypedDict):
x: int x: int
@overload @overload
def f(a: list[T], b: int): def f(a: list[T], b: int) -> int:
... ...
@overload @overload
def f(a: list[dict[str, int]], b: str): def f(a: list[dict[str, int]], b: str) -> str:
... ...
def f(a: list[dict[str, int]] | list[T], b: int | str): def f(a: list[dict[str, int]] | list[T], b: int | str) -> int | str:
... return 1
def int_or_str() -> int | str: def int_or_str() -> int | str:
return 1 return 1
f([{ "x": 1 }], int_or_str()) x = f([{ "x": 1 }], int_or_str())
reveal_type(x) # revealed: int | str
# error: [missing-typed-dict-key] "Missing required key 'x' in TypedDict `T` constructor" # we currently incorrectly consider `list[dict[str, int]]` a subtype of `list[T]`
# error: [invalid-key] "Invalid key access on TypedDict `T`: Unknown key "y"" # TODO: error: [no-matching-overload] "No overload of function `f` matches arguments"
f([{ "y": 1 }], int_or_str()) f([{ "y": 1 }], int_or_str())
``` ```
Non-matching overloads do not produce errors: Non-matching overloads do not produce diagnostics:
```py ```py
from typing import TypedDict, overload from typing import TypedDict, overload
@ -1723,16 +1724,16 @@ class T(TypedDict):
x: int x: int
@overload @overload
def f(a: T, b: int): def f(a: T, b: int) -> int:
... ...
@overload @overload
def f(a: dict[str, int], b: str): def f(a: dict[str, int], b: str) -> str:
... ...
def f(a: T | dict[str, int], b: int | str): def f(a: T | dict[str, int], b: int | str) -> int | str:
... return 1
# TODO x = f({ "y": 1 }, "a")
f({ "y": 1 }, "a") reveal_type(x) # revealed: str
``` ```

View file

@ -2467,7 +2467,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
// an annotated assignment, to closer match the order of any unions written in the type // an annotated assignment, to closer match the order of any unions written in the type
// annotation. // annotation.
if let Some(return_ty) = self.signature.return_ty if let Some(return_ty) = self.signature.return_ty
&& let Some(call_expression_tcx) = self.call_expression_tcx.annotation && let Some(call_expression_tcx) = self.call_expression_tcx.annotation()
{ {
match call_expression_tcx { match call_expression_tcx {
// A type variable is not a useful type-context for expression inference, and applying it // A type variable is not a useful type-context for expression inference, and applying it

View file

@ -330,7 +330,7 @@ impl<'db> InferExpression<'db> {
expression: Expression<'db>, expression: Expression<'db>,
tcx: TypeContext<'db>, tcx: TypeContext<'db>,
) -> InferExpression<'db> { ) -> InferExpression<'db> {
if tcx.annotation.is_some() { if tcx.annotation().is_some() {
InferExpression::WithContext(ExpressionWithContext::new(db, expression, tcx)) InferExpression::WithContext(ExpressionWithContext::new(db, expression, tcx))
} else { } else {
// Drop the empty `TypeContext` to avoid the interning cost. // Drop the empty `TypeContext` to avoid the interning cost.
@ -370,23 +370,51 @@ struct ExpressionWithContext<'db> {
/// Knowing the outer type context when inferring an expression can enable /// Knowing the outer type context when inferring an expression can enable
/// more precise inference results, aka "bidirectional type inference". /// more precise inference results, aka "bidirectional type inference".
#[derive(Default, Copy, Clone, Debug, PartialEq, Eq, Hash, get_size2::GetSize)] #[derive(Default, Copy, Clone, Debug, PartialEq, Eq, Hash, get_size2::GetSize)]
pub(crate) struct TypeContext<'db> { pub(crate) enum TypeContext<'db> {
pub(crate) annotation: Option<Type<'db>>, /// No type context.
#[default]
None,
/// The type of an annotated assignment, used as context for the RHS.
AnnotatedAssignment(Type<'db>),
/// The type of an annotated parameter, used as context for arguments that are
/// matched against it in function call expressions.
AnnotatedParameter(Type<'db>),
/// The unique type of annotated parameter, with no overloads.
UniqueAnnotatedParameter(Type<'db>),
} }
impl<'db> TypeContext<'db> { impl<'db> TypeContext<'db> {
pub(crate) fn new(annotation: Option<Type<'db>>) -> Self { pub(crate) fn annotation(self) -> Option<Type<'db>> {
Self { annotation } match self {
TypeContext::None => None,
TypeContext::AnnotatedAssignment(annotation) => Some(annotation),
TypeContext::AnnotatedParameter(annotation) => Some(annotation),
TypeContext::UniqueAnnotatedParameter(annotation) => Some(annotation),
}
}
pub(crate) fn with_annotation(self, annotation: Type<'db>) -> TypeContext<'db> {
match self {
TypeContext::None => panic!("no annotation to modify"),
TypeContext::AnnotatedAssignment(_) => TypeContext::AnnotatedAssignment(annotation),
TypeContext::AnnotatedParameter(_) => TypeContext::AnnotatedParameter(annotation),
TypeContext::UniqueAnnotatedParameter(_) => {
TypeContext::UniqueAnnotatedParameter(annotation)
}
}
} }
// If the type annotation is a specialized instance of the given `KnownClass`, returns the // If the type annotation is a specialized instance of the given `KnownClass`, returns the
// specialization. // specialization.
fn known_specialization( fn known_specialization(
&self, self,
known_class: KnownClass, known_class: KnownClass,
db: &'db dyn Db, db: &'db dyn Db,
) -> Option<Specialization<'db>> { ) -> Option<Specialization<'db>> {
self.annotation self.annotation()
.and_then(|ty| ty.known_specialization(known_class, db)) .and_then(|ty| ty.known_specialization(known_class, db))
} }
} }

View file

@ -3263,6 +3263,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
key, key,
assigned_ty, assigned_ty,
value.as_ref(), value.as_ref(),
true,
slice.as_ref(), slice.as_ref(),
rhs, rhs,
TypedDictAssignmentKind::Subscript, TypedDictAssignmentKind::Subscript,
@ -4019,7 +4020,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
if let Some(value) = value { if let Some(value) = value {
self.infer_maybe_standalone_expression( self.infer_maybe_standalone_expression(
value, value,
TypeContext::new(Some(annotated.inner_type())), TypeContext::AnnotatedAssignment(annotated.inner_type()),
); );
} }
@ -4114,7 +4115,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
if let Some(value) = value { if let Some(value) = value {
let inferred_ty = self.infer_maybe_standalone_expression( let inferred_ty = self.infer_maybe_standalone_expression(
value, value,
TypeContext::new(Some(declared.inner_type())), TypeContext::AnnotatedAssignment(declared.inner_type()),
); );
let inferred_ty = if target let inferred_ty = if target
.as_name_expr() .as_name_expr()
@ -4973,6 +4974,21 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
ast_arguments.arguments_source_order() ast_arguments.arguments_source_order()
); );
// If there is only a single binding and overload, there is a unique parameter type annotation for each argument.
let unique_parameter_type = bindings
.into_iter()
.map(|binding| match binding.matching_overload_index() {
MatchingOverloadIndex::Single(_) => 1,
MatchingOverloadIndex::Multiple(items) => items.len(),
MatchingOverloadIndex::None => {
// If there is a single overload that does not match, we still infer the
// arguments against it for better diagnostics.
if binding.overloads().len() == 1 { 1 } else { 0 }
}
})
.sum::<usize>()
== 1;
for (argument_index, (_, argument_type), argument_form, ast_argument) in iter { for (argument_index, (_, argument_type), argument_form, ast_argument) in iter {
let ast_argument = match ast_argument { let ast_argument = match ast_argument {
// Splatted arguments are inferred before parameter matching to // Splatted arguments are inferred before parameter matching to
@ -4995,13 +5011,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// Otherwise, we infer the type of each argument once for each matching overload signature, // Otherwise, we infer the type of each argument once for each matching overload signature,
// with the given annotated type as type context. // with the given annotated type as type context.
for binding in bindings { for binding in bindings {
// TODO: What if there are multiple bindings?
let overloads = match binding.matching_overload_index() { let overloads = match binding.matching_overload_index() {
MatchingOverloadIndex::Single(_) | MatchingOverloadIndex::Multiple(_) => { MatchingOverloadIndex::Single(_) | MatchingOverloadIndex::Multiple(_) => {
Either::Right(binding.matching_overloads()) Either::Right(binding.matching_overloads())
} }
// If there is a single overload that does not match, we still infer the // If there is a single overload that does not match, we still infer the
// argument types for better diagnostics. // arguments against it for better diagnostics.
MatchingOverloadIndex::None => match binding.overloads() { MatchingOverloadIndex::None => match binding.overloads() {
[overload] => Either::Left([(0, overload)].into_iter()), [overload] => Either::Left([(0, overload)].into_iter()),
_ => continue, _ => continue,
@ -5027,7 +5044,17 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let parameter_type = let parameter_type =
overload.signature.parameters()[*parameter_index].annotated_type(); overload.signature.parameters()[*parameter_index].annotated_type();
self.infer_expression_impl(ast_argument, TypeContext::new(parameter_type)); let tcx = parameter_type
.map(|ty| {
if unique_parameter_type {
TypeContext::UniqueAnnotatedParameter(ty)
} else {
TypeContext::AnnotatedParameter(ty)
}
})
.unwrap_or_default();
self.infer_expression_impl(ast_argument, tcx);
} }
*argument_type = self.try_expression_type(ast_argument); *argument_type = self.try_expression_type(ast_argument);
@ -5374,7 +5401,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let divergent = Type::divergent(self.scope()); let divergent = Type::divergent(self.scope());
let element_types = elts.iter().map(|element| { let element_types = elts.iter().map(|element| {
let annotated_elt_ty = annotated_elt_tys.as_mut().and_then(Iterator::next).copied(); let annotated_elt_ty = annotated_elt_tys.as_mut().and_then(Iterator::next).copied();
let element_type = self.infer_expression(element, TypeContext::new(annotated_elt_ty)); let elt_tcx = annotated_elt_ty
.map(|ty| tcx.with_annotation(ty))
.unwrap_or_default();
let element_type = self.infer_expression(element, elt_tcx);
if element_type.has_divergent_type(self.db(), divergent) { if element_type.has_divergent_type(self.db(), divergent) {
divergent divergent
@ -5423,7 +5453,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
} = dict; } = dict;
// Validate `TypedDict` dictionary literal assignments. // Validate `TypedDict` dictionary literal assignments.
if let Some(typed_dict) = tcx.annotation.and_then(Type::into_typed_dict) { if let Some(typed_dict) = tcx.annotation().and_then(Type::into_typed_dict) {
let typed_dict_items = typed_dict.items(self.db()); let typed_dict_items = typed_dict.items(self.db());
for item in items { for item in items {
@ -5433,28 +5463,43 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
&& let Some(key) = key.as_single_part_string() && let Some(key) = key.as_single_part_string()
&& let Some(field) = typed_dict_items.get(key.as_str()) && let Some(field) = typed_dict_items.get(key.as_str())
{ {
self.infer_expression(&item.value, TypeContext::new(Some(field.declared_ty))); self.infer_expression(&item.value, tcx.with_annotation(field.declared_ty));
} else { } else {
self.infer_expression(&item.value, TypeContext::default()); self.infer_expression(&item.value, TypeContext::default());
} }
} }
// TODO: Fall-back to `Unknown` so that we don't eagerly error when matching against a let report_diagnostics = match tcx {
// potential overload. // Do not eagerly report diagnostics when performing overload evaluation
validate_typed_dict_dict_literal( // with multiple potential overloads.
TypeContext::AnnotatedParameter(_) => false,
_ => true,
};
let result = validate_typed_dict_dict_literal(
&self.context, &self.context,
typed_dict, typed_dict,
dict, dict,
dict.into(), dict.into(),
report_diagnostics,
|expr| self.expression_type(expr), |expr| self.expression_type(expr),
); );
return Type::TypedDict(typed_dict); match result {
// Successfully validated the dictionary literal.
Ok(_) => return Type::TypedDict(typed_dict),
// The dictionary is not valid, but we are eagerly reporting diagnostics.
Err(_) if report_diagnostics => return Type::TypedDict(typed_dict),
// Otherwise, fallback to an untyped dictionary literal.
Err(_) => {}
}
} }
// Avoid false positives for the functional `TypedDict` form, which is currently // Avoid false positives for the functional `TypedDict` form, which is currently
// unsupported. // unsupported.
if let Some(Type::Dynamic(DynamicType::Todo(_))) = tcx.annotation { if let Some(Type::Dynamic(DynamicType::Todo(_))) = tcx.annotation() {
return KnownClass::Dict return KnownClass::Dict
.to_specialized_instance(self.db(), [Type::unknown(), Type::unknown()]); .to_specialized_instance(self.db(), [Type::unknown(), Type::unknown()]);
} }
@ -5526,7 +5571,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let elt_tcxs = match annotated_elt_tys { let elt_tcxs = match annotated_elt_tys {
None => Either::Left(iter::repeat(TypeContext::default())), None => Either::Left(iter::repeat(TypeContext::default())),
Some(tys) => Either::Right(tys.iter().map(|ty| TypeContext::new(Some(*ty)))), Some(tys) => Either::Right(tys.iter().map(|ty| tcx.with_annotation(*ty))),
}; };
for elts in elts { for elts in elts {

View file

@ -132,6 +132,7 @@ impl TypedDictAssignmentKind {
} }
/// Validates assignment of a value to a specific key on a `TypedDict`. /// Validates assignment of a value to a specific key on a `TypedDict`.
///
/// Returns true if the assignment is valid, false otherwise. /// Returns true if the assignment is valid, false otherwise.
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub(super) fn validate_typed_dict_key_assignment<'db, 'ast>( pub(super) fn validate_typed_dict_key_assignment<'db, 'ast>(
@ -140,6 +141,7 @@ pub(super) fn validate_typed_dict_key_assignment<'db, 'ast>(
key: &str, key: &str,
value_ty: Type<'db>, value_ty: Type<'db>,
typed_dict_node: impl Into<AnyNodeRef<'ast>>, typed_dict_node: impl Into<AnyNodeRef<'ast>>,
report_diagnostics: bool,
key_node: impl Into<AnyNodeRef<'ast>>, key_node: impl Into<AnyNodeRef<'ast>>,
value_node: impl Into<AnyNodeRef<'ast>>, value_node: impl Into<AnyNodeRef<'ast>>,
assignment_kind: TypedDictAssignmentKind, assignment_kind: TypedDictAssignmentKind,
@ -149,6 +151,7 @@ pub(super) fn validate_typed_dict_key_assignment<'db, 'ast>(
// Check if key exists in `TypedDict` // Check if key exists in `TypedDict`
let Some((_, item)) = items.iter().find(|(name, _)| *name == key) else { let Some((_, item)) = items.iter().find(|(name, _)| *name == key) else {
if report_diagnostics {
report_invalid_key_on_typed_dict( report_invalid_key_on_typed_dict(
context, context,
typed_dict_node.into(), typed_dict_node.into(),
@ -157,6 +160,8 @@ pub(super) fn validate_typed_dict_key_assignment<'db, 'ast>(
Type::string_literal(db, key), Type::string_literal(db, key),
&items, &items,
); );
}
return false; return false;
}; };
@ -177,7 +182,8 @@ pub(super) fn validate_typed_dict_key_assignment<'db, 'ast>(
}; };
if assignment_kind.is_subscript() && item.is_read_only() { if assignment_kind.is_subscript() && item.is_read_only() {
if let Some(builder) = if report_diagnostics
&& let Some(builder) =
context.report_lint(assignment_kind.diagnostic_type(), key_node.into()) context.report_lint(assignment_kind.diagnostic_type(), key_node.into())
{ {
let typed_dict_ty = Type::TypedDict(typed_dict); let typed_dict_ty = Type::TypedDict(typed_dict);
@ -207,7 +213,9 @@ pub(super) fn validate_typed_dict_key_assignment<'db, 'ast>(
} }
// Invalid assignment - emit diagnostic // Invalid assignment - emit diagnostic
if let Some(builder) = context.report_lint(assignment_kind.diagnostic_type(), value_node.into()) if report_diagnostics
&& let Some(builder) =
context.report_lint(assignment_kind.diagnostic_type(), value_node.into())
{ {
let typed_dict_ty = Type::TypedDict(typed_dict); let typed_dict_ty = Type::TypedDict(typed_dict);
let typed_dict_d = typed_dict_ty.display(db); let typed_dict_d = typed_dict_ty.display(db);
@ -240,13 +248,17 @@ pub(super) fn validate_typed_dict_key_assignment<'db, 'ast>(
} }
/// Validates that all required keys are provided in a `TypedDict` construction. /// Validates that all required keys are provided in a `TypedDict` construction.
///
/// Reports errors for any keys that are required but not provided. /// Reports errors for any keys that are required but not provided.
///
/// Returns true if the assignment is valid, false otherwise.
pub(super) fn validate_typed_dict_required_keys<'db, 'ast>( pub(super) fn validate_typed_dict_required_keys<'db, 'ast>(
context: &InferContext<'db, 'ast>, context: &InferContext<'db, 'ast>,
typed_dict: TypedDictType<'db>, typed_dict: TypedDictType<'db>,
provided_keys: &OrderSet<&str>, provided_keys: &OrderSet<&str>,
error_node: AnyNodeRef<'ast>, error_node: AnyNodeRef<'ast>,
) { report_diagnostics: bool,
) -> bool {
let db = context.db(); let db = context.db();
let items = typed_dict.items(db); let items = typed_dict.items(db);
@ -255,7 +267,13 @@ pub(super) fn validate_typed_dict_required_keys<'db, 'ast>(
.filter_map(|(key_name, field)| field.is_required().then_some(key_name.as_str())) .filter_map(|(key_name, field)| field.is_required().then_some(key_name.as_str()))
.collect(); .collect();
for missing_key in required_keys.difference(provided_keys) { let missing_keys = required_keys.difference(provided_keys);
let mut has_missing_key = false;
for missing_key in missing_keys {
has_missing_key = true;
if report_diagnostics {
report_missing_typed_dict_key( report_missing_typed_dict_key(
context, context,
error_node, error_node,
@ -263,6 +281,9 @@ pub(super) fn validate_typed_dict_required_keys<'db, 'ast>(
missing_key, missing_key,
); );
} }
}
!has_missing_key
} }
pub(super) fn validate_typed_dict_constructor<'db, 'ast>( pub(super) fn validate_typed_dict_constructor<'db, 'ast>(
@ -292,7 +313,7 @@ pub(super) fn validate_typed_dict_constructor<'db, 'ast>(
) )
}; };
validate_typed_dict_required_keys(context, typed_dict, &provided_keys, error_node); validate_typed_dict_required_keys(context, typed_dict, &provided_keys, error_node, true);
} }
/// Validates a `TypedDict` constructor call with a single positional dictionary argument /// Validates a `TypedDict` constructor call with a single positional dictionary argument
@ -325,6 +346,7 @@ fn validate_from_dict_literal<'db, 'ast>(
key_str, key_str,
value_type, value_type,
error_node, error_node,
true,
key_expr, key_expr,
&dict_item.value, &dict_item.value,
TypedDictAssignmentKind::Constructor, TypedDictAssignmentKind::Constructor,
@ -363,6 +385,7 @@ fn validate_from_keywords<'db, 'ast>(
arg_name.as_str(), arg_name.as_str(),
arg_type, arg_type,
error_node, error_node,
true,
keyword, keyword,
&keyword.value, &keyword.value,
TypedDictAssignmentKind::Constructor, TypedDictAssignmentKind::Constructor,
@ -373,15 +396,17 @@ fn validate_from_keywords<'db, 'ast>(
provided_keys provided_keys
} }
/// Validates a `TypedDict` dictionary literal assignment /// Validates a `TypedDict` dictionary literal assignment,
/// e.g. `person: Person = {"name": "Alice", "age": 30}` /// e.g. `person: Person = {"name": "Alice", "age": 30}`
pub(super) fn validate_typed_dict_dict_literal<'db, 'ast>( pub(super) fn validate_typed_dict_dict_literal<'db, 'ast>(
context: &InferContext<'db, 'ast>, context: &InferContext<'db, 'ast>,
typed_dict: TypedDictType<'db>, typed_dict: TypedDictType<'db>,
dict_expr: &'ast ast::ExprDict, dict_expr: &'ast ast::ExprDict,
error_node: AnyNodeRef<'ast>, error_node: AnyNodeRef<'ast>,
report_diagnostics: bool,
expression_type_fn: impl Fn(&ast::Expr) -> Type<'db>, expression_type_fn: impl Fn(&ast::Expr) -> Type<'db>,
) -> OrderSet<&'ast str> { ) -> Result<OrderSet<&'ast str>, OrderSet<&'ast str>> {
let mut valid = true;
let mut provided_keys = OrderSet::new(); let mut provided_keys = OrderSet::new();
// Validate each key-value pair in the dictionary literal // Validate each key-value pair in the dictionary literal
@ -392,12 +417,14 @@ pub(super) fn validate_typed_dict_dict_literal<'db, 'ast>(
provided_keys.insert(key_str); provided_keys.insert(key_str);
let value_type = expression_type_fn(&item.value); let value_type = expression_type_fn(&item.value);
validate_typed_dict_key_assignment(
valid &= validate_typed_dict_key_assignment(
context, context,
typed_dict, typed_dict,
key_str, key_str,
value_type, value_type,
error_node, error_node,
report_diagnostics,
key_expr, key_expr,
&item.value, &item.value,
TypedDictAssignmentKind::Constructor, TypedDictAssignmentKind::Constructor,
@ -406,7 +433,17 @@ pub(super) fn validate_typed_dict_dict_literal<'db, 'ast>(
} }
} }
validate_typed_dict_required_keys(context, typed_dict, &provided_keys, error_node); valid &= validate_typed_dict_required_keys(
context,
typed_dict,
&provided_keys,
error_node,
report_diagnostics,
);
provided_keys if valid {
Ok(provided_keys)
} else {
Err(provided_keys)
}
} }