use key and value parameter types as type context for __setitem__ dunder calls

This commit is contained in:
Ibraheem Ahmed 2025-12-22 17:51:43 -05:00
parent 664686bdbc
commit ca9593eec0
3 changed files with 136 additions and 39 deletions

View file

@ -3380,7 +3380,7 @@ impl Arguments {
/// 2
/// {'4': 5}
/// ```
pub fn arguments_source_order(&self) -> impl Iterator<Item = ArgOrKeyword<'_>> {
pub fn arguments_source_order(&self) -> impl Iterator<Item = ArgOrKeyword<'_>> + Clone {
let args = self.args.iter().map(ArgOrKeyword::Arg);
let keywords = self.keywords.iter().map(ArgOrKeyword::Keyword);
args.merge_by(keywords, |left, right| left.start() <= right.start())

View file

@ -297,6 +297,28 @@ def _(flag: bool):
reveal_type(x2) # revealed: list[int | None]
```
## Dunder Calls
The key and value parameters types are used as type context for `__setitem__` dunder calls:
```py
from typing import TypedDict
class Bar(TypedDict):
baz: float
def _(x: dict[str, Bar]):
x["bar"] = reveal_type({"baz": 2}) # revealed: Bar
class X:
def __setitem__(self, key: Bar, value: Bar):
...
def _(x: X):
# revealed: Bar
x[reveal_type({"baz": 1})] = reveal_type({"baz": 2}) # revealed: Bar
```
## Multi-inference diagnostics
```toml

View file

@ -7,7 +7,7 @@ use ruff_db::parsed::{ParsedModuleRef, parsed_module};
use ruff_db::source::source_text;
use ruff_python_ast::visitor::{Visitor, walk_expr};
use ruff_python_ast::{
self as ast, AnyNodeRef, ExprContext, HasNodeIndex, NodeIndex, PythonVersion,
self as ast, AnyNodeRef, ArgOrKeyword, ExprContext, HasNodeIndex, NodeIndex, PythonVersion,
};
use ruff_python_stdlib::builtins::version_builtin_was_added;
use ruff_text_size::{Ranged, TextRange};
@ -3951,7 +3951,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
&mut self,
target: &ast::ExprSubscript,
rhs_value: &ast::Expr,
rhs_value_ty: Type<'db>,
infer_rhs_value: &mut dyn FnMut(&mut Self, TypeContext<'db>) -> Type<'db>,
) -> bool {
let ast::ExprSubscript {
range: _,
@ -3962,28 +3962,28 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
} = target;
let object_ty = self.infer_expression(object, TypeContext::default());
let slice_ty = self.infer_expression(slice, TypeContext::default());
let mut infer_slice_ty = |builder: &mut Self, tcx| builder.infer_expression(slice, tcx);
self.validate_subscript_assignment_impl(
target,
None,
object_ty,
slice_ty,
&mut infer_slice_ty,
rhs_value,
rhs_value_ty,
infer_rhs_value,
true,
)
}
#[expect(clippy::too_many_arguments)]
fn validate_subscript_assignment_impl(
&self,
target: &'ast ast::ExprSubscript,
&mut self,
target: &ast::ExprSubscript,
full_object_ty: Option<Type<'db>>,
object_ty: Type<'db>,
slice_ty: Type<'db>,
rhs_value_node: &'ast ast::Expr,
rhs_value_ty: Type<'db>,
infer_slice_ty: &mut dyn FnMut(&mut Self, TypeContext<'db>) -> Type<'db>,
rhs_value_node: &ast::Expr,
infer_rhs_value: &mut dyn FnMut(&mut Self, TypeContext<'db>) -> Type<'db>,
emit_diagnostic: bool,
) -> bool {
/// Given a string literal or a union of string literals, return an iterator over the contained
@ -4019,6 +4019,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
match object_ty {
Type::Union(union) => {
// TODO: Perform multi-inference here.
let slice_ty = infer_slice_ty(self, TypeContext::default());
let rhs_value_ty = infer_rhs_value(self, TypeContext::default());
// Note that we use a loop here instead of .all(…) to avoid short-circuiting.
// We need to keep iterating to emit all diagnostics.
let mut valid = true;
@ -4027,9 +4031,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
target,
full_object_ty.or(Some(object_ty)),
*element_ty,
slice_ty,
&mut |_, _| slice_ty,
rhs_value_node,
rhs_value_ty,
&mut |_, _| rhs_value_ty,
emit_diagnostic,
);
}
@ -4037,16 +4041,20 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
Type::Intersection(intersection) => {
let check_positive_elements = |emit_diagnostic_and_short_circuit| {
// TODO: Perform multi-inference here.
let slice_ty = infer_slice_ty(self, TypeContext::default());
let rhs_value_ty = infer_rhs_value(self, TypeContext::default());
let mut check_positive_elements = |emit_diagnostic_and_short_circuit| {
let mut valid = false;
for element_ty in intersection.positive(db) {
valid |= self.validate_subscript_assignment_impl(
target,
full_object_ty.or(Some(object_ty)),
*element_ty,
slice_ty,
&mut |_, _| slice_ty,
rhs_value_node,
rhs_value_ty,
&mut |_, _| rhs_value_ty,
emit_diagnostic_and_short_circuit,
);
@ -4074,6 +4082,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// As an optimization, prevent calling `__setitem__` on (unions of) large `TypedDict`s, and
// validate the assignment ourselves. This also allows us to emit better diagnostics.
// TODO: Use type context here.
let slice_ty = infer_slice_ty(self, TypeContext::default());
let rhs_value_ty = infer_rhs_value(self, TypeContext::default());
let mut valid = true;
let Some(keys) = key_literals(db, slice_ty) else {
// Check if the key has a valid type. We only allow string literals, a union of string literals,
@ -4137,12 +4149,27 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
_ => {
match object_ty.try_call_dunder(
let ast_arguments = [
ArgOrKeyword::Arg(&target.slice),
ArgOrKeyword::Arg(rhs_value_node),
];
let mut call_arguments =
CallArguments::positional([Type::unknown(), Type::unknown()]);
let call_result = self.infer_and_try_call_dunder(
db,
object_ty,
"__setitem__",
CallArguments::positional([slice_ty, rhs_value_ty]),
ast_arguments,
&mut call_arguments,
TypeContext::default(),
) {
);
let [Some(slice_ty), Some(rhs_value_ty)] = call_arguments.types() else {
unreachable!();
};
match call_result {
Ok(_) => true,
Err(err) => match err {
CallDunderError::PossiblyUnbound { .. } => {
@ -4184,7 +4211,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
typed_dict,
full_object_ty,
key,
rhs_value_ty,
*rhs_value_ty,
target.value.as_ref(),
target.slice.as_ref(),
rhs_value_node,
@ -5065,11 +5092,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
}
ast::Expr::Subscript(subscript_expr) => {
let assigned_ty = infer_assigned_ty.map(|f| f(self, TypeContext::default()));
self.store_expression_type(target, assigned_ty.unwrap_or(Type::unknown()));
if let Some(infer_assigned_ty) = infer_assigned_ty {
let infer_assigned_ty = &mut |builder: &mut Self, tcx| {
let assigned_ty = infer_assigned_ty(builder, tcx);
builder.store_expression_type(target, assigned_ty);
assigned_ty
};
if let Some(assigned_ty) = assigned_ty {
self.validate_subscript_assignment(subscript_expr, value, assigned_ty);
self.validate_subscript_assignment(subscript_expr, value, infer_assigned_ty);
}
}
@ -6998,9 +7028,47 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
}
fn infer_and_check_argument_types(
fn infer_and_try_call_dunder<'a>(
&mut self,
ast_arguments: &ast::Arguments,
db: &'db dyn Db,
object: Type<'db>,
name: &str,
ast_arguments: impl IntoIterator<Item = ArgOrKeyword<'a>> + Clone,
argument_types: &mut CallArguments<'_, 'db>,
call_expression_tcx: TypeContext<'db>,
) -> Result<Bindings<'db>, CallDunderError<'db>> {
// Implicit calls to dunder methods never access instance members, so we pass
// `NO_INSTANCE_FALLBACK` here in addition to other policies:
match object
.member_lookup_with_policy(db, name.into(), MemberLookupPolicy::NO_INSTANCE_FALLBACK)
.place
{
Place::Defined(dunder_callable, _, boundness) => {
let mut bindings = dunder_callable
.bindings(db)
.match_parameters(db, argument_types);
if let Err(call_error) = self.infer_and_check_argument_types(
ast_arguments,
argument_types,
&mut bindings,
call_expression_tcx,
) {
return Err(CallDunderError::CallError(call_error, Box::new(bindings)));
}
if boundness == Definedness::PossiblyUndefined {
return Err(CallDunderError::PossiblyUnbound(Box::new(bindings)));
}
Ok(bindings)
}
Place::Undefined => Err(CallDunderError::MethodNotAvailable),
}
}
fn infer_and_check_argument_types<'a>(
&mut self,
ast_arguments: impl IntoIterator<Item = ArgOrKeyword<'a>> + Clone,
argument_types: &mut CallArguments<'_, 'db>,
bindings: &mut Bindings<'db>,
call_expression_tcx: TypeContext<'db>,
@ -7033,7 +7101,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// Attempt to infer the argument types using the narrowed type context.
self.infer_all_argument_types(
ast_arguments,
ast_arguments.clone(),
argument_types,
bindings,
narrowed_tcx,
@ -7073,7 +7141,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
self.context.set_multi_inference(was_in_multi_inference);
self.infer_all_argument_types(
ast_arguments,
ast_arguments.clone(),
argument_types,
bindings,
narrowed_tcx,
@ -7136,15 +7204,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
/// Note that this method may infer the type of a given argument expression multiple times with
/// distinct type context. The provided `MultiInferenceState` can be used to dictate multi-inference
/// behavior.
fn infer_all_argument_types(
fn infer_all_argument_types<'a>(
&mut self,
ast_arguments: &ast::Arguments,
ast_arguments: impl IntoIterator<Item = ArgOrKeyword<'a>>,
arguments_types: &mut CallArguments<'_, 'db>,
bindings: &Bindings<'db>,
call_expression_tcx: TypeContext<'db>,
multi_inference_state: MultiInferenceState,
) {
debug_assert_eq!(ast_arguments.len(), arguments_types.len());
debug_assert_eq!(arguments_types.len(), bindings.argument_forms().len());
let db = self.db();
@ -7152,7 +7219,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
0..,
arguments_types.iter_mut(),
bindings.argument_forms().iter().copied(),
ast_arguments.arguments_source_order()
ast_arguments
);
let overloads_with_binding = bindings
@ -7262,14 +7329,16 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// If there is only a single binding and overload, we can infer the argument directly with
// the unique parameter type annotation.
if let Ok((overload, binding)) = overloads_with_binding.iter().exactly_one() {
*argument_type = Some(self.infer_expression(
*argument_type = Some(self.infer_maybe_standalone_expression(
ast_argument,
TypeContext::new(parameter_type(overload, binding)),
));
} else {
// We perform inference once without any type context, emitting any diagnostics that are unrelated
// to bidirectional type inference.
*argument_type = Some(self.infer_expression(ast_argument, TypeContext::default()));
*argument_type = Some(
self.infer_maybe_standalone_expression(ast_argument, TypeContext::default()),
);
// We then silence any diagnostics emitted during multi-inference, as the type context is only
// used as a hint to infer a more assignable argument type, and should not lead to diagnostics
@ -7287,8 +7356,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
if !seen.insert(parameter_type) {
continue;
}
let inferred_ty =
self.infer_expression(ast_argument, TypeContext::new(Some(parameter_type)));
let inferred_ty = self.infer_maybe_standalone_expression(
ast_argument,
TypeContext::new(Some(parameter_type)),
);
// Ensure the inferred type is assignable to the declared type.
//
@ -8702,7 +8773,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
if let Some(bindings) = bindings {
let bindings = bindings.match_parameters(self.db(), &call_arguments);
self.infer_all_argument_types(
arguments,
arguments.arguments_source_order(),
&mut call_arguments,
&bindings,
tcx,
@ -8729,8 +8800,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
.bindings(self.db())
.match_parameters(self.db(), &call_arguments);
let bindings_result =
self.infer_and_check_argument_types(arguments, &mut call_arguments, &mut bindings, tcx);
let bindings_result = self.infer_and_check_argument_types(
arguments.arguments_source_order(),
&mut call_arguments,
&mut bindings,
tcx,
);
// Validate `TypedDict` constructor calls after argument type inference
if let Some(class_literal) = callable_type.as_class_literal() {