[red-knot] Add definition for augmented assignment (#12892)

## Summary

This PR adds definition for augmented assignment. This is similar to
annotated assignment in terms of implementation.

An augmented assignment should also record a use of the variable but
that's a TODO for now.

## Test Plan

Add test case to validate that a definition is added.
This commit is contained in:
Dhruv Manilawala 2024-08-20 10:33:55 +05:30 committed by GitHub
parent df09045176
commit aefaddeae7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 96 additions and 3 deletions

View file

@ -463,6 +463,25 @@ mod tests {
)); ));
} }
#[test]
fn augmented_assignment() {
let TestCase { db, file } = test_case("x += 1");
let scope = global_scope(&db, file);
let global_table = symbol_table(&db, scope);
assert_eq!(names(&global_table), vec!["x"]);
let use_def = use_def_map(&db, scope);
let definition = use_def
.first_public_definition(global_table.symbol_id_by_name("x").unwrap())
.unwrap();
assert!(matches!(
definition.node(&db),
DefinitionKind::AugmentedAssignment(_)
));
}
#[test] #[test]
fn class_scope() { fn class_scope() {
let TestCase { db, file } = test_case( let TestCase { db, file } = test_case(

View file

@ -495,6 +495,20 @@ where
self.visit_expr(&node.target); self.visit_expr(&node.target);
self.current_assignment = None; self.current_assignment = None;
} }
ast::Stmt::AugAssign(
aug_assign @ ast::StmtAugAssign {
range: _,
target,
op: _,
value,
},
) => {
debug_assert!(self.current_assignment.is_none());
self.visit_expr(value);
self.current_assignment = Some(aug_assign.into());
self.visit_expr(target);
self.current_assignment = None;
}
ast::Stmt::If(node) => { ast::Stmt::If(node) => {
self.visit_expr(&node.test); self.visit_expr(&node.test);
let pre_if = self.flow_snapshot(); let pre_if = self.flow_snapshot();
@ -563,12 +577,21 @@ where
match expr { match expr {
ast::Expr::Name(name_node @ ast::ExprName { id, ctx, .. }) => { ast::Expr::Name(name_node @ ast::ExprName { id, ctx, .. }) => {
let flags = match ctx { let mut flags = match ctx {
ast::ExprContext::Load => SymbolFlags::IS_USED, ast::ExprContext::Load => SymbolFlags::IS_USED,
ast::ExprContext::Store => SymbolFlags::IS_DEFINED, ast::ExprContext::Store => SymbolFlags::IS_DEFINED,
ast::ExprContext::Del => SymbolFlags::IS_DEFINED, ast::ExprContext::Del => SymbolFlags::IS_DEFINED,
ast::ExprContext::Invalid => SymbolFlags::empty(), ast::ExprContext::Invalid => SymbolFlags::empty(),
}; };
if matches!(
self.current_assignment,
Some(CurrentAssignment::AugAssign(_))
) && !ctx.is_invalid()
{
// For augmented assignment, the target expression is also used, so we should
// record that as a use.
flags |= SymbolFlags::IS_USED;
}
let symbol = self.add_or_update_symbol(id.clone(), flags); let symbol = self.add_or_update_symbol(id.clone(), flags);
if flags.contains(SymbolFlags::IS_DEFINED) { if flags.contains(SymbolFlags::IS_DEFINED) {
match self.current_assignment { match self.current_assignment {
@ -584,6 +607,9 @@ where
Some(CurrentAssignment::AnnAssign(ann_assign)) => { Some(CurrentAssignment::AnnAssign(ann_assign)) => {
self.add_definition(symbol, ann_assign); self.add_definition(symbol, ann_assign);
} }
Some(CurrentAssignment::AugAssign(aug_assign)) => {
self.add_definition(symbol, aug_assign);
}
Some(CurrentAssignment::Named(named)) => { Some(CurrentAssignment::Named(named)) => {
// TODO(dhruvmanila): If the current scope is a comprehension, then the // TODO(dhruvmanila): If the current scope is a comprehension, then the
// named expression is implicitly nonlocal. This is yet to be // named expression is implicitly nonlocal. This is yet to be
@ -727,6 +753,7 @@ where
enum CurrentAssignment<'a> { enum CurrentAssignment<'a> {
Assign(&'a ast::StmtAssign), Assign(&'a ast::StmtAssign),
AnnAssign(&'a ast::StmtAnnAssign), AnnAssign(&'a ast::StmtAnnAssign),
AugAssign(&'a ast::StmtAugAssign),
Named(&'a ast::ExprNamed), Named(&'a ast::ExprNamed),
Comprehension { Comprehension {
node: &'a ast::Comprehension, node: &'a ast::Comprehension,
@ -746,6 +773,12 @@ impl<'a> From<&'a ast::StmtAnnAssign> for CurrentAssignment<'a> {
} }
} }
impl<'a> From<&'a ast::StmtAugAssign> for CurrentAssignment<'a> {
fn from(value: &'a ast::StmtAugAssign) -> Self {
Self::AugAssign(value)
}
}
impl<'a> From<&'a ast::ExprNamed> for CurrentAssignment<'a> { impl<'a> From<&'a ast::ExprNamed> for CurrentAssignment<'a> {
fn from(value: &'a ast::ExprNamed) -> Self { fn from(value: &'a ast::ExprNamed) -> Self {
Self::Named(value) Self::Named(value)

View file

@ -44,6 +44,7 @@ pub(crate) enum DefinitionNodeRef<'a> {
NamedExpression(&'a ast::ExprNamed), NamedExpression(&'a ast::ExprNamed),
Assignment(AssignmentDefinitionNodeRef<'a>), Assignment(AssignmentDefinitionNodeRef<'a>),
AnnotatedAssignment(&'a ast::StmtAnnAssign), AnnotatedAssignment(&'a ast::StmtAnnAssign),
AugmentedAssignment(&'a ast::StmtAugAssign),
Comprehension(ComprehensionDefinitionNodeRef<'a>), Comprehension(ComprehensionDefinitionNodeRef<'a>),
Parameter(ast::AnyParameterRef<'a>), Parameter(ast::AnyParameterRef<'a>),
} }
@ -72,6 +73,12 @@ impl<'a> From<&'a ast::StmtAnnAssign> for DefinitionNodeRef<'a> {
} }
} }
impl<'a> From<&'a ast::StmtAugAssign> for DefinitionNodeRef<'a> {
fn from(node: &'a ast::StmtAugAssign) -> Self {
Self::AugmentedAssignment(node)
}
}
impl<'a> From<&'a ast::Alias> for DefinitionNodeRef<'a> { impl<'a> From<&'a ast::Alias> for DefinitionNodeRef<'a> {
fn from(node_ref: &'a ast::Alias) -> Self { fn from(node_ref: &'a ast::Alias) -> Self {
Self::Import(node_ref) Self::Import(node_ref)
@ -151,6 +158,9 @@ impl DefinitionNodeRef<'_> {
DefinitionNodeRef::AnnotatedAssignment(assign) => { DefinitionNodeRef::AnnotatedAssignment(assign) => {
DefinitionKind::AnnotatedAssignment(AstNodeRef::new(parsed, assign)) DefinitionKind::AnnotatedAssignment(AstNodeRef::new(parsed, assign))
} }
DefinitionNodeRef::AugmentedAssignment(augmented_assignment) => {
DefinitionKind::AugmentedAssignment(AstNodeRef::new(parsed, augmented_assignment))
}
DefinitionNodeRef::Comprehension(ComprehensionDefinitionNodeRef { node, first }) => { DefinitionNodeRef::Comprehension(ComprehensionDefinitionNodeRef { node, first }) => {
DefinitionKind::Comprehension(ComprehensionDefinitionKind { DefinitionKind::Comprehension(ComprehensionDefinitionKind {
node: AstNodeRef::new(parsed, node), node: AstNodeRef::new(parsed, node),
@ -182,6 +192,7 @@ impl DefinitionNodeRef<'_> {
target, target,
}) => target.into(), }) => target.into(),
Self::AnnotatedAssignment(node) => node.into(), Self::AnnotatedAssignment(node) => node.into(),
Self::AugmentedAssignment(node) => node.into(),
Self::Comprehension(ComprehensionDefinitionNodeRef { node, first: _ }) => node.into(), Self::Comprehension(ComprehensionDefinitionNodeRef { node, first: _ }) => node.into(),
Self::Parameter(node) => match node { Self::Parameter(node) => match node {
ast::AnyParameterRef::Variadic(parameter) => parameter.into(), ast::AnyParameterRef::Variadic(parameter) => parameter.into(),
@ -200,6 +211,7 @@ pub enum DefinitionKind {
NamedExpression(AstNodeRef<ast::ExprNamed>), NamedExpression(AstNodeRef<ast::ExprNamed>),
Assignment(AssignmentDefinitionKind), Assignment(AssignmentDefinitionKind),
AnnotatedAssignment(AstNodeRef<ast::StmtAnnAssign>), AnnotatedAssignment(AstNodeRef<ast::StmtAnnAssign>),
AugmentedAssignment(AstNodeRef<ast::StmtAugAssign>),
Comprehension(ComprehensionDefinitionKind), Comprehension(ComprehensionDefinitionKind),
Parameter(AstNodeRef<ast::Parameter>), Parameter(AstNodeRef<ast::Parameter>),
ParameterWithDefault(AstNodeRef<ast::ParameterWithDefault>), ParameterWithDefault(AstNodeRef<ast::ParameterWithDefault>),
@ -293,6 +305,12 @@ impl From<&ast::StmtAnnAssign> for DefinitionNodeKey {
} }
} }
impl From<&ast::StmtAugAssign> for DefinitionNodeKey {
fn from(node: &ast::StmtAugAssign) -> Self {
Self(NodeKey::from_node(node))
}
}
impl From<&ast::Comprehension> for DefinitionNodeKey { impl From<&ast::Comprehension> for DefinitionNodeKey {
fn from(node: &ast::Comprehension) -> Self { fn from(node: &ast::Comprehension) -> Self {
Self(NodeKey::from_node(node)) Self(NodeKey::from_node(node))

View file

@ -303,6 +303,9 @@ impl<'db> TypeInferenceBuilder<'db> {
DefinitionKind::AnnotatedAssignment(annotated_assignment) => { DefinitionKind::AnnotatedAssignment(annotated_assignment) => {
self.infer_annotated_assignment_definition(annotated_assignment.node(), definition); self.infer_annotated_assignment_definition(annotated_assignment.node(), definition);
} }
DefinitionKind::AugmentedAssignment(augmented_assignment) => {
self.infer_augment_assignment_definition(augmented_assignment.node(), definition);
}
DefinitionKind::NamedExpression(named_expression) => { DefinitionKind::NamedExpression(named_expression) => {
self.infer_named_expression_definition(named_expression.node(), definition); self.infer_named_expression_definition(named_expression.node(), definition);
} }
@ -763,15 +766,35 @@ impl<'db> TypeInferenceBuilder<'db> {
} }
fn infer_augmented_assignment_statement(&mut self, assignment: &ast::StmtAugAssign) { fn infer_augmented_assignment_statement(&mut self, assignment: &ast::StmtAugAssign) {
// TODO this should be a Definition if assignment.target.is_name_expr() {
self.infer_definition(assignment);
} else {
// TODO currently we don't consider assignments to non-Names to be Definitions
self.infer_augment_assignment(assignment);
}
}
fn infer_augment_assignment_definition(
&mut self,
assignment: &ast::StmtAugAssign,
definition: Definition<'db>,
) {
let target_ty = self.infer_augment_assignment(assignment);
self.types.definitions.insert(definition, target_ty);
}
fn infer_augment_assignment(&mut self, assignment: &ast::StmtAugAssign) -> Type<'db> {
let ast::StmtAugAssign { let ast::StmtAugAssign {
range: _, range: _,
target, target,
op: _, op: _,
value, value,
} = assignment; } = assignment;
self.infer_expression(target);
self.infer_expression(value); self.infer_expression(value);
self.infer_expression(target);
// TODO(dhruvmanila): Resolve the target type using the value type and the operator
Type::Unknown
} }
fn infer_type_alias_statement(&mut self, type_alias_statement: &ast::StmtTypeAlias) { fn infer_type_alias_statement(&mut self, type_alias_statement: &ast::StmtTypeAlias) {