Add augmented assignment inference for -= operator (#13981)

## Summary

See: https://github.com/astral-sh/ruff/issues/12699
This commit is contained in:
Charlie Marsh 2024-10-29 22:14:27 -04:00 committed by GitHub
parent 39cf46ecd6
commit c6b82151dd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 179 additions and 65 deletions

View file

@ -0,0 +1,35 @@
# Augmented assignment
## Basic
```py
x = 3
x -= 1
reveal_type(x) # revealed: Literal[2]
```
## Dunder methods
```py
class C:
def __isub__(self, other: int) -> str:
return "Hello, world!"
x = C()
x -= 1
reveal_type(x) # revealed: str
```
## Unsupported types
```py
class C:
def __isub__(self, other: str) -> int:
return 42
x = C()
x -= 1
# TODO: should error, once operand type check is implemented
reveal_type(x) # revealed: int
```

View file

@ -955,6 +955,12 @@ where
}; };
let symbol = self.add_symbol(id.clone()); let symbol = self.add_symbol(id.clone());
if is_use {
self.mark_symbol_used(symbol);
let use_id = self.current_ast_ids().record_use(expr);
self.current_use_def_map_mut().record_use(symbol, use_id);
}
if is_definition { if is_definition {
match self.current_assignment().copied() { match self.current_assignment().copied() {
Some(CurrentAssignment::Assign { Some(CurrentAssignment::Assign {
@ -1018,12 +1024,6 @@ where
} }
} }
if is_use {
self.mark_symbol_used(symbol);
let use_id = self.current_ast_ids().record_use(expr);
self.current_use_def_map_mut().record_use(symbol, use_id);
}
walk_expr(self, expr); walk_expr(self, expr);
} }
ast::Expr::Named(node) => { ast::Expr::Named(node) => {

View file

@ -26,13 +26,14 @@
//! stringified annotations. We have a fourth Salsa query for inferring the deferred types //! stringified annotations. We have a fourth Salsa query for inferring the deferred types
//! associated with a particular definition. Scope-level inference infers deferred types for all //! associated with a particular definition. Scope-level inference infers deferred types for all
//! definitions once the rest of the types in the scope have been inferred. //! definitions once the rest of the types in the scope have been inferred.
use itertools::Itertools;
use std::borrow::Cow; use std::borrow::Cow;
use std::num::NonZeroU32; use std::num::NonZeroU32;
use itertools::Itertools;
use ruff_db::files::File; use ruff_db::files::File;
use ruff_db::parsed::parsed_module; use ruff_db::parsed::parsed_module;
use ruff_python_ast::{self as ast, AnyNodeRef, ExprContext, UnaryOp}; use ruff_python_ast::name::Name;
use ruff_python_ast::{self as ast, AnyNodeRef, Expr, ExprContext, Operator, UnaryOp};
use ruff_text_size::Ranged; use ruff_text_size::Ranged;
use rustc_hash::FxHashMap; use rustc_hash::FxHashMap;
use salsa; use salsa;
@ -1406,14 +1407,63 @@ impl<'db> TypeInferenceBuilder<'db> {
let ast::StmtAugAssign { let ast::StmtAugAssign {
range: _, range: _,
target, target,
op: _, op,
value, value,
} = assignment; } = assignment;
self.infer_expression(value);
self.infer_expression(target);
// TODO(dhruvmanila): Resolve the target type using the value type and the operator // Resolve the target type, assuming a load context.
Type::Todo let target_type = match &**target {
Expr::Name(name) => {
self.store_expression_type(target, Type::None);
self.infer_name_load(name)
}
Expr::Attribute(attr) => {
self.store_expression_type(target, Type::None);
self.infer_attribute_load(attr)
}
_ => self.infer_expression(target),
};
let value_type = self.infer_expression(value);
// TODO(charlie): Add remaining branches for different types of augmented assignments.
if let (Operator::Sub, Type::Instance(class)) = (*op, target_type) {
let class_member = class.class_member(self.db, "__isub__");
let call = class_member.call(self.db, &[value_type]);
return match call.return_ty_result(self.db, AnyNodeRef::StmtAugAssign(assignment), self)
{
Ok(t) => t,
Err(e) => {
self.add_diagnostic(
assignment.into(),
"unsupported-operator",
format_args!(
"Operator `{op}=` is unsupported for type `{}` with type `{}`",
target_type.display(self.db),
value_type.display(self.db)
),
);
e.return_ty()
}
};
}
let left_ty = target_type;
let right_ty = value_type;
self.infer_binary_expression_type(left_ty, right_ty, *op)
.unwrap_or_else(|| {
self.add_diagnostic(
assignment.into(),
"unsupported-operator",
format_args!(
"Operator `{op}` is unsupported between objects of type `{}` and `{}`",
left_ty.display(self.db),
right_ty.display(self.db)
),
);
Type::Unknown
})
} }
fn infer_type_alias_statement(&mut self, type_alias_statement: &ast::StmtTypeAlias) { fn infer_type_alias_statement(&mut self, type_alias_statement: &ast::StmtTypeAlias) {
@ -1850,11 +1900,15 @@ impl<'db> TypeInferenceBuilder<'db> {
ast::Expr::IpyEscapeCommand(_) => todo!("Implement Ipy escape command support"), ast::Expr::IpyEscapeCommand(_) => todo!("Implement Ipy escape command support"),
}; };
self.store_expression_type(expression, ty);
ty
}
fn store_expression_type(&mut self, expression: &ast::Expr, ty: Type<'db>) {
let expr_id = expression.scoped_ast_id(self.db, self.scope()); let expr_id = expression.scoped_ast_id(self.db, self.scope());
let previous = self.types.expressions.insert(expr_id, ty); let previous = self.types.expressions.insert(expr_id, ty);
assert_eq!(previous, None); assert_eq!(previous, None);
ty
} }
fn infer_number_literal_expression(&mut self, literal: &ast::ExprNumberLiteral) -> Type<'db> { fn infer_number_literal_expression(&mut self, literal: &ast::ExprNumberLiteral) -> Type<'db> {
@ -2391,12 +2445,15 @@ impl<'db> TypeInferenceBuilder<'db> {
} }
} }
fn infer_name_expression(&mut self, name: &ast::ExprName) -> Type<'db> { /// Infer the type of a [`ast::ExprName`] expression, assuming a load context.
let ast::ExprName { range: _, id, ctx } = name; fn infer_name_load(&mut self, name: &ast::ExprName) -> Type<'db> {
let file_scope_id = self.scope().file_scope_id(self.db); let ast::ExprName {
range: _,
id,
ctx: _,
} = name;
match ctx { let file_scope_id = self.scope().file_scope_id(self.db);
ExprContext::Load => {
let use_def = self.index.use_def_map(file_scope_id); let use_def = self.index.use_def_map(file_scope_id);
let symbol = self let symbol = self
.index .index
@ -2422,8 +2479,8 @@ impl<'db> TypeInferenceBuilder<'db> {
} else { } else {
None None
}; };
let ty = bindings_ty(self.db, definitions, unbound_ty);
let ty = bindings_ty(self.db, definitions, unbound_ty);
if ty.is_unbound() { if ty.is_unbound() {
self.add_diagnostic( self.add_diagnostic(
name.into(), name.into(),
@ -2440,26 +2497,48 @@ impl<'db> TypeInferenceBuilder<'db> {
ty ty
} }
fn infer_name_expression(&mut self, name: &ast::ExprName) -> Type<'db> {
match name.ctx {
ExprContext::Load => self.infer_name_load(name),
ExprContext::Store | ExprContext::Del => Type::None, ExprContext::Store | ExprContext::Del => Type::None,
ExprContext::Invalid => Type::Unknown, ExprContext::Invalid => Type::Unknown,
} }
} }
/// Infer the type of a [`ast::ExprAttribute`] expression, assuming a load context.
fn infer_attribute_load(&mut self, attribute: &ast::ExprAttribute) -> Type<'db> {
let ast::ExprAttribute {
value,
attr,
range: _,
ctx: _,
} = attribute;
let value_ty = self.infer_expression(value);
let member_ty = value_ty.member(self.db, &Name::new(&attr.id));
member_ty
}
fn infer_attribute_expression(&mut self, attribute: &ast::ExprAttribute) -> Type<'db> { fn infer_attribute_expression(&mut self, attribute: &ast::ExprAttribute) -> Type<'db> {
let ast::ExprAttribute { let ast::ExprAttribute {
value, value,
attr, attr: _,
range: _, range: _,
ctx, ctx,
} = attribute; } = attribute;
let value_ty = self.infer_expression(value);
let member_ty = value_ty.member(self.db, &ast::name::Name::new(&attr.id));
match ctx { match ctx {
ExprContext::Load => member_ty, ExprContext::Load => self.infer_attribute_load(attribute),
ExprContext::Store | ExprContext::Del => Type::None, ExprContext::Store | ExprContext::Del => {
ExprContext::Invalid => Type::Unknown, self.infer_expression(value);
Type::None
}
ExprContext::Invalid => {
self.infer_expression(value);
Type::Unknown
}
} }
} }