avoid eagerly reporting TypedDict literal diagnostics during overload evaluation

This commit is contained in:
Ibraheem Ahmed 2025-09-26 21:46:04 -04:00
parent e33bf726d7
commit 0cdca09e97
5 changed files with 171 additions and 60 deletions

View file

@ -1694,27 +1694,28 @@ class T(TypedDict):
x: int
@overload
def f(a: list[T], b: int):
def f(a: list[T], b: int) -> int:
...
@overload
def f(a: list[dict[str, int]], b: str):
def f(a: list[dict[str, int]], b: str) -> str:
...
def f(a: list[dict[str, int]] | list[T], b: int | str):
...
def f(a: list[dict[str, int]] | list[T], b: int | str) -> int | str:
return 1
def int_or_str() -> int | str:
return 1
f([{ "x": 1 }], int_or_str())
x = f([{ "x": 1 }], int_or_str())
reveal_type(x) # revealed: int | str
# error: [missing-typed-dict-key] "Missing required key 'x' in TypedDict `T` constructor"
# error: [invalid-key] "Invalid key access on TypedDict `T`: Unknown key "y""
# we currently incorrectly consider `list[dict[str, int]]` a subtype of `list[T]`
# TODO: error: [no-matching-overload] "No overload of function `f` matches arguments"
f([{ "y": 1 }], int_or_str())
```
Non-matching overloads do not produce errors:
Non-matching overloads do not produce diagnostics:
```py
from typing import TypedDict, overload
@ -1723,16 +1724,16 @@ class T(TypedDict):
x: int
@overload
def f(a: T, b: int):
def f(a: T, b: int) -> int:
...
@overload
def f(a: dict[str, int], b: str):
def f(a: dict[str, int], b: str) -> str:
...
def f(a: T | dict[str, int], b: int | str):
...
def f(a: T | dict[str, int], b: int | str) -> int | str:
return 1
# TODO
f({ "y": 1 }, "a")
x = f({ "y": 1 }, "a")
reveal_type(x) # revealed: str
```

View file

@ -2467,7 +2467,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
// an annotated assignment, to closer match the order of any unions written in the type
// annotation.
if let Some(return_ty) = self.signature.return_ty
&& let Some(call_expression_tcx) = self.call_expression_tcx.annotation
&& let Some(call_expression_tcx) = self.call_expression_tcx.annotation()
{
match call_expression_tcx {
// A type variable is not a useful type-context for expression inference, and applying it

View file

@ -330,7 +330,7 @@ impl<'db> InferExpression<'db> {
expression: Expression<'db>,
tcx: TypeContext<'db>,
) -> InferExpression<'db> {
if tcx.annotation.is_some() {
if tcx.annotation().is_some() {
InferExpression::WithContext(ExpressionWithContext::new(db, expression, tcx))
} else {
// Drop the empty `TypeContext` to avoid the interning cost.
@ -370,23 +370,51 @@ struct ExpressionWithContext<'db> {
/// Knowing the outer type context when inferring an expression can enable
/// more precise inference results, aka "bidirectional type inference".
#[derive(Default, Copy, Clone, Debug, PartialEq, Eq, Hash, get_size2::GetSize)]
pub(crate) struct TypeContext<'db> {
pub(crate) annotation: Option<Type<'db>>,
pub(crate) enum TypeContext<'db> {
/// No type context.
#[default]
None,
/// The type of an annotated assignment, used as context for the RHS.
AnnotatedAssignment(Type<'db>),
/// The type of an annotated parameter, used as context for arguments that are
/// matched against it in function call expressions.
AnnotatedParameter(Type<'db>),
/// The unique type of annotated parameter, with no overloads.
UniqueAnnotatedParameter(Type<'db>),
}
impl<'db> TypeContext<'db> {
pub(crate) fn new(annotation: Option<Type<'db>>) -> Self {
Self { annotation }
pub(crate) fn annotation(self) -> Option<Type<'db>> {
match self {
TypeContext::None => None,
TypeContext::AnnotatedAssignment(annotation) => Some(annotation),
TypeContext::AnnotatedParameter(annotation) => Some(annotation),
TypeContext::UniqueAnnotatedParameter(annotation) => Some(annotation),
}
}
pub(crate) fn with_annotation(self, annotation: Type<'db>) -> TypeContext<'db> {
match self {
TypeContext::None => panic!("no annotation to modify"),
TypeContext::AnnotatedAssignment(_) => TypeContext::AnnotatedAssignment(annotation),
TypeContext::AnnotatedParameter(_) => TypeContext::AnnotatedParameter(annotation),
TypeContext::UniqueAnnotatedParameter(_) => {
TypeContext::UniqueAnnotatedParameter(annotation)
}
}
}
// If the type annotation is a specialized instance of the given `KnownClass`, returns the
// specialization.
fn known_specialization(
&self,
self,
known_class: KnownClass,
db: &'db dyn Db,
) -> Option<Specialization<'db>> {
self.annotation
self.annotation()
.and_then(|ty| ty.known_specialization(known_class, db))
}
}

View file

@ -3263,6 +3263,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
key,
assigned_ty,
value.as_ref(),
true,
slice.as_ref(),
rhs,
TypedDictAssignmentKind::Subscript,
@ -4019,7 +4020,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
if let Some(value) = value {
self.infer_maybe_standalone_expression(
value,
TypeContext::new(Some(annotated.inner_type())),
TypeContext::AnnotatedAssignment(annotated.inner_type()),
);
}
@ -4114,7 +4115,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
if let Some(value) = value {
let inferred_ty = self.infer_maybe_standalone_expression(
value,
TypeContext::new(Some(declared.inner_type())),
TypeContext::AnnotatedAssignment(declared.inner_type()),
);
let inferred_ty = if target
.as_name_expr()
@ -4973,6 +4974,21 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
ast_arguments.arguments_source_order()
);
// If there is only a single binding and overload, there is a unique parameter type annotation for each argument.
let unique_parameter_type = bindings
.into_iter()
.map(|binding| match binding.matching_overload_index() {
MatchingOverloadIndex::Single(_) => 1,
MatchingOverloadIndex::Multiple(items) => items.len(),
MatchingOverloadIndex::None => {
// If there is a single overload that does not match, we still infer the
// arguments against it for better diagnostics.
if binding.overloads().len() == 1 { 1 } else { 0 }
}
})
.sum::<usize>()
== 1;
for (argument_index, (_, argument_type), argument_form, ast_argument) in iter {
let ast_argument = match ast_argument {
// Splatted arguments are inferred before parameter matching to
@ -4995,13 +5011,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// Otherwise, we infer the type of each argument once for each matching overload signature,
// with the given annotated type as type context.
for binding in bindings {
// TODO: What if there are multiple bindings?
let overloads = match binding.matching_overload_index() {
MatchingOverloadIndex::Single(_) | MatchingOverloadIndex::Multiple(_) => {
Either::Right(binding.matching_overloads())
}
// If there is a single overload that does not match, we still infer the
// argument types for better diagnostics.
// arguments against it for better diagnostics.
MatchingOverloadIndex::None => match binding.overloads() {
[overload] => Either::Left([(0, overload)].into_iter()),
_ => continue,
@ -5027,7 +5044,17 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let parameter_type =
overload.signature.parameters()[*parameter_index].annotated_type();
self.infer_expression_impl(ast_argument, TypeContext::new(parameter_type));
let tcx = parameter_type
.map(|ty| {
if unique_parameter_type {
TypeContext::UniqueAnnotatedParameter(ty)
} else {
TypeContext::AnnotatedParameter(ty)
}
})
.unwrap_or_default();
self.infer_expression_impl(ast_argument, tcx);
}
*argument_type = self.try_expression_type(ast_argument);
@ -5374,7 +5401,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let divergent = Type::divergent(self.scope());
let element_types = elts.iter().map(|element| {
let annotated_elt_ty = annotated_elt_tys.as_mut().and_then(Iterator::next).copied();
let element_type = self.infer_expression(element, TypeContext::new(annotated_elt_ty));
let elt_tcx = annotated_elt_ty
.map(|ty| tcx.with_annotation(ty))
.unwrap_or_default();
let element_type = self.infer_expression(element, elt_tcx);
if element_type.has_divergent_type(self.db(), divergent) {
divergent
@ -5423,7 +5453,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
} = dict;
// Validate `TypedDict` dictionary literal assignments.
if let Some(typed_dict) = tcx.annotation.and_then(Type::into_typed_dict) {
if let Some(typed_dict) = tcx.annotation().and_then(Type::into_typed_dict) {
let typed_dict_items = typed_dict.items(self.db());
for item in items {
@ -5433,28 +5463,43 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
&& let Some(key) = key.as_single_part_string()
&& let Some(field) = typed_dict_items.get(key.as_str())
{
self.infer_expression(&item.value, TypeContext::new(Some(field.declared_ty)));
self.infer_expression(&item.value, tcx.with_annotation(field.declared_ty));
} else {
self.infer_expression(&item.value, TypeContext::default());
}
}
// TODO: Fall-back to `Unknown` so that we don't eagerly error when matching against a
// potential overload.
validate_typed_dict_dict_literal(
let report_diagnostics = match tcx {
// Do not eagerly report diagnostics when performing overload evaluation
// with multiple potential overloads.
TypeContext::AnnotatedParameter(_) => false,
_ => true,
};
let result = validate_typed_dict_dict_literal(
&self.context,
typed_dict,
dict,
dict.into(),
report_diagnostics,
|expr| self.expression_type(expr),
);
return Type::TypedDict(typed_dict);
match result {
// Successfully validated the dictionary literal.
Ok(_) => return Type::TypedDict(typed_dict),
// The dictionary is not valid, but we are eagerly reporting diagnostics.
Err(_) if report_diagnostics => return Type::TypedDict(typed_dict),
// Otherwise, fallback to an untyped dictionary literal.
Err(_) => {}
}
}
// Avoid false positives for the functional `TypedDict` form, which is currently
// unsupported.
if let Some(Type::Dynamic(DynamicType::Todo(_))) = tcx.annotation {
if let Some(Type::Dynamic(DynamicType::Todo(_))) = tcx.annotation() {
return KnownClass::Dict
.to_specialized_instance(self.db(), [Type::unknown(), Type::unknown()]);
}
@ -5526,7 +5571,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let elt_tcxs = match annotated_elt_tys {
None => Either::Left(iter::repeat(TypeContext::default())),
Some(tys) => Either::Right(tys.iter().map(|ty| TypeContext::new(Some(*ty)))),
Some(tys) => Either::Right(tys.iter().map(|ty| tcx.with_annotation(*ty))),
};
for elts in elts {

View file

@ -132,6 +132,7 @@ impl TypedDictAssignmentKind {
}
/// Validates assignment of a value to a specific key on a `TypedDict`.
///
/// Returns true if the assignment is valid, false otherwise.
#[allow(clippy::too_many_arguments)]
pub(super) fn validate_typed_dict_key_assignment<'db, 'ast>(
@ -140,6 +141,7 @@ pub(super) fn validate_typed_dict_key_assignment<'db, 'ast>(
key: &str,
value_ty: Type<'db>,
typed_dict_node: impl Into<AnyNodeRef<'ast>>,
report_diagnostics: bool,
key_node: impl Into<AnyNodeRef<'ast>>,
value_node: impl Into<AnyNodeRef<'ast>>,
assignment_kind: TypedDictAssignmentKind,
@ -149,14 +151,17 @@ pub(super) fn validate_typed_dict_key_assignment<'db, 'ast>(
// Check if key exists in `TypedDict`
let Some((_, item)) = items.iter().find(|(name, _)| *name == key) else {
report_invalid_key_on_typed_dict(
context,
typed_dict_node.into(),
key_node.into(),
Type::TypedDict(typed_dict),
Type::string_literal(db, key),
&items,
);
if report_diagnostics {
report_invalid_key_on_typed_dict(
context,
typed_dict_node.into(),
key_node.into(),
Type::TypedDict(typed_dict),
Type::string_literal(db, key),
&items,
);
}
return false;
};
@ -177,8 +182,9 @@ pub(super) fn validate_typed_dict_key_assignment<'db, 'ast>(
};
if assignment_kind.is_subscript() && item.is_read_only() {
if let Some(builder) =
context.report_lint(assignment_kind.diagnostic_type(), key_node.into())
if report_diagnostics
&& let Some(builder) =
context.report_lint(assignment_kind.diagnostic_type(), key_node.into())
{
let typed_dict_ty = Type::TypedDict(typed_dict);
let typed_dict_d = typed_dict_ty.display(db);
@ -207,7 +213,9 @@ pub(super) fn validate_typed_dict_key_assignment<'db, 'ast>(
}
// Invalid assignment - emit diagnostic
if let Some(builder) = context.report_lint(assignment_kind.diagnostic_type(), value_node.into())
if report_diagnostics
&& let Some(builder) =
context.report_lint(assignment_kind.diagnostic_type(), value_node.into())
{
let typed_dict_ty = Type::TypedDict(typed_dict);
let typed_dict_d = typed_dict_ty.display(db);
@ -240,13 +248,17 @@ pub(super) fn validate_typed_dict_key_assignment<'db, 'ast>(
}
/// Validates that all required keys are provided in a `TypedDict` construction.
///
/// Reports errors for any keys that are required but not provided.
///
/// Returns true if the assignment is valid, false otherwise.
pub(super) fn validate_typed_dict_required_keys<'db, 'ast>(
context: &InferContext<'db, 'ast>,
typed_dict: TypedDictType<'db>,
provided_keys: &OrderSet<&str>,
error_node: AnyNodeRef<'ast>,
) {
report_diagnostics: bool,
) -> bool {
let db = context.db();
let items = typed_dict.items(db);
@ -255,14 +267,23 @@ pub(super) fn validate_typed_dict_required_keys<'db, 'ast>(
.filter_map(|(key_name, field)| field.is_required().then_some(key_name.as_str()))
.collect();
for missing_key in required_keys.difference(provided_keys) {
report_missing_typed_dict_key(
context,
error_node,
Type::TypedDict(typed_dict),
missing_key,
);
let missing_keys = required_keys.difference(provided_keys);
let mut has_missing_key = false;
for missing_key in missing_keys {
has_missing_key = true;
if report_diagnostics {
report_missing_typed_dict_key(
context,
error_node,
Type::TypedDict(typed_dict),
missing_key,
);
}
}
!has_missing_key
}
pub(super) fn validate_typed_dict_constructor<'db, 'ast>(
@ -292,7 +313,7 @@ pub(super) fn validate_typed_dict_constructor<'db, 'ast>(
)
};
validate_typed_dict_required_keys(context, typed_dict, &provided_keys, error_node);
validate_typed_dict_required_keys(context, typed_dict, &provided_keys, error_node, true);
}
/// Validates a `TypedDict` constructor call with a single positional dictionary argument
@ -325,6 +346,7 @@ fn validate_from_dict_literal<'db, 'ast>(
key_str,
value_type,
error_node,
true,
key_expr,
&dict_item.value,
TypedDictAssignmentKind::Constructor,
@ -363,6 +385,7 @@ fn validate_from_keywords<'db, 'ast>(
arg_name.as_str(),
arg_type,
error_node,
true,
keyword,
&keyword.value,
TypedDictAssignmentKind::Constructor,
@ -373,15 +396,17 @@ fn validate_from_keywords<'db, 'ast>(
provided_keys
}
/// Validates a `TypedDict` dictionary literal assignment
/// Validates a `TypedDict` dictionary literal assignment,
/// e.g. `person: Person = {"name": "Alice", "age": 30}`
pub(super) fn validate_typed_dict_dict_literal<'db, 'ast>(
context: &InferContext<'db, 'ast>,
typed_dict: TypedDictType<'db>,
dict_expr: &'ast ast::ExprDict,
error_node: AnyNodeRef<'ast>,
report_diagnostics: bool,
expression_type_fn: impl Fn(&ast::Expr) -> Type<'db>,
) -> OrderSet<&'ast str> {
) -> Result<OrderSet<&'ast str>, OrderSet<&'ast str>> {
let mut valid = true;
let mut provided_keys = OrderSet::new();
// Validate each key-value pair in the dictionary literal
@ -392,12 +417,14 @@ pub(super) fn validate_typed_dict_dict_literal<'db, 'ast>(
provided_keys.insert(key_str);
let value_type = expression_type_fn(&item.value);
validate_typed_dict_key_assignment(
valid &= validate_typed_dict_key_assignment(
context,
typed_dict,
key_str,
value_type,
error_node,
report_diagnostics,
key_expr,
&item.value,
TypedDictAssignmentKind::Constructor,
@ -406,7 +433,17 @@ pub(super) fn validate_typed_dict_dict_literal<'db, 'ast>(
}
}
validate_typed_dict_required_keys(context, typed_dict, &provided_keys, error_node);
valid &= validate_typed_dict_required_keys(
context,
typed_dict,
&provided_keys,
error_node,
report_diagnostics,
);
provided_keys
if valid {
Ok(provided_keys)
} else {
Err(provided_keys)
}
}