mirror of
https://github.com/astral-sh/ruff.git
synced 2025-09-29 05:15:12 +00:00
Handle unions in augmented assignments (#14045)
## Summary Removing more TODOs from the augmented assignment test suite. Now, if the _target_ is a union, we correctly infer the union of results: ```python if flag: f = Foo() else: f = 42.0 f += 12 ```
This commit is contained in:
parent
34a5d7cb7f
commit
70bdde4085
3 changed files with 126 additions and 75 deletions
|
@ -1902,7 +1902,7 @@ impl<'db> UnionType<'db> {
|
|||
pub fn map(
|
||||
&self,
|
||||
db: &'db dyn Db,
|
||||
transform_fn: impl Fn(&Type<'db>) -> Type<'db>,
|
||||
transform_fn: impl FnMut(&Type<'db>) -> Type<'db>,
|
||||
) -> Type<'db> {
|
||||
Self::from_elements(db, self.elements(db).iter().map(transform_fn))
|
||||
}
|
||||
|
|
|
@ -1516,6 +1516,96 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
fn infer_augmented_op(
|
||||
&mut self,
|
||||
assignment: &ast::StmtAugAssign,
|
||||
target_type: Type<'db>,
|
||||
value_type: Type<'db>,
|
||||
) -> Type<'db> {
|
||||
// If the target defines, e.g., `__iadd__`, infer the augmented assignment as a call to that
|
||||
// dunder.
|
||||
let op = assignment.op;
|
||||
match target_type {
|
||||
Type::Union(union) => {
|
||||
return union.map(self.db, |&target_type| {
|
||||
self.infer_augmented_op(assignment, target_type, value_type)
|
||||
})
|
||||
}
|
||||
Type::Instance(class) => {
|
||||
if let Symbol::Type(class_member, boundness) =
|
||||
class.class_member(self.db, op.in_place_dunder())
|
||||
{
|
||||
let call = class_member.call(self.db, &[target_type, value_type]);
|
||||
let augmented_return_ty = match call.return_ty_result(
|
||||
self.db,
|
||||
AnyNodeRef::StmtAugAssign(assignment),
|
||||
&mut self.diagnostics,
|
||||
) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
self.diagnostics.add(
|
||||
assignment.into(),
|
||||
"unsupported-operator",
|
||||
format_args!(
|
||||
"Operator `{op}=` is unsupported between objects of type `{}` and `{}`",
|
||||
target_type.display(self.db),
|
||||
value_type.display(self.db)
|
||||
),
|
||||
);
|
||||
e.return_ty()
|
||||
}
|
||||
};
|
||||
|
||||
return match boundness {
|
||||
Boundness::Bound => augmented_return_ty,
|
||||
Boundness::MayBeUnbound => {
|
||||
let left_ty = target_type;
|
||||
let right_ty = value_type;
|
||||
|
||||
let binary_return_ty = self.infer_binary_expression_type(left_ty, right_ty, op)
|
||||
.unwrap_or_else(|| {
|
||||
self.diagnostics.add(
|
||||
assignment.into(),
|
||||
"unsupported-operator",
|
||||
format_args!(
|
||||
"Operator `{op}=` is unsupported between objects of type `{}` and `{}`",
|
||||
left_ty.display(self.db),
|
||||
right_ty.display(self.db)
|
||||
),
|
||||
);
|
||||
Type::Unknown
|
||||
});
|
||||
|
||||
UnionType::from_elements(
|
||||
self.db,
|
||||
[augmented_return_ty, binary_return_ty],
|
||||
)
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// By default, fall back to non-augmented binary operator inference.
|
||||
let left_ty = target_type;
|
||||
let right_ty = value_type;
|
||||
|
||||
self.infer_binary_expression_type(left_ty, right_ty, op)
|
||||
.unwrap_or_else(|| {
|
||||
self.diagnostics.add(
|
||||
assignment.into(),
|
||||
"unsupported-operator",
|
||||
format_args!(
|
||||
"Operator `{op}=` is unsupported between objects of type `{}` and `{}`",
|
||||
left_ty.display(self.db),
|
||||
right_ty.display(self.db)
|
||||
),
|
||||
);
|
||||
Type::Unknown
|
||||
})
|
||||
}
|
||||
|
||||
fn infer_augment_assignment_definition(
|
||||
&mut self,
|
||||
assignment: &ast::StmtAugAssign,
|
||||
|
@ -1529,7 +1619,7 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
let ast::StmtAugAssign {
|
||||
range: _,
|
||||
target,
|
||||
op,
|
||||
op: _,
|
||||
value,
|
||||
} = assignment;
|
||||
|
||||
|
@ -1547,75 +1637,7 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
};
|
||||
let value_type = self.infer_expression(value);
|
||||
|
||||
// If the target defines, e.g., `__iadd__`, infer the augmented assignment as a call to that
|
||||
// dunder.
|
||||
if let Type::Instance(class) = target_type {
|
||||
if let Symbol::Type(class_member, boundness) =
|
||||
class.class_member(self.db, op.in_place_dunder())
|
||||
{
|
||||
let call = class_member.call(self.db, &[target_type, value_type]);
|
||||
let augmented_return_ty = match call.return_ty_result(
|
||||
self.db,
|
||||
AnyNodeRef::StmtAugAssign(assignment),
|
||||
&mut self.diagnostics,
|
||||
) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
self.diagnostics.add(
|
||||
assignment.into(),
|
||||
"unsupported-operator",
|
||||
format_args!(
|
||||
"Operator `{op}=` is unsupported between objects of type `{}` and `{}`",
|
||||
target_type.display(self.db),
|
||||
value_type.display(self.db)
|
||||
),
|
||||
);
|
||||
e.return_ty()
|
||||
}
|
||||
};
|
||||
|
||||
return match boundness {
|
||||
Boundness::Bound => augmented_return_ty,
|
||||
Boundness::MayBeUnbound => {
|
||||
let left_ty = target_type;
|
||||
let right_ty = value_type;
|
||||
|
||||
let binary_return_ty = self.infer_binary_expression_type(left_ty, right_ty, *op)
|
||||
.unwrap_or_else(|| {
|
||||
self.diagnostics.add(
|
||||
assignment.into(),
|
||||
"unsupported-operator",
|
||||
format_args!(
|
||||
"Operator `{op}=` is unsupported between objects of type `{}` and `{}`",
|
||||
left_ty.display(self.db),
|
||||
right_ty.display(self.db)
|
||||
),
|
||||
);
|
||||
Type::Unknown
|
||||
});
|
||||
|
||||
UnionType::from_elements(self.db, [augmented_return_ty, binary_return_ty])
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
let left_ty = target_type;
|
||||
let right_ty = value_type;
|
||||
|
||||
self.infer_binary_expression_type(left_ty, right_ty, *op)
|
||||
.unwrap_or_else(|| {
|
||||
self.diagnostics.add(
|
||||
assignment.into(),
|
||||
"unsupported-operator",
|
||||
format_args!(
|
||||
"Operator `{op}=` is unsupported between objects of type `{}` and `{}`",
|
||||
left_ty.display(self.db),
|
||||
right_ty.display(self.db)
|
||||
),
|
||||
);
|
||||
Type::Unknown
|
||||
})
|
||||
self.infer_augmented_op(assignment, target_type, value_type)
|
||||
}
|
||||
|
||||
fn infer_type_alias_statement(&mut self, type_alias_statement: &ast::StmtTypeAlias) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue