From b16f665a811f2a76cc55fc5cf4937e35a82d7537 Mon Sep 17 00:00:00 2001 From: Dhruv Manilawala Date: Wed, 16 Oct 2024 00:37:11 +0530 Subject: [PATCH] [red-knot] Infer target types for unpacked tuple assignment (#13316) ## Summary This PR adds support for unpacking tuple expression in an assignment statement where the target expression can be a tuple or a list (the allowed sequence targets). The implementation introduces a new `infer_assignment_target` which can then be used for other targets like the ones in for loops as well. This delegates it to the `infer_definition`. The final implementation uses a recursive function that visits the target expression in source order and compares the variable node that corresponds to the definition. At the same time, it keeps track of where it is on the assignment value type. The logic also accounts for the number of elements on both sides such that it matches even if there's a gap in between. For example, if there's a starred expression like `(a, *b, c) = (1, 2, 3)`, then the type of `a` will be `Literal[1]` and the type of `b` will be `Literal[2]`. There are a couple of follow-ups that can be done: * Use this logic for other target positions like `for` loop * Add diagnostics for mis-match length between LHS and RHS ## Test Plan Add various test cases using the new markdown test framework. Validate that existing test cases pass. --------- Co-authored-by: Carl Meyer --- .../resources/mdtest/unpacking.md | 273 ++++++++++++++++++ .../src/semantic_index.rs | 2 +- .../src/semantic_index/builder.rs | 43 ++- .../src/semantic_index/definition.rs | 54 +++- crates/red_knot_python_semantic/src/types.rs | 16 + .../src/types/infer.rs | 163 ++++++++++- crates/red_knot_workspace/tests/check.rs | 20 +- 7 files changed, 525 insertions(+), 46 deletions(-) create mode 100644 crates/red_knot_python_semantic/resources/mdtest/unpacking.md diff --git a/crates/red_knot_python_semantic/resources/mdtest/unpacking.md b/crates/red_knot_python_semantic/resources/mdtest/unpacking.md new file mode 100644 index 0000000000..974d672b48 --- /dev/null +++ b/crates/red_knot_python_semantic/resources/mdtest/unpacking.md @@ -0,0 +1,273 @@ +# Unpacking + +## Tuple + +### Simple tuple + +```py +(a, b, c) = (1, 2, 3) +reveal_type(a) # revealed: Literal[1] +reveal_type(b) # revealed: Literal[2] +reveal_type(c) # revealed: Literal[3] +``` + +### Simple list + +```py +[a, b, c] = (1, 2, 3) +reveal_type(a) # revealed: Literal[1] +reveal_type(b) # revealed: Literal[2] +reveal_type(c) # revealed: Literal[3] +``` + +### Simple mixed + +```py +[a, (b, c), d] = (1, (2, 3), 4) +reveal_type(a) # revealed: Literal[1] +reveal_type(b) # revealed: Literal[2] +reveal_type(c) # revealed: Literal[3] +reveal_type(d) # revealed: Literal[4] +``` + +### Multiple assignment + +```py +a, b = c = 1, 2 +reveal_type(a) # revealed: Literal[1] +reveal_type(b) # revealed: Literal[2] +reveal_type(c) # revealed: tuple[Literal[1], Literal[2]] +``` + +### Nested tuple with unpacking + +```py +(a, (b, c), d) = (1, (2, 3), 4) +reveal_type(a) # revealed: Literal[1] +reveal_type(b) # revealed: Literal[2] +reveal_type(c) # revealed: Literal[3] +reveal_type(d) # revealed: Literal[4] +``` + +### Nested tuple without unpacking + +```py +(a, b, c) = (1, (2, 3), 4) +reveal_type(a) # revealed: Literal[1] +reveal_type(b) # revealed: tuple[Literal[2], Literal[3]] +reveal_type(c) # revealed: Literal[4] +``` + +### Uneven unpacking (1) + +```py +# TODO: Add diagnostic (there aren't enough values to unpack) +(a, b, c) = (1, 2) +reveal_type(a) # revealed: Literal[1] +reveal_type(b) # revealed: Literal[2] +reveal_type(c) # revealed: Unknown +``` + +### Uneven unpacking (2) + +```py +# TODO: Add diagnostic (too many values to unpack) +(a, b) = (1, 2, 3) +reveal_type(a) # revealed: Literal[1] +reveal_type(b) # revealed: Literal[2] +``` + +### Starred expression (1) + +```py +# TODO: Add diagnostic (need more values to unpack) +# TODO: Remove 'not-iterable' diagnostic +[a, *b, c, d] = (1, 2) # error: "Object of type `None` is not iterable" +reveal_type(a) # revealed: Literal[1] +# TODO: Should be list[Any] once support for assigning to starred expression is added +reveal_type(b) # revealed: @Todo +reveal_type(c) # revealed: Literal[2] +reveal_type(d) # revealed: Unknown +``` + +### Starred expression (2) + +```py +[a, *b, c] = (1, 2) # error: "Object of type `None` is not iterable" +reveal_type(a) # revealed: Literal[1] +# TODO: Should be list[Any] once support for assigning to starred expression is added +reveal_type(b) # revealed: @Todo +reveal_type(c) # revealed: Literal[2] +``` + +### Starred expression (3) + +```py +# TODO: Remove 'not-iterable' diagnostic +[a, *b, c] = (1, 2, 3) # error: "Object of type `None` is not iterable" +reveal_type(a) # revealed: Literal[1] +# TODO: Should be list[int] once support for assigning to starred expression is added +reveal_type(b) # revealed: @Todo +reveal_type(c) # revealed: Literal[3] +``` + +### Starred expression (4) + +```py +# TODO: Remove 'not-iterable' diagnostic +[a, *b, c, d] = (1, 2, 3, 4, 5, 6) # error: "Object of type `None` is not iterable" +reveal_type(a) # revealed: Literal[1] +# TODO: Should be list[int] once support for assigning to starred expression is added +reveal_type(b) # revealed: @Todo +reveal_type(c) # revealed: Literal[5] +reveal_type(d) # revealed: Literal[6] +``` + +### Starred expression (5) + +```py +# TODO: Remove 'not-iterable' diagnostic +[a, b, *c] = (1, 2, 3, 4) # error: "Object of type `None` is not iterable" +reveal_type(a) # revealed: Literal[1] +reveal_type(b) # revealed: Literal[2] +# TODO: Should be list[int] once support for assigning to starred expression is added +reveal_type(c) # revealed: @Todo +``` + +### Non-iterable unpacking + +TODO: Remove duplicate diagnostics. This is happening because for a sequence-like +assignment target, multiple definitions are created and the inference engine runs +on each of them which results in duplicate diagnostics. + +```py +# error: "Object of type `Literal[1]` is not iterable" +# error: "Object of type `Literal[1]` is not iterable" +a, b = 1 +reveal_type(a) # revealed: Unknown +reveal_type(b) # revealed: Unknown +``` + +### Custom iterator unpacking + +```py +class Iterator: + def __next__(self) -> int: + return 42 + + +class Iterable: + def __iter__(self) -> Iterator: + return Iterator() + + +(a, b) = Iterable() +reveal_type(a) # revealed: int +reveal_type(b) # revealed: int +``` + +### Custom iterator unpacking nested + +```py +class Iterator: + def __next__(self) -> int: + return 42 + + +class Iterable: + def __iter__(self) -> Iterator: + return Iterator() + + +(a, (b, c), d) = (1, Iterable(), 2) +reveal_type(a) # revealed: Literal[1] +reveal_type(b) # revealed: int +reveal_type(c) # revealed: int +reveal_type(d) # revealed: Literal[2] +``` + +## String + +### Simple unpacking + +```py +a, b = 'ab' +reveal_type(a) # revealed: LiteralString +reveal_type(b) # revealed: LiteralString +``` + +### Uneven unpacking (1) + +```py +# TODO: Add diagnostic (there aren't enough values to unpack) +a, b, c = 'ab' +reveal_type(a) # revealed: LiteralString +reveal_type(b) # revealed: LiteralString +reveal_type(c) # revealed: Unknown +``` + +### Uneven unpacking (2) + +```py +# TODO: Add diagnostic (too many values to unpack) +a, b = 'abc' +reveal_type(a) # revealed: LiteralString +reveal_type(b) # revealed: LiteralString +``` + +### Starred expression (1) + +```py +# TODO: Add diagnostic (need more values to unpack) +# TODO: Remove 'not-iterable' diagnostic +(a, *b, c, d) = "ab" # error: "Object of type `None` is not iterable" +reveal_type(a) # revealed: LiteralString +# TODO: Should be list[LiteralString] once support for assigning to starred expression is added +reveal_type(b) # revealed: @Todo +reveal_type(c) # revealed: LiteralString +reveal_type(d) # revealed: Unknown +``` + +### Starred expression (2) + +```py +(a, *b, c) = "ab" # error: "Object of type `None` is not iterable" +reveal_type(a) # revealed: LiteralString +# TODO: Should be list[Any] once support for assigning to starred expression is added +reveal_type(b) # revealed: @Todo +reveal_type(c) # revealed: LiteralString +``` + +### Starred expression (3) + +```py +# TODO: Remove 'not-iterable' diagnostic +(a, *b, c) = "abc" # error: "Object of type `None` is not iterable" +reveal_type(a) # revealed: LiteralString +# TODO: Should be list[LiteralString] once support for assigning to starred expression is added +reveal_type(b) # revealed: @Todo +reveal_type(c) # revealed: LiteralString +``` + +### Starred expression (4) + +```py +# TODO: Remove 'not-iterable' diagnostic +(a, *b, c, d) = "abcdef" # error: "Object of type `None` is not iterable" +reveal_type(a) # revealed: LiteralString +# TODO: Should be list[LiteralString] once support for assigning to starred expression is added +reveal_type(b) # revealed: @Todo +reveal_type(c) # revealed: LiteralString +reveal_type(d) # revealed: LiteralString +``` + +### Starred expression (5) + +```py +# TODO: Remove 'not-iterable' diagnostic +(a, b, *c) = "abcd" # error: "Object of type `None` is not iterable" +reveal_type(a) # revealed: LiteralString +reveal_type(b) # revealed: LiteralString +# TODO: Should be list[int] once support for assigning to starred expression is added +reveal_type(c) # revealed: @Todo +``` diff --git a/crates/red_knot_python_semantic/src/semantic_index.rs b/crates/red_knot_python_semantic/src/semantic_index.rs index ac5463053d..083779dc0c 100644 --- a/crates/red_knot_python_semantic/src/semantic_index.rs +++ b/crates/red_knot_python_semantic/src/semantic_index.rs @@ -994,7 +994,7 @@ class C[T]: let ast::Expr::NumberLiteral(ast::ExprNumberLiteral { value: ast::Number::Int(num), .. - }) = &*assignment.assignment().value + }) = assignment.value() else { panic!("should be a number literal") }; diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index 56df1c44d9..1c5c03e79a 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -28,8 +28,8 @@ use crate::Db; use super::constraint::{Constraint, PatternConstraint}; use super::definition::{ - DefinitionCategory, ExceptHandlerDefinitionNodeRef, MatchPatternDefinitionNodeRef, - WithItemDefinitionNodeRef, + AssignmentKind, DefinitionCategory, ExceptHandlerDefinitionNodeRef, + MatchPatternDefinitionNodeRef, WithItemDefinitionNodeRef, }; pub(super) struct SemanticIndexBuilder<'db> { @@ -566,11 +566,22 @@ where debug_assert!(self.current_assignment.is_none()); self.visit_expr(&node.value); self.add_standalone_expression(&node.value); - self.current_assignment = Some(node.into()); - for target in &node.targets { + for (target_index, target) in node.targets.iter().enumerate() { + let kind = match target { + ast::Expr::List(_) | ast::Expr::Tuple(_) => Some(AssignmentKind::Sequence), + ast::Expr::Name(_) => Some(AssignmentKind::Name), + _ => None, + }; + if let Some(kind) = kind { + self.current_assignment = Some(CurrentAssignment::Assign { + assignment: node, + target_index, + kind, + }); + } self.visit_expr(target); + self.current_assignment = None; } - self.current_assignment = None; } ast::Stmt::AnnAssign(node) => { debug_assert!(self.current_assignment.is_none()); @@ -815,12 +826,18 @@ where let symbol = self.add_symbol(id.clone()); if is_definition { match self.current_assignment { - Some(CurrentAssignment::Assign(assignment)) => { + Some(CurrentAssignment::Assign { + assignment, + target_index, + kind, + }) => { self.add_definition( symbol, AssignmentDefinitionNodeRef { assignment, - target: name_node, + target_index, + name: name_node, + kind, }, ); } @@ -1045,7 +1062,11 @@ where #[derive(Copy, Clone, Debug)] enum CurrentAssignment<'a> { - Assign(&'a ast::StmtAssign), + Assign { + assignment: &'a ast::StmtAssign, + target_index: usize, + kind: AssignmentKind, + }, AnnAssign(&'a ast::StmtAnnAssign), AugAssign(&'a ast::StmtAugAssign), For(&'a ast::StmtFor), @@ -1057,12 +1078,6 @@ enum CurrentAssignment<'a> { WithItem(&'a ast::WithItem), } -impl<'a> From<&'a ast::StmtAssign> for CurrentAssignment<'a> { - fn from(value: &'a ast::StmtAssign) -> Self { - Self::Assign(value) - } -} - impl<'a> From<&'a ast::StmtAnnAssign> for CurrentAssignment<'a> { fn from(value: &'a ast::StmtAnnAssign) -> Self { Self::AnnAssign(value) diff --git a/crates/red_knot_python_semantic/src/semantic_index/definition.rs b/crates/red_knot_python_semantic/src/semantic_index/definition.rs index 0104515af8..20a6647b3c 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/definition.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/definition.rs @@ -161,7 +161,9 @@ pub(crate) struct ImportFromDefinitionNodeRef<'a> { #[derive(Copy, Clone, Debug)] pub(crate) struct AssignmentDefinitionNodeRef<'a> { pub(crate) assignment: &'a ast::StmtAssign, - pub(crate) target: &'a ast::ExprName, + pub(crate) target_index: usize, + pub(crate) name: &'a ast::ExprName, + pub(crate) kind: AssignmentKind, } #[derive(Copy, Clone, Debug)] @@ -224,12 +226,17 @@ impl DefinitionNodeRef<'_> { DefinitionNodeRef::NamedExpression(named) => { DefinitionKind::NamedExpression(AstNodeRef::new(parsed, named)) } - DefinitionNodeRef::Assignment(AssignmentDefinitionNodeRef { assignment, target }) => { - DefinitionKind::Assignment(AssignmentDefinitionKind { - assignment: AstNodeRef::new(parsed.clone(), assignment), - target: AstNodeRef::new(parsed, target), - }) - } + DefinitionNodeRef::Assignment(AssignmentDefinitionNodeRef { + assignment, + target_index, + name, + kind, + }) => DefinitionKind::Assignment(AssignmentDefinitionKind { + assignment: AstNodeRef::new(parsed.clone(), assignment), + target_index, + name: AstNodeRef::new(parsed, name), + kind, + }), DefinitionNodeRef::AnnotatedAssignment(assign) => { DefinitionKind::AnnotatedAssignment(AstNodeRef::new(parsed, assign)) } @@ -300,8 +307,10 @@ impl DefinitionNodeRef<'_> { Self::NamedExpression(node) => node.into(), Self::Assignment(AssignmentDefinitionNodeRef { assignment: _, - target, - }) => target.into(), + target_index: _, + name, + kind: _, + }) => name.into(), Self::AnnotatedAssignment(node) => node.into(), Self::AugmentedAssignment(node) => node.into(), Self::For(ForStmtDefinitionNodeRef { @@ -485,17 +494,34 @@ impl ImportFromDefinitionKind { #[derive(Clone, Debug)] pub struct AssignmentDefinitionKind { assignment: AstNodeRef, - target: AstNodeRef, + target_index: usize, + name: AstNodeRef, + kind: AssignmentKind, } impl AssignmentDefinitionKind { - pub(crate) fn assignment(&self) -> &ast::StmtAssign { - self.assignment.node() + pub(crate) fn value(&self) -> &ast::Expr { + &self.assignment.node().value } - pub(crate) fn target(&self) -> &ast::ExprName { - self.target.node() + pub(crate) fn target(&self) -> &ast::Expr { + &self.assignment.node().targets[self.target_index] } + + pub(crate) fn name(&self) -> &ast::ExprName { + self.name.node() + } + + pub(crate) fn kind(&self) -> AssignmentKind { + self.kind + } +} + +/// The kind of assignment target expression. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum AssignmentKind { + Sequence, + Name, } #[derive(Clone, Debug)] diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 3aa6fe097d..447212d384 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -1510,6 +1510,12 @@ pub struct StringLiteralType<'db> { value: Box, } +impl<'db> StringLiteralType<'db> { + pub fn len(&self, db: &'db dyn Db) -> usize { + self.value(db).len() + } +} + #[salsa::interned] pub struct BytesLiteralType<'db> { #[return_ref] @@ -1522,6 +1528,16 @@ pub struct TupleType<'db> { elements: Box<[Type<'db>]>, } +impl<'db> TupleType<'db> { + pub fn get(&self, db: &'db dyn Db, index: usize) -> Option> { + self.elements(db).get(index).copied() + } + + pub fn len(&self, db: &'db dyn Db) -> usize { + self.elements(db).len() + } +} + #[cfg(test)] mod tests { use super::{ diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 24d292af63..d83b5a8ba6 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -27,6 +27,7 @@ //! associated with a particular definition. Scope-level inference infers deferred types for all //! definitions once the rest of the types in the scope have been inferred. use itertools::Itertools; +use std::borrow::Cow; use std::num::NonZeroU32; use ruff_db::files::File; @@ -41,7 +42,7 @@ use crate::module_name::ModuleName; use crate::module_resolver::{file_to_module, resolve_module}; use crate::semantic_index::ast_ids::{HasScopedAstId, HasScopedUseId, ScopedExpressionId}; use crate::semantic_index::definition::{ - Definition, DefinitionKind, DefinitionNodeKey, ExceptHandlerDefinitionKind, + AssignmentKind, Definition, DefinitionKind, DefinitionNodeKey, ExceptHandlerDefinitionKind, }; use crate::semantic_index::expression::Expression; use crate::semantic_index::semantic_index; @@ -415,7 +416,9 @@ impl<'db> TypeInferenceBuilder<'db> { DefinitionKind::Assignment(assignment) => { self.infer_assignment_definition( assignment.target(), - assignment.assignment(), + assignment.value(), + assignment.name(), + assignment.kind(), definition, ); } @@ -1151,13 +1154,23 @@ impl<'db> TypeInferenceBuilder<'db> { } = assignment; for target in targets { - if let ast::Expr::Name(name) = target { - self.infer_definition(name); - } else { - // TODO infer definitions in unpacking assignment. When we do, this duplication of - // the "get `Expression`, call `infer_expression_types` on it, `self.extend`" dance - // will be removed; it'll all happen in `infer_assignment_definition` instead. - let expression = self.index.expression(value.as_ref()); + self.infer_assignment_target(target, value); + } + } + + // TODO: Remove the `value` argument once we handle all possible assignment targets. + fn infer_assignment_target(&mut self, target: &ast::Expr, value: &ast::Expr) { + match target { + ast::Expr::Name(name) => self.infer_definition(name), + ast::Expr::List(ast::ExprList { elts, .. }) + | ast::Expr::Tuple(ast::ExprTuple { elts, .. }) => { + for element in elts { + self.infer_assignment_target(element, value); + } + } + _ => { + // TODO: Remove this once we handle all possible assignment targets. + let expression = self.index.expression(value); self.extend(infer_expression_types(self.db, expression)); self.infer_expression(target); } @@ -1166,18 +1179,138 @@ impl<'db> TypeInferenceBuilder<'db> { fn infer_assignment_definition( &mut self, - target: &ast::ExprName, - assignment: &ast::StmtAssign, + target: &ast::Expr, + value: &ast::Expr, + name: &ast::ExprName, + kind: AssignmentKind, definition: Definition<'db>, ) { - let expression = self.index.expression(assignment.value.as_ref()); + let expression = self.index.expression(value); let result = infer_expression_types(self.db, expression); self.extend(result); - let value_ty = self.expression_ty(&assignment.value); - self.add_binding(assignment.into(), definition, value_ty); + + let value_ty = self.expression_ty(value); + + let target_ty = match kind { + AssignmentKind::Sequence => self.infer_sequence_unpacking(target, value_ty, name), + AssignmentKind::Name => value_ty, + }; + + self.add_binding(name.into(), definition, target_ty); self.types .expressions - .insert(target.scoped_ast_id(self.db, self.scope), value_ty); + .insert(name.scoped_ast_id(self.db, self.scope), target_ty); + } + + fn infer_sequence_unpacking( + &mut self, + target: &ast::Expr, + value_ty: Type<'db>, + name: &ast::ExprName, + ) -> Type<'db> { + // The inner function is recursive and only differs in the return type which is an `Option` + // where if the variable is found, the corresponding type is returned otherwise `None`. + fn inner<'db>( + builder: &mut TypeInferenceBuilder<'db>, + target: &ast::Expr, + value_ty: Type<'db>, + name: &ast::ExprName, + ) -> Option> { + match target { + ast::Expr::Name(target_name) if target_name == name => { + return Some(value_ty); + } + ast::Expr::Starred(ast::ExprStarred { value, .. }) => { + return inner(builder, value, value_ty, name); + } + ast::Expr::List(ast::ExprList { elts, .. }) + | ast::Expr::Tuple(ast::ExprTuple { elts, .. }) => match value_ty { + Type::Tuple(tuple_ty) => { + let starred_index = elts.iter().position(ast::Expr::is_starred_expr); + + let element_types = if let Some(starred_index) = starred_index { + if tuple_ty.len(builder.db) >= elts.len() - 1 { + let mut element_types = Vec::with_capacity(elts.len()); + element_types.extend_from_slice( + // SAFETY: Safe because of the length check above. + &tuple_ty.elements(builder.db)[..starred_index], + ); + + // E.g., in `(a, *b, c, d) = ...`, the index of starred element `b` + // is 1 and the remaining elements after that are 2. + let remaining = elts.len() - (starred_index + 1); + // This index represents the type of the last element that belongs + // to the starred expression, in an exclusive manner. + let starred_end_index = tuple_ty.len(builder.db) - remaining; + // SAFETY: Safe because of the length check above. + let _starred_element_types = &tuple_ty.elements(builder.db) + [starred_index..starred_end_index]; + // TODO: Combine the types into a list type. If the + // starred_element_types is empty, then it should be `List[Any]`. + // combine_types(starred_element_types); + element_types.push(Type::Todo); + + element_types.extend_from_slice( + // SAFETY: Safe because of the length check above. + &tuple_ty.elements(builder.db)[starred_end_index..], + ); + Cow::Owned(element_types) + } else { + let mut element_types = tuple_ty.elements(builder.db).to_vec(); + element_types.insert(starred_index, Type::Todo); + Cow::Owned(element_types) + } + } else { + Cow::Borrowed(tuple_ty.elements(builder.db).as_ref()) + }; + + for (index, element) in elts.iter().enumerate() { + if let Some(ty) = inner( + builder, + element, + element_types.get(index).copied().unwrap_or(Type::Unknown), + name, + ) { + return Some(ty); + } + } + } + Type::StringLiteral(string_literal_ty) => { + // Deconstruct the string literal to delegate the inference back to the + // tuple type for correct handling of starred expressions. We could go + // further and deconstruct to an array of `StringLiteral` with each + // individual character, instead of just an array of `LiteralString`, but + // there would be a cost and it's not clear that it's worth it. + let value_ty = Type::Tuple(TupleType::new( + builder.db, + vec![Type::LiteralString; string_literal_ty.len(builder.db)] + .into_boxed_slice(), + )); + if let Some(ty) = inner(builder, target, value_ty, name) { + return Some(ty); + } + } + _ => { + let value_ty = if matches!(value_ty, Type::LiteralString) { + Type::LiteralString + } else { + value_ty + .iterate(builder.db) + .unwrap_with_diagnostic(AnyNodeRef::from(target), builder) + }; + for element in elts { + if let Some(ty) = inner(builder, element, value_ty, name) { + return Some(ty); + } + } + } + }, + _ => {} + } + None + } + + inner(self, target, value_ty, name).unwrap_or(Type::Unknown) } fn infer_annotated_assignment_statement(&mut self, assignment: &ast::StmtAnnAssign) { diff --git a/crates/red_knot_workspace/tests/check.rs b/crates/red_knot_workspace/tests/check.rs index cf0404c3d1..6097247590 100644 --- a/crates/red_knot_workspace/tests/check.rs +++ b/crates/red_knot_workspace/tests/check.rs @@ -9,7 +9,7 @@ use ruff_db::parsed::parsed_module; use ruff_db::system::{OsSystem, SystemPath, SystemPathBuf}; use ruff_python_ast::visitor::source_order; use ruff_python_ast::visitor::source_order::SourceOrderVisitor; -use ruff_python_ast::{Alias, Expr, Parameter, ParameterWithDefault, Stmt}; +use ruff_python_ast::{self as ast, Alias, Expr, Parameter, ParameterWithDefault, Stmt}; fn setup_db(workspace_root: &SystemPath) -> anyhow::Result { let system = OsSystem::new(workspace_root); @@ -65,6 +65,17 @@ impl<'db> PullTypesVisitor<'db> { model: SemanticModel::new(db, file), } } + + fn visit_assign_target(&mut self, target: &Expr) { + match target { + Expr::List(ast::ExprList { elts, .. }) | Expr::Tuple(ast::ExprTuple { elts, .. }) => { + for element in elts { + self.visit_assign_target(element); + } + } + _ => self.visit_expr(target), + } + } } impl SourceOrderVisitor<'_> for PullTypesVisitor<'_> { @@ -76,10 +87,15 @@ impl SourceOrderVisitor<'_> for PullTypesVisitor<'_> { Stmt::ClassDef(class) => { let _ty = class.ty(&self.model); } + Stmt::Assign(assign) => { + for target in &assign.targets { + self.visit_assign_target(target); + } + return; + } Stmt::AnnAssign(_) | Stmt::Return(_) | Stmt::Delete(_) - | Stmt::Assign(_) | Stmt::AugAssign(_) | Stmt::TypeAlias(_) | Stmt::For(_)