mirror of
https://github.com/astral-sh/ruff.git
synced 2025-09-30 22:01:47 +00:00
Handle unions in augmented assignments (#14045)
## Summary Removing more TODOs from the augmented assignment test suite. Now, if the _target_ is a union, we correctly infer the union of results: ```python if flag: f = Foo() else: f = 42.0 f += 12 ```
This commit is contained in:
parent
34a5d7cb7f
commit
70bdde4085
3 changed files with 126 additions and 75 deletions
|
@ -126,8 +126,7 @@ else:
|
||||||
f = 42.0
|
f = 42.0
|
||||||
f += 12
|
f += 12
|
||||||
|
|
||||||
# TODO(charlie): This should be `str | int | float`
|
reveal_type(f) # revealed: int | str | float
|
||||||
reveal_type(f) # revealed: @Todo
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Target union
|
## Target union
|
||||||
|
@ -148,6 +147,36 @@ else:
|
||||||
f = 42.0
|
f = 42.0
|
||||||
f += 12
|
f += 12
|
||||||
|
|
||||||
# TODO(charlie): This should be `str | float`.
|
reveal_type(f) # revealed: str | float
|
||||||
reveal_type(f) # revealed: @Todo
|
```
|
||||||
|
|
||||||
|
## Partially bound target union with `__add__`
|
||||||
|
|
||||||
|
```py
|
||||||
|
def bool_instance() -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
flag = bool_instance()
|
||||||
|
|
||||||
|
class Foo:
|
||||||
|
def __add__(self, other: int) -> str:
|
||||||
|
return "Hello, world!"
|
||||||
|
if bool_instance():
|
||||||
|
def __iadd__(self, other: int) -> int:
|
||||||
|
return 42
|
||||||
|
|
||||||
|
class Bar:
|
||||||
|
def __add__(self, other: int) -> bytes:
|
||||||
|
return b"Hello, world!"
|
||||||
|
|
||||||
|
def __iadd__(self, other: int) -> float:
|
||||||
|
return 42.0
|
||||||
|
|
||||||
|
if flag:
|
||||||
|
f = Foo()
|
||||||
|
else:
|
||||||
|
f = Bar()
|
||||||
|
f += 12
|
||||||
|
|
||||||
|
reveal_type(f) # revealed: int | str | float
|
||||||
```
|
```
|
||||||
|
|
|
@ -1902,7 +1902,7 @@ impl<'db> UnionType<'db> {
|
||||||
pub fn map(
|
pub fn map(
|
||||||
&self,
|
&self,
|
||||||
db: &'db dyn Db,
|
db: &'db dyn Db,
|
||||||
transform_fn: impl Fn(&Type<'db>) -> Type<'db>,
|
transform_fn: impl FnMut(&Type<'db>) -> Type<'db>,
|
||||||
) -> Type<'db> {
|
) -> Type<'db> {
|
||||||
Self::from_elements(db, self.elements(db).iter().map(transform_fn))
|
Self::from_elements(db, self.elements(db).iter().map(transform_fn))
|
||||||
}
|
}
|
||||||
|
|
|
@ -1516,40 +1516,22 @@ impl<'db> TypeInferenceBuilder<'db> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn infer_augment_assignment_definition(
|
fn infer_augmented_op(
|
||||||
&mut self,
|
&mut self,
|
||||||
assignment: &ast::StmtAugAssign,
|
assignment: &ast::StmtAugAssign,
|
||||||
definition: Definition<'db>,
|
target_type: Type<'db>,
|
||||||
) {
|
value_type: Type<'db>,
|
||||||
let target_ty = self.infer_augment_assignment(assignment);
|
) -> Type<'db> {
|
||||||
self.add_binding(assignment.into(), definition, target_ty);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn infer_augment_assignment(&mut self, assignment: &ast::StmtAugAssign) -> Type<'db> {
|
|
||||||
let ast::StmtAugAssign {
|
|
||||||
range: _,
|
|
||||||
target,
|
|
||||||
op,
|
|
||||||
value,
|
|
||||||
} = assignment;
|
|
||||||
|
|
||||||
// Resolve the target type, assuming a load context.
|
|
||||||
let target_type = match &**target {
|
|
||||||
Expr::Name(name) => {
|
|
||||||
self.store_expression_type(target, Type::Never);
|
|
||||||
self.infer_name_load(name)
|
|
||||||
}
|
|
||||||
Expr::Attribute(attr) => {
|
|
||||||
self.store_expression_type(target, Type::Never);
|
|
||||||
self.infer_attribute_load(attr)
|
|
||||||
}
|
|
||||||
_ => self.infer_expression(target),
|
|
||||||
};
|
|
||||||
let value_type = self.infer_expression(value);
|
|
||||||
|
|
||||||
// If the target defines, e.g., `__iadd__`, infer the augmented assignment as a call to that
|
// If the target defines, e.g., `__iadd__`, infer the augmented assignment as a call to that
|
||||||
// dunder.
|
// dunder.
|
||||||
if let Type::Instance(class) = target_type {
|
let op = assignment.op;
|
||||||
|
match target_type {
|
||||||
|
Type::Union(union) => {
|
||||||
|
return union.map(self.db, |&target_type| {
|
||||||
|
self.infer_augmented_op(assignment, target_type, value_type)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Type::Instance(class) => {
|
||||||
if let Symbol::Type(class_member, boundness) =
|
if let Symbol::Type(class_member, boundness) =
|
||||||
class.class_member(self.db, op.in_place_dunder())
|
class.class_member(self.db, op.in_place_dunder())
|
||||||
{
|
{
|
||||||
|
@ -1580,7 +1562,7 @@ impl<'db> TypeInferenceBuilder<'db> {
|
||||||
let left_ty = target_type;
|
let left_ty = target_type;
|
||||||
let right_ty = value_type;
|
let right_ty = value_type;
|
||||||
|
|
||||||
let binary_return_ty = self.infer_binary_expression_type(left_ty, right_ty, *op)
|
let binary_return_ty = self.infer_binary_expression_type(left_ty, right_ty, op)
|
||||||
.unwrap_or_else(|| {
|
.unwrap_or_else(|| {
|
||||||
self.diagnostics.add(
|
self.diagnostics.add(
|
||||||
assignment.into(),
|
assignment.into(),
|
||||||
|
@ -1594,16 +1576,22 @@ impl<'db> TypeInferenceBuilder<'db> {
|
||||||
Type::Unknown
|
Type::Unknown
|
||||||
});
|
});
|
||||||
|
|
||||||
UnionType::from_elements(self.db, [augmented_return_ty, binary_return_ty])
|
UnionType::from_elements(
|
||||||
|
self.db,
|
||||||
|
[augmented_return_ty, binary_return_ty],
|
||||||
|
)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// By default, fall back to non-augmented binary operator inference.
|
||||||
let left_ty = target_type;
|
let left_ty = target_type;
|
||||||
let right_ty = value_type;
|
let right_ty = value_type;
|
||||||
|
|
||||||
self.infer_binary_expression_type(left_ty, right_ty, *op)
|
self.infer_binary_expression_type(left_ty, right_ty, op)
|
||||||
.unwrap_or_else(|| {
|
.unwrap_or_else(|| {
|
||||||
self.diagnostics.add(
|
self.diagnostics.add(
|
||||||
assignment.into(),
|
assignment.into(),
|
||||||
|
@ -1618,6 +1606,40 @@ impl<'db> TypeInferenceBuilder<'db> {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn infer_augment_assignment_definition(
|
||||||
|
&mut self,
|
||||||
|
assignment: &ast::StmtAugAssign,
|
||||||
|
definition: Definition<'db>,
|
||||||
|
) {
|
||||||
|
let target_ty = self.infer_augment_assignment(assignment);
|
||||||
|
self.add_binding(assignment.into(), definition, target_ty);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn infer_augment_assignment(&mut self, assignment: &ast::StmtAugAssign) -> Type<'db> {
|
||||||
|
let ast::StmtAugAssign {
|
||||||
|
range: _,
|
||||||
|
target,
|
||||||
|
op: _,
|
||||||
|
value,
|
||||||
|
} = assignment;
|
||||||
|
|
||||||
|
// Resolve the target type, assuming a load context.
|
||||||
|
let target_type = match &**target {
|
||||||
|
Expr::Name(name) => {
|
||||||
|
self.store_expression_type(target, Type::Never);
|
||||||
|
self.infer_name_load(name)
|
||||||
|
}
|
||||||
|
Expr::Attribute(attr) => {
|
||||||
|
self.store_expression_type(target, Type::Never);
|
||||||
|
self.infer_attribute_load(attr)
|
||||||
|
}
|
||||||
|
_ => self.infer_expression(target),
|
||||||
|
};
|
||||||
|
let value_type = self.infer_expression(value);
|
||||||
|
|
||||||
|
self.infer_augmented_op(assignment, target_type, value_type)
|
||||||
|
}
|
||||||
|
|
||||||
fn infer_type_alias_statement(&mut self, type_alias_statement: &ast::StmtTypeAlias) {
|
fn infer_type_alias_statement(&mut self, type_alias_statement: &ast::StmtTypeAlias) {
|
||||||
let ast::StmtTypeAlias {
|
let ast::StmtTypeAlias {
|
||||||
range: _,
|
range: _,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue