diff --git a/crates/ruff_python_ast/src/nodes.rs b/crates/ruff_python_ast/src/nodes.rs index 2ea2701f82..d0992e11b7 100644 --- a/crates/ruff_python_ast/src/nodes.rs +++ b/crates/ruff_python_ast/src/nodes.rs @@ -3380,7 +3380,7 @@ impl Arguments { /// 2 /// {'4': 5} /// ``` - pub fn arguments_source_order(&self) -> impl Iterator> { + pub fn arguments_source_order(&self) -> impl Iterator> + Clone { let args = self.args.iter().map(ArgOrKeyword::Arg); let keywords = self.keywords.iter().map(ArgOrKeyword::Keyword); args.merge_by(keywords, |left, right| left.start() <= right.start()) diff --git a/crates/ty_python_semantic/resources/mdtest/bidirectional.md b/crates/ty_python_semantic/resources/mdtest/bidirectional.md index 17d31794ab..12014f7449 100644 --- a/crates/ty_python_semantic/resources/mdtest/bidirectional.md +++ b/crates/ty_python_semantic/resources/mdtest/bidirectional.md @@ -297,6 +297,28 @@ def _(flag: bool): reveal_type(x2) # revealed: list[int | None] ``` +## Dunder Calls + +The key and value parameters types are used as type context for `__setitem__` dunder calls: + +```py +from typing import TypedDict + +class Bar(TypedDict): + baz: float + +def _(x: dict[str, Bar]): + x["bar"] = reveal_type({"baz": 2}) # revealed: Bar + +class X: + def __setitem__(self, key: Bar, value: Bar): + ... + +def _(x: X): + # revealed: Bar + x[reveal_type({"baz": 1})] = reveal_type({"baz": 2}) # revealed: Bar +``` + ## Multi-inference diagnostics ```toml diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 28061e2e92..a816522545 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -7,7 +7,7 @@ use ruff_db::parsed::{ParsedModuleRef, parsed_module}; use ruff_db::source::source_text; use ruff_python_ast::visitor::{Visitor, walk_expr}; use ruff_python_ast::{ - self as ast, AnyNodeRef, ExprContext, HasNodeIndex, NodeIndex, PythonVersion, + self as ast, AnyNodeRef, ArgOrKeyword, ExprContext, HasNodeIndex, NodeIndex, PythonVersion, }; use ruff_python_stdlib::builtins::version_builtin_was_added; use ruff_text_size::{Ranged, TextRange}; @@ -3951,7 +3951,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { &mut self, target: &ast::ExprSubscript, rhs_value: &ast::Expr, - rhs_value_ty: Type<'db>, + infer_rhs_value: &mut dyn FnMut(&mut Self, TypeContext<'db>) -> Type<'db>, ) -> bool { let ast::ExprSubscript { range: _, @@ -3962,28 +3962,28 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } = target; let object_ty = self.infer_expression(object, TypeContext::default()); - let slice_ty = self.infer_expression(slice, TypeContext::default()); + let mut infer_slice_ty = |builder: &mut Self, tcx| builder.infer_expression(slice, tcx); self.validate_subscript_assignment_impl( target, None, object_ty, - slice_ty, + &mut infer_slice_ty, rhs_value, - rhs_value_ty, + infer_rhs_value, true, ) } #[expect(clippy::too_many_arguments)] fn validate_subscript_assignment_impl( - &self, - target: &'ast ast::ExprSubscript, + &mut self, + target: &ast::ExprSubscript, full_object_ty: Option>, object_ty: Type<'db>, - slice_ty: Type<'db>, - rhs_value_node: &'ast ast::Expr, - rhs_value_ty: Type<'db>, + infer_slice_ty: &mut dyn FnMut(&mut Self, TypeContext<'db>) -> Type<'db>, + rhs_value_node: &ast::Expr, + infer_rhs_value: &mut dyn FnMut(&mut Self, TypeContext<'db>) -> Type<'db>, emit_diagnostic: bool, ) -> bool { /// Given a string literal or a union of string literals, return an iterator over the contained @@ -4019,6 +4019,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { match object_ty { Type::Union(union) => { + // TODO: Perform multi-inference here. + let slice_ty = infer_slice_ty(self, TypeContext::default()); + let rhs_value_ty = infer_rhs_value(self, TypeContext::default()); + // Note that we use a loop here instead of .all(…) to avoid short-circuiting. // We need to keep iterating to emit all diagnostics. let mut valid = true; @@ -4027,9 +4031,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { target, full_object_ty.or(Some(object_ty)), *element_ty, - slice_ty, + &mut |_, _| slice_ty, rhs_value_node, - rhs_value_ty, + &mut |_, _| rhs_value_ty, emit_diagnostic, ); } @@ -4037,16 +4041,20 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } Type::Intersection(intersection) => { - let check_positive_elements = |emit_diagnostic_and_short_circuit| { + // TODO: Perform multi-inference here. + let slice_ty = infer_slice_ty(self, TypeContext::default()); + let rhs_value_ty = infer_rhs_value(self, TypeContext::default()); + + let mut check_positive_elements = |emit_diagnostic_and_short_circuit| { let mut valid = false; for element_ty in intersection.positive(db) { valid |= self.validate_subscript_assignment_impl( target, full_object_ty.or(Some(object_ty)), *element_ty, - slice_ty, + &mut |_, _| slice_ty, rhs_value_node, - rhs_value_ty, + &mut |_, _| rhs_value_ty, emit_diagnostic_and_short_circuit, ); @@ -4074,6 +4082,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // As an optimization, prevent calling `__setitem__` on (unions of) large `TypedDict`s, and // validate the assignment ourselves. This also allows us to emit better diagnostics. + // TODO: Use type context here. + let slice_ty = infer_slice_ty(self, TypeContext::default()); + let rhs_value_ty = infer_rhs_value(self, TypeContext::default()); + let mut valid = true; let Some(keys) = key_literals(db, slice_ty) else { // Check if the key has a valid type. We only allow string literals, a union of string literals, @@ -4137,12 +4149,27 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } _ => { - match object_ty.try_call_dunder( + let ast_arguments = [ + ArgOrKeyword::Arg(&target.slice), + ArgOrKeyword::Arg(rhs_value_node), + ]; + let mut call_arguments = + CallArguments::positional([Type::unknown(), Type::unknown()]); + + let call_result = self.infer_and_try_call_dunder( db, + object_ty, "__setitem__", - CallArguments::positional([slice_ty, rhs_value_ty]), + ast_arguments, + &mut call_arguments, TypeContext::default(), - ) { + ); + + let [Some(slice_ty), Some(rhs_value_ty)] = call_arguments.types() else { + unreachable!(); + }; + + match call_result { Ok(_) => true, Err(err) => match err { CallDunderError::PossiblyUnbound { .. } => { @@ -4184,7 +4211,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { typed_dict, full_object_ty, key, - rhs_value_ty, + *rhs_value_ty, target.value.as_ref(), target.slice.as_ref(), rhs_value_node, @@ -5065,11 +5092,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } } ast::Expr::Subscript(subscript_expr) => { - let assigned_ty = infer_assigned_ty.map(|f| f(self, TypeContext::default())); - self.store_expression_type(target, assigned_ty.unwrap_or(Type::unknown())); + if let Some(infer_assigned_ty) = infer_assigned_ty { + let infer_assigned_ty = &mut |builder: &mut Self, tcx| { + let assigned_ty = infer_assigned_ty(builder, tcx); + builder.store_expression_type(target, assigned_ty); + assigned_ty + }; - if let Some(assigned_ty) = assigned_ty { - self.validate_subscript_assignment(subscript_expr, value, assigned_ty); + self.validate_subscript_assignment(subscript_expr, value, infer_assigned_ty); } } @@ -6998,9 +7028,47 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } } - fn infer_and_check_argument_types( + fn infer_and_try_call_dunder<'a>( &mut self, - ast_arguments: &ast::Arguments, + db: &'db dyn Db, + object: Type<'db>, + name: &str, + ast_arguments: impl IntoIterator> + Clone, + argument_types: &mut CallArguments<'_, 'db>, + call_expression_tcx: TypeContext<'db>, + ) -> Result, CallDunderError<'db>> { + // Implicit calls to dunder methods never access instance members, so we pass + // `NO_INSTANCE_FALLBACK` here in addition to other policies: + match object + .member_lookup_with_policy(db, name.into(), MemberLookupPolicy::NO_INSTANCE_FALLBACK) + .place + { + Place::Defined(dunder_callable, _, boundness) => { + let mut bindings = dunder_callable + .bindings(db) + .match_parameters(db, argument_types); + + if let Err(call_error) = self.infer_and_check_argument_types( + ast_arguments, + argument_types, + &mut bindings, + call_expression_tcx, + ) { + return Err(CallDunderError::CallError(call_error, Box::new(bindings))); + } + + if boundness == Definedness::PossiblyUndefined { + return Err(CallDunderError::PossiblyUnbound(Box::new(bindings))); + } + Ok(bindings) + } + Place::Undefined => Err(CallDunderError::MethodNotAvailable), + } + } + + fn infer_and_check_argument_types<'a>( + &mut self, + ast_arguments: impl IntoIterator> + Clone, argument_types: &mut CallArguments<'_, 'db>, bindings: &mut Bindings<'db>, call_expression_tcx: TypeContext<'db>, @@ -7033,7 +7101,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // Attempt to infer the argument types using the narrowed type context. self.infer_all_argument_types( - ast_arguments, + ast_arguments.clone(), argument_types, bindings, narrowed_tcx, @@ -7073,7 +7141,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { self.context.set_multi_inference(was_in_multi_inference); self.infer_all_argument_types( - ast_arguments, + ast_arguments.clone(), argument_types, bindings, narrowed_tcx, @@ -7136,15 +7204,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { /// Note that this method may infer the type of a given argument expression multiple times with /// distinct type context. The provided `MultiInferenceState` can be used to dictate multi-inference /// behavior. - fn infer_all_argument_types( + fn infer_all_argument_types<'a>( &mut self, - ast_arguments: &ast::Arguments, + ast_arguments: impl IntoIterator>, arguments_types: &mut CallArguments<'_, 'db>, bindings: &Bindings<'db>, call_expression_tcx: TypeContext<'db>, multi_inference_state: MultiInferenceState, ) { - debug_assert_eq!(ast_arguments.len(), arguments_types.len()); debug_assert_eq!(arguments_types.len(), bindings.argument_forms().len()); let db = self.db(); @@ -7152,7 +7219,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { 0.., arguments_types.iter_mut(), bindings.argument_forms().iter().copied(), - ast_arguments.arguments_source_order() + ast_arguments ); let overloads_with_binding = bindings @@ -7262,14 +7329,16 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // If there is only a single binding and overload, we can infer the argument directly with // the unique parameter type annotation. if let Ok((overload, binding)) = overloads_with_binding.iter().exactly_one() { - *argument_type = Some(self.infer_expression( + *argument_type = Some(self.infer_maybe_standalone_expression( ast_argument, TypeContext::new(parameter_type(overload, binding)), )); } else { // We perform inference once without any type context, emitting any diagnostics that are unrelated // to bidirectional type inference. - *argument_type = Some(self.infer_expression(ast_argument, TypeContext::default())); + *argument_type = Some( + self.infer_maybe_standalone_expression(ast_argument, TypeContext::default()), + ); // We then silence any diagnostics emitted during multi-inference, as the type context is only // used as a hint to infer a more assignable argument type, and should not lead to diagnostics @@ -7287,8 +7356,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { if !seen.insert(parameter_type) { continue; } - let inferred_ty = - self.infer_expression(ast_argument, TypeContext::new(Some(parameter_type))); + let inferred_ty = self.infer_maybe_standalone_expression( + ast_argument, + TypeContext::new(Some(parameter_type)), + ); // Ensure the inferred type is assignable to the declared type. // @@ -8702,7 +8773,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { if let Some(bindings) = bindings { let bindings = bindings.match_parameters(self.db(), &call_arguments); self.infer_all_argument_types( - arguments, + arguments.arguments_source_order(), &mut call_arguments, &bindings, tcx, @@ -8729,8 +8800,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { .bindings(self.db()) .match_parameters(self.db(), &call_arguments); - let bindings_result = - self.infer_and_check_argument_types(arguments, &mut call_arguments, &mut bindings, tcx); + let bindings_result = self.infer_and_check_argument_types( + arguments.arguments_source_order(), + &mut call_arguments, + &mut bindings, + tcx, + ); // Validate `TypedDict` constructor calls after argument type inference if let Some(class_literal) = callable_type.as_class_literal() {