[red-knot] Use try_call_dunder for augmented assignment (#16717)

## Summary

Uses the `try_call_dunder` infrastructure for augmented assignment and
fixes the logic to work for types other than `Type::Instance(…)`. This
allows us to infer the correct type here:
```py
x = (1, 2)
x += (3, 4)
reveal_type(x)  # revealed: tuple[Literal[1], Literal[2], Literal[3], Literal[4]]
```
Or in this (extremely weird) scenario:
```py
class Meta(type):
    def __iadd__(cls, other: int) -> str:
        return ""

class C(metaclass=Meta): ...

cls = C
cls += 1

reveal_type(cls)  # revealed: str
```

Union and intersection handling could also be improved here, but I made
no attempt to do so in this PR.

## Test Plan

New MD tests
This commit is contained in:
David Peter 2025-03-14 20:36:09 +01:00 committed by GitHub
parent fe275725e0
commit ebcad6e641
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 60 additions and 74 deletions

View file

@ -10,6 +10,10 @@ reveal_type(x) # revealed: Literal[2]
x = 1.0
x /= 2
reveal_type(x) # revealed: int | float
x = (1, 2)
x += (3, 4)
reveal_type(x) # revealed: tuple[Literal[1], Literal[2], Literal[3], Literal[4]]
```
## Dunder methods
@ -161,3 +165,18 @@ def f(flag: bool, flag2: bool):
reveal_type(f) # revealed: int | str | float
```
## Implicit dunder calls on class objects
```py
class Meta(type):
def __iadd__(cls, other: int) -> str:
return ""
class C(metaclass=Meta): ...
cls = C
cls += 1
reveal_type(cls) # revealed: str
```

View file

@ -2752,86 +2752,53 @@ impl<'db> TypeInferenceBuilder<'db> {
// If the target defines, e.g., `__iadd__`, infer the augmented assignment as a call to that
// dunder.
let op = assignment.op;
match target_type {
Type::Union(union) => {
return union.map(self.db(), |&target_type| {
self.infer_augmented_op(assignment, target_type, value_type)
let db = self.db();
let report_unsupported_augmented_op = |ctx: &mut InferContext| {
ctx.report_lint(
&UNSUPPORTED_OPERATOR,
assignment,
format_args!(
"Operator `{op}=` is unsupported between objects of type `{}` and `{}`",
target_type.display(db),
value_type.display(db)
),
);
};
// Fall back to non-augmented binary operator inference.
let mut binary_return_ty = || {
self.infer_binary_expression_type(target_type, value_type, op)
.unwrap_or_else(|| {
report_unsupported_augmented_op(&mut self.context);
Type::unknown()
})
}
Type::Instance(instance) => {
if let Symbol::Type(class_member, boundness) = instance
.class()
.class_member(self.db(), op.in_place_dunder())
.symbol
{
let call = class_member.try_call(
self.db(),
&CallArguments::positional([target_type, value_type]),
);
let augmented_return_ty = match call {
Ok(t) => t.return_type(self.db()),
Err(e) => {
self.context.report_lint(
&UNSUPPORTED_OPERATOR,
assignment,
format_args!(
"Operator `{op}=` is unsupported between objects of type `{}` and `{}`",
target_type.display(self.db()),
value_type.display(self.db())
),
);
e.fallback_return_type(self.db())
}
};
};
return match boundness {
Boundness::Bound => augmented_return_ty,
Boundness::PossiblyUnbound => {
let left_ty = target_type;
let right_ty = value_type;
match target_type {
Type::Union(union) => union.map(db, |&elem_type| {
self.infer_augmented_op(assignment, elem_type, value_type)
}),
_ => {
let call = target_type.try_call_dunder(
db,
op.in_place_dunder(),
&CallArguments::positional([value_type]),
);
let binary_return_ty = self.infer_binary_expression_type(left_ty, right_ty, op)
.unwrap_or_else(|| {
self.context.report_lint(
&UNSUPPORTED_OPERATOR,
assignment,
format_args!(
"Operator `{op}=` is unsupported between objects of type `{}` and `{}`",
left_ty.display(self.db()),
right_ty.display(self.db())
),
);
Type::unknown()
});
UnionType::from_elements(
self.db(),
[augmented_return_ty, binary_return_ty],
)
}
};
match call {
Ok(outcome) => outcome.return_type(db),
Err(CallDunderError::MethodNotAvailable) => binary_return_ty(),
Err(CallDunderError::PossiblyUnbound(outcome)) => {
UnionType::from_elements(db, [outcome.return_type(db), binary_return_ty()])
}
Err(CallDunderError::Call(call_error)) => {
report_unsupported_augmented_op(&mut self.context);
call_error.fallback_return_type(db)
}
}
}
_ => {}
}
// By default, fall back to non-augmented binary operator inference.
let left_ty = target_type;
let right_ty = value_type;
self.infer_binary_expression_type(left_ty, right_ty, op)
.unwrap_or_else(|| {
self.context.report_lint(
&UNSUPPORTED_OPERATOR,
assignment,
format_args!(
"Operator `{op}=` is unsupported between objects of type `{}` and `{}`",
left_ty.display(self.db()),
right_ty.display(self.db())
),
);
Type::unknown()
})
}
fn infer_augment_assignment_definition(