diff --git a/crates/red_knot_python_semantic/resources/mdtest/assignment/augmented.md b/crates/red_knot_python_semantic/resources/mdtest/assignment/augmented.md new file mode 100644 index 0000000000..65dfcef23a --- /dev/null +++ b/crates/red_knot_python_semantic/resources/mdtest/assignment/augmented.md @@ -0,0 +1,35 @@ +# Augmented assignment + +## Basic + +```py +x = 3 +x -= 1 +reveal_type(x) # revealed: Literal[2] +``` + +## Dunder methods + +```py +class C: + def __isub__(self, other: int) -> str: + return "Hello, world!" + +x = C() +x -= 1 +reveal_type(x) # revealed: str +``` + +## Unsupported types + +```py +class C: + def __isub__(self, other: str) -> int: + return 42 + +x = C() +x -= 1 + +# TODO: should error, once operand type check is implemented +reveal_type(x) # revealed: int +``` diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index 4ee6faac8a..b8c900d6b3 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -955,6 +955,12 @@ where }; let symbol = self.add_symbol(id.clone()); + if is_use { + self.mark_symbol_used(symbol); + let use_id = self.current_ast_ids().record_use(expr); + self.current_use_def_map_mut().record_use(symbol, use_id); + } + if is_definition { match self.current_assignment().copied() { Some(CurrentAssignment::Assign { @@ -1018,12 +1024,6 @@ where } } - if is_use { - self.mark_symbol_used(symbol); - let use_id = self.current_ast_ids().record_use(expr); - self.current_use_def_map_mut().record_use(symbol, use_id); - } - walk_expr(self, expr); } ast::Expr::Named(node) => { diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index aed301e4d1..5080bd27c4 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -26,13 +26,14 @@ //! stringified annotations. We have a fourth Salsa query for inferring the deferred types //! associated with a particular definition. Scope-level inference infers deferred types for all //! definitions once the rest of the types in the scope have been inferred. -use itertools::Itertools; use std::borrow::Cow; use std::num::NonZeroU32; +use itertools::Itertools; use ruff_db::files::File; use ruff_db::parsed::parsed_module; -use ruff_python_ast::{self as ast, AnyNodeRef, ExprContext, UnaryOp}; +use ruff_python_ast::name::Name; +use ruff_python_ast::{self as ast, AnyNodeRef, Expr, ExprContext, Operator, UnaryOp}; use ruff_text_size::Ranged; use rustc_hash::FxHashMap; use salsa; @@ -1406,14 +1407,63 @@ impl<'db> TypeInferenceBuilder<'db> { let ast::StmtAugAssign { range: _, target, - op: _, + op, value, } = assignment; - self.infer_expression(value); - self.infer_expression(target); - // TODO(dhruvmanila): Resolve the target type using the value type and the operator - Type::Todo + // Resolve the target type, assuming a load context. + let target_type = match &**target { + Expr::Name(name) => { + self.store_expression_type(target, Type::None); + self.infer_name_load(name) + } + Expr::Attribute(attr) => { + self.store_expression_type(target, Type::None); + self.infer_attribute_load(attr) + } + _ => self.infer_expression(target), + }; + let value_type = self.infer_expression(value); + + // TODO(charlie): Add remaining branches for different types of augmented assignments. + if let (Operator::Sub, Type::Instance(class)) = (*op, target_type) { + let class_member = class.class_member(self.db, "__isub__"); + let call = class_member.call(self.db, &[value_type]); + + return match call.return_ty_result(self.db, AnyNodeRef::StmtAugAssign(assignment), self) + { + Ok(t) => t, + Err(e) => { + self.add_diagnostic( + assignment.into(), + "unsupported-operator", + format_args!( + "Operator `{op}=` is unsupported for type `{}` with type `{}`", + target_type.display(self.db), + value_type.display(self.db) + ), + ); + e.return_ty() + } + }; + } + + let left_ty = target_type; + let right_ty = value_type; + + self.infer_binary_expression_type(left_ty, right_ty, *op) + .unwrap_or_else(|| { + self.add_diagnostic( + assignment.into(), + "unsupported-operator", + format_args!( + "Operator `{op}` is unsupported between objects of type `{}` and `{}`", + left_ty.display(self.db), + right_ty.display(self.db) + ), + ); + Type::Unknown + }) } fn infer_type_alias_statement(&mut self, type_alias_statement: &ast::StmtTypeAlias) { @@ -1850,11 +1900,15 @@ impl<'db> TypeInferenceBuilder<'db> { ast::Expr::IpyEscapeCommand(_) => todo!("Implement Ipy escape command support"), }; + self.store_expression_type(expression, ty); + + ty + } + + fn store_expression_type(&mut self, expression: &ast::Expr, ty: Type<'db>) { let expr_id = expression.scoped_ast_id(self.db, self.scope()); let previous = self.types.expressions.insert(expr_id, ty); assert_eq!(previous, None); - - ty } fn infer_number_literal_expression(&mut self, literal: &ast::ExprNumberLiteral) -> Type<'db> { @@ -2391,75 +2445,100 @@ impl<'db> TypeInferenceBuilder<'db> { } } - fn infer_name_expression(&mut self, name: &ast::ExprName) -> Type<'db> { - let ast::ExprName { range: _, id, ctx } = name; + /// Infer the type of a [`ast::ExprName`] expression, assuming a load context. + fn infer_name_load(&mut self, name: &ast::ExprName) -> Type<'db> { + let ast::ExprName { + range: _, + id, + ctx: _, + } = name; + let file_scope_id = self.scope().file_scope_id(self.db); + let use_def = self.index.use_def_map(file_scope_id); + let symbol = self + .index + .symbol_table(file_scope_id) + .symbol_id_by_name(id) + .expect("Expected the symbol table to create a symbol for every Name node"); + // if we're inferring types of deferred expressions, always treat them as public symbols + let (definitions, may_be_unbound) = if self.is_deferred() { + ( + use_def.public_bindings(symbol), + use_def.public_may_be_unbound(symbol), + ) + } else { + let use_id = name.scoped_use_id(self.db, self.scope()); + ( + use_def.bindings_at_use(use_id), + use_def.use_may_be_unbound(use_id), + ) + }; - match ctx { - ExprContext::Load => { - let use_def = self.index.use_def_map(file_scope_id); - let symbol = self - .index - .symbol_table(file_scope_id) - .symbol_id_by_name(id) - .expect("Expected the symbol table to create a symbol for every Name node"); - // if we're inferring types of deferred expressions, always treat them as public symbols - let (definitions, may_be_unbound) = if self.is_deferred() { - ( - use_def.public_bindings(symbol), - use_def.public_may_be_unbound(symbol), - ) - } else { - let use_id = name.scoped_use_id(self.db, self.scope()); - ( - use_def.bindings_at_use(use_id), - use_def.use_may_be_unbound(use_id), - ) - }; + let unbound_ty = if may_be_unbound { + Some(self.lookup_name(name)) + } else { + None + }; - let unbound_ty = if may_be_unbound { - Some(self.lookup_name(name)) - } else { - None - }; - let ty = bindings_ty(self.db, definitions, unbound_ty); + let ty = bindings_ty(self.db, definitions, unbound_ty); + if ty.is_unbound() { + self.add_diagnostic( + name.into(), + "unresolved-reference", + format_args!("Name `{id}` used when not defined"), + ); + } else if ty.may_be_unbound(self.db) { + self.add_diagnostic( + name.into(), + "possibly-unresolved-reference", + format_args!("Name `{id}` used when possibly not defined"), + ); + } - if ty.is_unbound() { - self.add_diagnostic( - name.into(), - "unresolved-reference", - format_args!("Name `{id}` used when not defined"), - ); - } else if ty.may_be_unbound(self.db) { - self.add_diagnostic( - name.into(), - "possibly-unresolved-reference", - format_args!("Name `{id}` used when possibly not defined"), - ); - } + ty + } - ty - } + fn infer_name_expression(&mut self, name: &ast::ExprName) -> Type<'db> { + match name.ctx { + ExprContext::Load => self.infer_name_load(name), ExprContext::Store | ExprContext::Del => Type::None, ExprContext::Invalid => Type::Unknown, } } - fn infer_attribute_expression(&mut self, attribute: &ast::ExprAttribute) -> Type<'db> { + /// Infer the type of a [`ast::ExprAttribute`] expression, assuming a load context. + fn infer_attribute_load(&mut self, attribute: &ast::ExprAttribute) -> Type<'db> { let ast::ExprAttribute { value, attr, range: _, - ctx, + ctx: _, } = attribute; let value_ty = self.infer_expression(value); - let member_ty = value_ty.member(self.db, &ast::name::Name::new(&attr.id)); + let member_ty = value_ty.member(self.db, &Name::new(&attr.id)); + + member_ty + } + + fn infer_attribute_expression(&mut self, attribute: &ast::ExprAttribute) -> Type<'db> { + let ast::ExprAttribute { + value, + attr: _, + range: _, + ctx, + } = attribute; match ctx { - ExprContext::Load => member_ty, - ExprContext::Store | ExprContext::Del => Type::None, - ExprContext::Invalid => Type::Unknown, + ExprContext::Load => self.infer_attribute_load(attribute), + ExprContext::Store | ExprContext::Del => { + self.infer_expression(value); + Type::None + } + ExprContext::Invalid => { + self.infer_expression(value); + Type::Unknown + } } }