mirror of
https://github.com/astral-sh/ruff.git
synced 2025-12-23 09:19:58 +00:00
use key and value parameter types as type context for __setitem__ dunder calls
This commit is contained in:
parent
664686bdbc
commit
ca9593eec0
3 changed files with 136 additions and 39 deletions
|
|
@ -3380,7 +3380,7 @@ impl Arguments {
|
|||
/// 2
|
||||
/// {'4': 5}
|
||||
/// ```
|
||||
pub fn arguments_source_order(&self) -> impl Iterator<Item = ArgOrKeyword<'_>> {
|
||||
pub fn arguments_source_order(&self) -> impl Iterator<Item = ArgOrKeyword<'_>> + Clone {
|
||||
let args = self.args.iter().map(ArgOrKeyword::Arg);
|
||||
let keywords = self.keywords.iter().map(ArgOrKeyword::Keyword);
|
||||
args.merge_by(keywords, |left, right| left.start() <= right.start())
|
||||
|
|
|
|||
|
|
@ -297,6 +297,28 @@ def _(flag: bool):
|
|||
reveal_type(x2) # revealed: list[int | None]
|
||||
```
|
||||
|
||||
## Dunder Calls
|
||||
|
||||
The key and value parameters types are used as type context for `__setitem__` dunder calls:
|
||||
|
||||
```py
|
||||
from typing import TypedDict
|
||||
|
||||
class Bar(TypedDict):
|
||||
baz: float
|
||||
|
||||
def _(x: dict[str, Bar]):
|
||||
x["bar"] = reveal_type({"baz": 2}) # revealed: Bar
|
||||
|
||||
class X:
|
||||
def __setitem__(self, key: Bar, value: Bar):
|
||||
...
|
||||
|
||||
def _(x: X):
|
||||
# revealed: Bar
|
||||
x[reveal_type({"baz": 1})] = reveal_type({"baz": 2}) # revealed: Bar
|
||||
```
|
||||
|
||||
## Multi-inference diagnostics
|
||||
|
||||
```toml
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ use ruff_db::parsed::{ParsedModuleRef, parsed_module};
|
|||
use ruff_db::source::source_text;
|
||||
use ruff_python_ast::visitor::{Visitor, walk_expr};
|
||||
use ruff_python_ast::{
|
||||
self as ast, AnyNodeRef, ExprContext, HasNodeIndex, NodeIndex, PythonVersion,
|
||||
self as ast, AnyNodeRef, ArgOrKeyword, ExprContext, HasNodeIndex, NodeIndex, PythonVersion,
|
||||
};
|
||||
use ruff_python_stdlib::builtins::version_builtin_was_added;
|
||||
use ruff_text_size::{Ranged, TextRange};
|
||||
|
|
@ -3951,7 +3951,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
&mut self,
|
||||
target: &ast::ExprSubscript,
|
||||
rhs_value: &ast::Expr,
|
||||
rhs_value_ty: Type<'db>,
|
||||
infer_rhs_value: &mut dyn FnMut(&mut Self, TypeContext<'db>) -> Type<'db>,
|
||||
) -> bool {
|
||||
let ast::ExprSubscript {
|
||||
range: _,
|
||||
|
|
@ -3962,28 +3962,28 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
} = target;
|
||||
|
||||
let object_ty = self.infer_expression(object, TypeContext::default());
|
||||
let slice_ty = self.infer_expression(slice, TypeContext::default());
|
||||
let mut infer_slice_ty = |builder: &mut Self, tcx| builder.infer_expression(slice, tcx);
|
||||
|
||||
self.validate_subscript_assignment_impl(
|
||||
target,
|
||||
None,
|
||||
object_ty,
|
||||
slice_ty,
|
||||
&mut infer_slice_ty,
|
||||
rhs_value,
|
||||
rhs_value_ty,
|
||||
infer_rhs_value,
|
||||
true,
|
||||
)
|
||||
}
|
||||
|
||||
#[expect(clippy::too_many_arguments)]
|
||||
fn validate_subscript_assignment_impl(
|
||||
&self,
|
||||
target: &'ast ast::ExprSubscript,
|
||||
&mut self,
|
||||
target: &ast::ExprSubscript,
|
||||
full_object_ty: Option<Type<'db>>,
|
||||
object_ty: Type<'db>,
|
||||
slice_ty: Type<'db>,
|
||||
rhs_value_node: &'ast ast::Expr,
|
||||
rhs_value_ty: Type<'db>,
|
||||
infer_slice_ty: &mut dyn FnMut(&mut Self, TypeContext<'db>) -> Type<'db>,
|
||||
rhs_value_node: &ast::Expr,
|
||||
infer_rhs_value: &mut dyn FnMut(&mut Self, TypeContext<'db>) -> Type<'db>,
|
||||
emit_diagnostic: bool,
|
||||
) -> bool {
|
||||
/// Given a string literal or a union of string literals, return an iterator over the contained
|
||||
|
|
@ -4019,6 +4019,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
|
||||
match object_ty {
|
||||
Type::Union(union) => {
|
||||
// TODO: Perform multi-inference here.
|
||||
let slice_ty = infer_slice_ty(self, TypeContext::default());
|
||||
let rhs_value_ty = infer_rhs_value(self, TypeContext::default());
|
||||
|
||||
// Note that we use a loop here instead of .all(…) to avoid short-circuiting.
|
||||
// We need to keep iterating to emit all diagnostics.
|
||||
let mut valid = true;
|
||||
|
|
@ -4027,9 +4031,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
target,
|
||||
full_object_ty.or(Some(object_ty)),
|
||||
*element_ty,
|
||||
slice_ty,
|
||||
&mut |_, _| slice_ty,
|
||||
rhs_value_node,
|
||||
rhs_value_ty,
|
||||
&mut |_, _| rhs_value_ty,
|
||||
emit_diagnostic,
|
||||
);
|
||||
}
|
||||
|
|
@ -4037,16 +4041,20 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
}
|
||||
|
||||
Type::Intersection(intersection) => {
|
||||
let check_positive_elements = |emit_diagnostic_and_short_circuit| {
|
||||
// TODO: Perform multi-inference here.
|
||||
let slice_ty = infer_slice_ty(self, TypeContext::default());
|
||||
let rhs_value_ty = infer_rhs_value(self, TypeContext::default());
|
||||
|
||||
let mut check_positive_elements = |emit_diagnostic_and_short_circuit| {
|
||||
let mut valid = false;
|
||||
for element_ty in intersection.positive(db) {
|
||||
valid |= self.validate_subscript_assignment_impl(
|
||||
target,
|
||||
full_object_ty.or(Some(object_ty)),
|
||||
*element_ty,
|
||||
slice_ty,
|
||||
&mut |_, _| slice_ty,
|
||||
rhs_value_node,
|
||||
rhs_value_ty,
|
||||
&mut |_, _| rhs_value_ty,
|
||||
emit_diagnostic_and_short_circuit,
|
||||
);
|
||||
|
||||
|
|
@ -4074,6 +4082,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
// As an optimization, prevent calling `__setitem__` on (unions of) large `TypedDict`s, and
|
||||
// validate the assignment ourselves. This also allows us to emit better diagnostics.
|
||||
|
||||
// TODO: Use type context here.
|
||||
let slice_ty = infer_slice_ty(self, TypeContext::default());
|
||||
let rhs_value_ty = infer_rhs_value(self, TypeContext::default());
|
||||
|
||||
let mut valid = true;
|
||||
let Some(keys) = key_literals(db, slice_ty) else {
|
||||
// Check if the key has a valid type. We only allow string literals, a union of string literals,
|
||||
|
|
@ -4137,12 +4149,27 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
}
|
||||
|
||||
_ => {
|
||||
match object_ty.try_call_dunder(
|
||||
let ast_arguments = [
|
||||
ArgOrKeyword::Arg(&target.slice),
|
||||
ArgOrKeyword::Arg(rhs_value_node),
|
||||
];
|
||||
let mut call_arguments =
|
||||
CallArguments::positional([Type::unknown(), Type::unknown()]);
|
||||
|
||||
let call_result = self.infer_and_try_call_dunder(
|
||||
db,
|
||||
object_ty,
|
||||
"__setitem__",
|
||||
CallArguments::positional([slice_ty, rhs_value_ty]),
|
||||
ast_arguments,
|
||||
&mut call_arguments,
|
||||
TypeContext::default(),
|
||||
) {
|
||||
);
|
||||
|
||||
let [Some(slice_ty), Some(rhs_value_ty)] = call_arguments.types() else {
|
||||
unreachable!();
|
||||
};
|
||||
|
||||
match call_result {
|
||||
Ok(_) => true,
|
||||
Err(err) => match err {
|
||||
CallDunderError::PossiblyUnbound { .. } => {
|
||||
|
|
@ -4184,7 +4211,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
typed_dict,
|
||||
full_object_ty,
|
||||
key,
|
||||
rhs_value_ty,
|
||||
*rhs_value_ty,
|
||||
target.value.as_ref(),
|
||||
target.slice.as_ref(),
|
||||
rhs_value_node,
|
||||
|
|
@ -5065,11 +5092,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
}
|
||||
}
|
||||
ast::Expr::Subscript(subscript_expr) => {
|
||||
let assigned_ty = infer_assigned_ty.map(|f| f(self, TypeContext::default()));
|
||||
self.store_expression_type(target, assigned_ty.unwrap_or(Type::unknown()));
|
||||
if let Some(infer_assigned_ty) = infer_assigned_ty {
|
||||
let infer_assigned_ty = &mut |builder: &mut Self, tcx| {
|
||||
let assigned_ty = infer_assigned_ty(builder, tcx);
|
||||
builder.store_expression_type(target, assigned_ty);
|
||||
assigned_ty
|
||||
};
|
||||
|
||||
if let Some(assigned_ty) = assigned_ty {
|
||||
self.validate_subscript_assignment(subscript_expr, value, assigned_ty);
|
||||
self.validate_subscript_assignment(subscript_expr, value, infer_assigned_ty);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -6998,9 +7028,47 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
}
|
||||
}
|
||||
|
||||
fn infer_and_check_argument_types(
|
||||
fn infer_and_try_call_dunder<'a>(
|
||||
&mut self,
|
||||
ast_arguments: &ast::Arguments,
|
||||
db: &'db dyn Db,
|
||||
object: Type<'db>,
|
||||
name: &str,
|
||||
ast_arguments: impl IntoIterator<Item = ArgOrKeyword<'a>> + Clone,
|
||||
argument_types: &mut CallArguments<'_, 'db>,
|
||||
call_expression_tcx: TypeContext<'db>,
|
||||
) -> Result<Bindings<'db>, CallDunderError<'db>> {
|
||||
// Implicit calls to dunder methods never access instance members, so we pass
|
||||
// `NO_INSTANCE_FALLBACK` here in addition to other policies:
|
||||
match object
|
||||
.member_lookup_with_policy(db, name.into(), MemberLookupPolicy::NO_INSTANCE_FALLBACK)
|
||||
.place
|
||||
{
|
||||
Place::Defined(dunder_callable, _, boundness) => {
|
||||
let mut bindings = dunder_callable
|
||||
.bindings(db)
|
||||
.match_parameters(db, argument_types);
|
||||
|
||||
if let Err(call_error) = self.infer_and_check_argument_types(
|
||||
ast_arguments,
|
||||
argument_types,
|
||||
&mut bindings,
|
||||
call_expression_tcx,
|
||||
) {
|
||||
return Err(CallDunderError::CallError(call_error, Box::new(bindings)));
|
||||
}
|
||||
|
||||
if boundness == Definedness::PossiblyUndefined {
|
||||
return Err(CallDunderError::PossiblyUnbound(Box::new(bindings)));
|
||||
}
|
||||
Ok(bindings)
|
||||
}
|
||||
Place::Undefined => Err(CallDunderError::MethodNotAvailable),
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_and_check_argument_types<'a>(
|
||||
&mut self,
|
||||
ast_arguments: impl IntoIterator<Item = ArgOrKeyword<'a>> + Clone,
|
||||
argument_types: &mut CallArguments<'_, 'db>,
|
||||
bindings: &mut Bindings<'db>,
|
||||
call_expression_tcx: TypeContext<'db>,
|
||||
|
|
@ -7033,7 +7101,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
|
||||
// Attempt to infer the argument types using the narrowed type context.
|
||||
self.infer_all_argument_types(
|
||||
ast_arguments,
|
||||
ast_arguments.clone(),
|
||||
argument_types,
|
||||
bindings,
|
||||
narrowed_tcx,
|
||||
|
|
@ -7073,7 +7141,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
self.context.set_multi_inference(was_in_multi_inference);
|
||||
|
||||
self.infer_all_argument_types(
|
||||
ast_arguments,
|
||||
ast_arguments.clone(),
|
||||
argument_types,
|
||||
bindings,
|
||||
narrowed_tcx,
|
||||
|
|
@ -7136,15 +7204,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
/// Note that this method may infer the type of a given argument expression multiple times with
|
||||
/// distinct type context. The provided `MultiInferenceState` can be used to dictate multi-inference
|
||||
/// behavior.
|
||||
fn infer_all_argument_types(
|
||||
fn infer_all_argument_types<'a>(
|
||||
&mut self,
|
||||
ast_arguments: &ast::Arguments,
|
||||
ast_arguments: impl IntoIterator<Item = ArgOrKeyword<'a>>,
|
||||
arguments_types: &mut CallArguments<'_, 'db>,
|
||||
bindings: &Bindings<'db>,
|
||||
call_expression_tcx: TypeContext<'db>,
|
||||
multi_inference_state: MultiInferenceState,
|
||||
) {
|
||||
debug_assert_eq!(ast_arguments.len(), arguments_types.len());
|
||||
debug_assert_eq!(arguments_types.len(), bindings.argument_forms().len());
|
||||
|
||||
let db = self.db();
|
||||
|
|
@ -7152,7 +7219,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
0..,
|
||||
arguments_types.iter_mut(),
|
||||
bindings.argument_forms().iter().copied(),
|
||||
ast_arguments.arguments_source_order()
|
||||
ast_arguments
|
||||
);
|
||||
|
||||
let overloads_with_binding = bindings
|
||||
|
|
@ -7262,14 +7329,16 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
// If there is only a single binding and overload, we can infer the argument directly with
|
||||
// the unique parameter type annotation.
|
||||
if let Ok((overload, binding)) = overloads_with_binding.iter().exactly_one() {
|
||||
*argument_type = Some(self.infer_expression(
|
||||
*argument_type = Some(self.infer_maybe_standalone_expression(
|
||||
ast_argument,
|
||||
TypeContext::new(parameter_type(overload, binding)),
|
||||
));
|
||||
} else {
|
||||
// We perform inference once without any type context, emitting any diagnostics that are unrelated
|
||||
// to bidirectional type inference.
|
||||
*argument_type = Some(self.infer_expression(ast_argument, TypeContext::default()));
|
||||
*argument_type = Some(
|
||||
self.infer_maybe_standalone_expression(ast_argument, TypeContext::default()),
|
||||
);
|
||||
|
||||
// We then silence any diagnostics emitted during multi-inference, as the type context is only
|
||||
// used as a hint to infer a more assignable argument type, and should not lead to diagnostics
|
||||
|
|
@ -7287,8 +7356,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
if !seen.insert(parameter_type) {
|
||||
continue;
|
||||
}
|
||||
let inferred_ty =
|
||||
self.infer_expression(ast_argument, TypeContext::new(Some(parameter_type)));
|
||||
let inferred_ty = self.infer_maybe_standalone_expression(
|
||||
ast_argument,
|
||||
TypeContext::new(Some(parameter_type)),
|
||||
);
|
||||
|
||||
// Ensure the inferred type is assignable to the declared type.
|
||||
//
|
||||
|
|
@ -8702,7 +8773,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
if let Some(bindings) = bindings {
|
||||
let bindings = bindings.match_parameters(self.db(), &call_arguments);
|
||||
self.infer_all_argument_types(
|
||||
arguments,
|
||||
arguments.arguments_source_order(),
|
||||
&mut call_arguments,
|
||||
&bindings,
|
||||
tcx,
|
||||
|
|
@ -8729,8 +8800,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
.bindings(self.db())
|
||||
.match_parameters(self.db(), &call_arguments);
|
||||
|
||||
let bindings_result =
|
||||
self.infer_and_check_argument_types(arguments, &mut call_arguments, &mut bindings, tcx);
|
||||
let bindings_result = self.infer_and_check_argument_types(
|
||||
arguments.arguments_source_order(),
|
||||
&mut call_arguments,
|
||||
&mut bindings,
|
||||
tcx,
|
||||
);
|
||||
|
||||
// Validate `TypedDict` constructor calls after argument type inference
|
||||
if let Some(class_literal) = callable_type.as_class_literal() {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue