Handle unions in augmented assignments (#14045)

## Summary

Removing more TODOs from the augmented assignment test suite. Now, if
the _target_ is a union, we correctly infer the union of results:

```python
if flag:
    f = Foo()
else:
    f = 42.0
f += 12
```
This commit is contained in:
Charlie Marsh 2024-11-01 15:49:18 -04:00 committed by GitHub
parent 34a5d7cb7f
commit 70bdde4085
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 126 additions and 75 deletions

View file

@ -126,8 +126,7 @@ else:
f = 42.0 f = 42.0
f += 12 f += 12
# TODO(charlie): This should be `str | int | float` reveal_type(f) # revealed: int | str | float
reveal_type(f) # revealed: @Todo
``` ```
## Target union ## Target union
@ -148,6 +147,36 @@ else:
f = 42.0 f = 42.0
f += 12 f += 12
# TODO(charlie): This should be `str | float`. reveal_type(f) # revealed: str | float
reveal_type(f) # revealed: @Todo ```
## Partially bound target union with `__add__`
```py
def bool_instance() -> bool:
return True
flag = bool_instance()
class Foo:
def __add__(self, other: int) -> str:
return "Hello, world!"
if bool_instance():
def __iadd__(self, other: int) -> int:
return 42
class Bar:
def __add__(self, other: int) -> bytes:
return b"Hello, world!"
def __iadd__(self, other: int) -> float:
return 42.0
if flag:
f = Foo()
else:
f = Bar()
f += 12
reveal_type(f) # revealed: int | str | float
``` ```

View file

@ -1902,7 +1902,7 @@ impl<'db> UnionType<'db> {
pub fn map( pub fn map(
&self, &self,
db: &'db dyn Db, db: &'db dyn Db,
transform_fn: impl Fn(&Type<'db>) -> Type<'db>, transform_fn: impl FnMut(&Type<'db>) -> Type<'db>,
) -> Type<'db> { ) -> Type<'db> {
Self::from_elements(db, self.elements(db).iter().map(transform_fn)) Self::from_elements(db, self.elements(db).iter().map(transform_fn))
} }

View file

@ -1516,40 +1516,22 @@ impl<'db> TypeInferenceBuilder<'db> {
} }
} }
fn infer_augment_assignment_definition( fn infer_augmented_op(
&mut self, &mut self,
assignment: &ast::StmtAugAssign, assignment: &ast::StmtAugAssign,
definition: Definition<'db>, target_type: Type<'db>,
) { value_type: Type<'db>,
let target_ty = self.infer_augment_assignment(assignment); ) -> Type<'db> {
self.add_binding(assignment.into(), definition, target_ty);
}
fn infer_augment_assignment(&mut self, assignment: &ast::StmtAugAssign) -> Type<'db> {
let ast::StmtAugAssign {
range: _,
target,
op,
value,
} = assignment;
// Resolve the target type, assuming a load context.
let target_type = match &**target {
Expr::Name(name) => {
self.store_expression_type(target, Type::Never);
self.infer_name_load(name)
}
Expr::Attribute(attr) => {
self.store_expression_type(target, Type::Never);
self.infer_attribute_load(attr)
}
_ => self.infer_expression(target),
};
let value_type = self.infer_expression(value);
// If the target defines, e.g., `__iadd__`, infer the augmented assignment as a call to that // If the target defines, e.g., `__iadd__`, infer the augmented assignment as a call to that
// dunder. // dunder.
if let Type::Instance(class) = target_type { let op = assignment.op;
match target_type {
Type::Union(union) => {
return union.map(self.db, |&target_type| {
self.infer_augmented_op(assignment, target_type, value_type)
})
}
Type::Instance(class) => {
if let Symbol::Type(class_member, boundness) = if let Symbol::Type(class_member, boundness) =
class.class_member(self.db, op.in_place_dunder()) class.class_member(self.db, op.in_place_dunder())
{ {
@ -1580,7 +1562,7 @@ impl<'db> TypeInferenceBuilder<'db> {
let left_ty = target_type; let left_ty = target_type;
let right_ty = value_type; let right_ty = value_type;
let binary_return_ty = self.infer_binary_expression_type(left_ty, right_ty, *op) let binary_return_ty = self.infer_binary_expression_type(left_ty, right_ty, op)
.unwrap_or_else(|| { .unwrap_or_else(|| {
self.diagnostics.add( self.diagnostics.add(
assignment.into(), assignment.into(),
@ -1594,16 +1576,22 @@ impl<'db> TypeInferenceBuilder<'db> {
Type::Unknown Type::Unknown
}); });
UnionType::from_elements(self.db, [augmented_return_ty, binary_return_ty]) UnionType::from_elements(
self.db,
[augmented_return_ty, binary_return_ty],
)
} }
}; };
} }
} }
_ => {}
}
// By default, fall back to non-augmented binary operator inference.
let left_ty = target_type; let left_ty = target_type;
let right_ty = value_type; let right_ty = value_type;
self.infer_binary_expression_type(left_ty, right_ty, *op) self.infer_binary_expression_type(left_ty, right_ty, op)
.unwrap_or_else(|| { .unwrap_or_else(|| {
self.diagnostics.add( self.diagnostics.add(
assignment.into(), assignment.into(),
@ -1618,6 +1606,40 @@ impl<'db> TypeInferenceBuilder<'db> {
}) })
} }
fn infer_augment_assignment_definition(
&mut self,
assignment: &ast::StmtAugAssign,
definition: Definition<'db>,
) {
let target_ty = self.infer_augment_assignment(assignment);
self.add_binding(assignment.into(), definition, target_ty);
}
fn infer_augment_assignment(&mut self, assignment: &ast::StmtAugAssign) -> Type<'db> {
let ast::StmtAugAssign {
range: _,
target,
op: _,
value,
} = assignment;
// Resolve the target type, assuming a load context.
let target_type = match &**target {
Expr::Name(name) => {
self.store_expression_type(target, Type::Never);
self.infer_name_load(name)
}
Expr::Attribute(attr) => {
self.store_expression_type(target, Type::Never);
self.infer_attribute_load(attr)
}
_ => self.infer_expression(target),
};
let value_type = self.infer_expression(value);
self.infer_augmented_op(assignment, target_type, value_type)
}
fn infer_type_alias_statement(&mut self, type_alias_statement: &ast::StmtTypeAlias) { fn infer_type_alias_statement(&mut self, type_alias_statement: &ast::StmtTypeAlias) {
let ast::StmtTypeAlias { let ast::StmtTypeAlias {
range: _, range: _,