Handle unions in augmented assignments (#14045)

## Summary

Removing more TODOs from the augmented assignment test suite. Now, if
the _target_ is a union, we correctly infer the union of results:

```python
if flag:
    f = Foo()
else:
    f = 42.0
f += 12
```
This commit is contained in:
Charlie Marsh 2024-11-01 15:49:18 -04:00 committed by GitHub
parent 34a5d7cb7f
commit 70bdde4085
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 126 additions and 75 deletions

View file

@ -1902,7 +1902,7 @@ impl<'db> UnionType<'db> {
pub fn map(
&self,
db: &'db dyn Db,
transform_fn: impl Fn(&Type<'db>) -> Type<'db>,
transform_fn: impl FnMut(&Type<'db>) -> Type<'db>,
) -> Type<'db> {
Self::from_elements(db, self.elements(db).iter().map(transform_fn))
}

View file

@ -1516,6 +1516,96 @@ impl<'db> TypeInferenceBuilder<'db> {
}
}
fn infer_augmented_op(
&mut self,
assignment: &ast::StmtAugAssign,
target_type: Type<'db>,
value_type: Type<'db>,
) -> Type<'db> {
// If the target defines, e.g., `__iadd__`, infer the augmented assignment as a call to that
// dunder.
let op = assignment.op;
match target_type {
Type::Union(union) => {
return union.map(self.db, |&target_type| {
self.infer_augmented_op(assignment, target_type, value_type)
})
}
Type::Instance(class) => {
if let Symbol::Type(class_member, boundness) =
class.class_member(self.db, op.in_place_dunder())
{
let call = class_member.call(self.db, &[target_type, value_type]);
let augmented_return_ty = match call.return_ty_result(
self.db,
AnyNodeRef::StmtAugAssign(assignment),
&mut self.diagnostics,
) {
Ok(t) => t,
Err(e) => {
self.diagnostics.add(
assignment.into(),
"unsupported-operator",
format_args!(
"Operator `{op}=` is unsupported between objects of type `{}` and `{}`",
target_type.display(self.db),
value_type.display(self.db)
),
);
e.return_ty()
}
};
return match boundness {
Boundness::Bound => augmented_return_ty,
Boundness::MayBeUnbound => {
let left_ty = target_type;
let right_ty = value_type;
let binary_return_ty = self.infer_binary_expression_type(left_ty, right_ty, op)
.unwrap_or_else(|| {
self.diagnostics.add(
assignment.into(),
"unsupported-operator",
format_args!(
"Operator `{op}=` is unsupported between objects of type `{}` and `{}`",
left_ty.display(self.db),
right_ty.display(self.db)
),
);
Type::Unknown
});
UnionType::from_elements(
self.db,
[augmented_return_ty, binary_return_ty],
)
}
};
}
}
_ => {}
}
// By default, fall back to non-augmented binary operator inference.
let left_ty = target_type;
let right_ty = value_type;
self.infer_binary_expression_type(left_ty, right_ty, op)
.unwrap_or_else(|| {
self.diagnostics.add(
assignment.into(),
"unsupported-operator",
format_args!(
"Operator `{op}=` is unsupported between objects of type `{}` and `{}`",
left_ty.display(self.db),
right_ty.display(self.db)
),
);
Type::Unknown
})
}
fn infer_augment_assignment_definition(
&mut self,
assignment: &ast::StmtAugAssign,
@ -1529,7 +1619,7 @@ impl<'db> TypeInferenceBuilder<'db> {
let ast::StmtAugAssign {
range: _,
target,
op,
op: _,
value,
} = assignment;
@ -1547,75 +1637,7 @@ impl<'db> TypeInferenceBuilder<'db> {
};
let value_type = self.infer_expression(value);
// If the target defines, e.g., `__iadd__`, infer the augmented assignment as a call to that
// dunder.
if let Type::Instance(class) = target_type {
if let Symbol::Type(class_member, boundness) =
class.class_member(self.db, op.in_place_dunder())
{
let call = class_member.call(self.db, &[target_type, value_type]);
let augmented_return_ty = match call.return_ty_result(
self.db,
AnyNodeRef::StmtAugAssign(assignment),
&mut self.diagnostics,
) {
Ok(t) => t,
Err(e) => {
self.diagnostics.add(
assignment.into(),
"unsupported-operator",
format_args!(
"Operator `{op}=` is unsupported between objects of type `{}` and `{}`",
target_type.display(self.db),
value_type.display(self.db)
),
);
e.return_ty()
}
};
return match boundness {
Boundness::Bound => augmented_return_ty,
Boundness::MayBeUnbound => {
let left_ty = target_type;
let right_ty = value_type;
let binary_return_ty = self.infer_binary_expression_type(left_ty, right_ty, *op)
.unwrap_or_else(|| {
self.diagnostics.add(
assignment.into(),
"unsupported-operator",
format_args!(
"Operator `{op}=` is unsupported between objects of type `{}` and `{}`",
left_ty.display(self.db),
right_ty.display(self.db)
),
);
Type::Unknown
});
UnionType::from_elements(self.db, [augmented_return_ty, binary_return_ty])
}
};
}
}
let left_ty = target_type;
let right_ty = value_type;
self.infer_binary_expression_type(left_ty, right_ty, *op)
.unwrap_or_else(|| {
self.diagnostics.add(
assignment.into(),
"unsupported-operator",
format_args!(
"Operator `{op}=` is unsupported between objects of type `{}` and `{}`",
left_ty.display(self.db),
right_ty.display(self.db)
),
);
Type::Unknown
})
self.infer_augmented_op(assignment, target_type, value_type)
}
fn infer_type_alias_statement(&mut self, type_alias_statement: &ast::StmtTypeAlias) {