diff --git a/crates/red_knot_project/tests/check.rs b/crates/red_knot_project/tests/check.rs index 091908cf0c..276612c3d8 100644 --- a/crates/red_knot_project/tests/check.rs +++ b/crates/red_knot_project/tests/check.rs @@ -6,7 +6,9 @@ use ruff_db::parsed::parsed_module; use ruff_db::system::{SystemPath, SystemPathBuf, TestSystem}; use ruff_python_ast::visitor::source_order; use ruff_python_ast::visitor::source_order::SourceOrderVisitor; -use ruff_python_ast::{self as ast, Alias, Expr, Parameter, ParameterWithDefault, Stmt}; +use ruff_python_ast::{ + self as ast, Alias, Comprehension, Expr, Parameter, ParameterWithDefault, Stmt, +}; fn setup_db(project_root: &SystemPath, system: TestSystem) -> anyhow::Result { let project = ProjectMetadata::discover(project_root, &system)?; @@ -258,6 +260,14 @@ impl SourceOrderVisitor<'_> for PullTypesVisitor<'_> { source_order::walk_expr(self, expr); } + fn visit_comprehension(&mut self, comprehension: &Comprehension) { + self.visit_expr(&comprehension.iter); + self.visit_target(&comprehension.target); + for if_expr in &comprehension.ifs { + self.visit_expr(if_expr); + } + } + fn visit_parameter(&mut self, parameter: &Parameter) { let _ty = parameter.inferred_type(&self.model); diff --git a/crates/red_knot_python_semantic/resources/mdtest/attributes.md b/crates/red_knot_python_semantic/resources/mdtest/attributes.md index 37f0961461..5077b0eb8f 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/attributes.md +++ b/crates/red_knot_python_semantic/resources/mdtest/attributes.md @@ -397,15 +397,27 @@ class IntIterable: def __iter__(self) -> IntIterator: return IntIterator() +class TupleIterator: + def __next__(self) -> tuple[int, str]: + return (1, "a") + +class TupleIterable: + def __iter__(self) -> TupleIterator: + return TupleIterator() + class C: def __init__(self) -> None: [... for self.a in IntIterable()] + [... for (self.b, self.c) in TupleIterable()] + [... for self.d in IntIterable() for self.e in IntIterable()] c_instance = C() -# TODO: Should be `Unknown | int` -# error: [unresolved-attribute] -reveal_type(c_instance.a) # revealed: Unknown +reveal_type(c_instance.a) # revealed: Unknown | int +reveal_type(c_instance.b) # revealed: Unknown | int +reveal_type(c_instance.c) # revealed: Unknown | str +reveal_type(c_instance.d) # revealed: Unknown | int +reveal_type(c_instance.e) # revealed: Unknown | int ``` #### Conditionally declared / bound attributes diff --git a/crates/red_knot_python_semantic/resources/mdtest/unpacking.md b/crates/red_knot_python_semantic/resources/mdtest/unpacking.md index 50a7a64388..6a0f375737 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/unpacking.md +++ b/crates/red_knot_python_semantic/resources/mdtest/unpacking.md @@ -708,3 +708,95 @@ with ContextManager() as (a, b, c): reveal_type(b) # revealed: Unknown reveal_type(c) # revealed: Unknown ``` + +## Comprehension + +Unpacking in a comprehension. + +### Same types + +```py +def _(arg: tuple[tuple[int, int], tuple[int, int]]): + # revealed: tuple[int, int] + [reveal_type((a, b)) for a, b in arg] +``` + +### Mixed types (1) + +```py +def _(arg: tuple[tuple[int, int], tuple[int, str]]): + # revealed: tuple[int, int | str] + [reveal_type((a, b)) for a, b in arg] +``` + +### Mixed types (2) + +```py +def _(arg: tuple[tuple[int, str], tuple[str, int]]): + # revealed: tuple[int | str, str | int] + [reveal_type((a, b)) for a, b in arg] +``` + +### Mixed types (3) + +```py +def _(arg: tuple[tuple[int, int, int], tuple[int, str, bytes], tuple[int, int, str]]): + # revealed: tuple[int, int | str, int | bytes | str] + [reveal_type((a, b, c)) for a, b, c in arg] +``` + +### Same literal values + +```py +# revealed: tuple[Literal[1, 3], Literal[2, 4]] +[reveal_type((a, b)) for a, b in ((1, 2), (3, 4))] +``` + +### Mixed literal values (1) + +```py +# revealed: tuple[Literal[1, "a"], Literal[2, "b"]] +[reveal_type((a, b)) for a, b in ((1, 2), ("a", "b"))] +``` + +### Mixed literals values (2) + +```py +# error: "Object of type `Literal[1]` is not iterable" +# error: "Object of type `Literal[2]` is not iterable" +# error: "Object of type `Literal[4]` is not iterable" +# error: [invalid-assignment] "Not enough values to unpack (expected 2, got 1)" +# revealed: tuple[Unknown | Literal[3, 5], Unknown | Literal["a", "b"]] +[reveal_type((a, b)) for a, b in (1, 2, (3, "a"), 4, (5, "b"), "c")] +``` + +### Custom iterator (1) + +```py +class Iterator: + def __next__(self) -> tuple[int, int]: + return (1, 2) + +class Iterable: + def __iter__(self) -> Iterator: + return Iterator() + +# revealed: tuple[int, int] +[reveal_type((a, b)) for a, b in Iterable()] +``` + +### Custom iterator (2) + +```py +class Iterator: + def __next__(self) -> bytes: + return b"" + +class Iterable: + def __iter__(self) -> Iterator: + return Iterator() + +def _(arg: tuple[tuple[int, str], Iterable]): + # revealed: tuple[int | bytes, str | bytes] + [reveal_type((a, b)) for a, b in arg] +``` diff --git a/crates/red_knot_python_semantic/src/semantic_index.rs b/crates/red_knot_python_semantic/src/semantic_index.rs index 7bcfa969fc..0af319bb63 100644 --- a/crates/red_knot_python_semantic/src/semantic_index.rs +++ b/crates/red_knot_python_semantic/src/semantic_index.rs @@ -940,7 +940,7 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs): panic!("expected generator definition") }; let target = comprehension.target(); - let name = target.id().as_str(); + let name = target.as_name_expr().unwrap().id().as_str(); assert_eq!(name, "x"); assert_eq!(target.range(), TextRange::new(23.into(), 24.into())); diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index e1543428dc..8098c8789f 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -18,11 +18,12 @@ use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey; use crate::semantic_index::ast_ids::AstIdsBuilder; use crate::semantic_index::definition::{ AnnotatedAssignmentDefinitionKind, AnnotatedAssignmentDefinitionNodeRef, - AssignmentDefinitionKind, AssignmentDefinitionNodeRef, ComprehensionDefinitionNodeRef, - Definition, DefinitionCategory, DefinitionKind, DefinitionNodeKey, DefinitionNodeRef, - Definitions, ExceptHandlerDefinitionNodeRef, ForStmtDefinitionKind, ForStmtDefinitionNodeRef, - ImportDefinitionNodeRef, ImportFromDefinitionNodeRef, MatchPatternDefinitionNodeRef, - StarImportDefinitionNodeRef, TargetKind, WithItemDefinitionKind, WithItemDefinitionNodeRef, + AssignmentDefinitionKind, AssignmentDefinitionNodeRef, ComprehensionDefinitionKind, + ComprehensionDefinitionNodeRef, Definition, DefinitionCategory, DefinitionKind, + DefinitionNodeKey, DefinitionNodeRef, Definitions, ExceptHandlerDefinitionNodeRef, + ForStmtDefinitionKind, ForStmtDefinitionNodeRef, ImportDefinitionNodeRef, + ImportFromDefinitionNodeRef, MatchPatternDefinitionNodeRef, StarImportDefinitionNodeRef, + TargetKind, WithItemDefinitionKind, WithItemDefinitionNodeRef, }; use crate::semantic_index::expression::{Expression, ExpressionKind}; use crate::semantic_index::predicate::{ @@ -850,31 +851,35 @@ impl<'db> SemanticIndexBuilder<'db> { // The `iter` of the first generator is evaluated in the outer scope, while all subsequent // nodes are evaluated in the inner scope. - self.add_standalone_expression(&generator.iter); + let value = self.add_standalone_expression(&generator.iter); self.visit_expr(&generator.iter); self.push_scope(scope); - self.push_assignment(CurrentAssignment::Comprehension { - node: generator, - first: true, - }); - self.visit_expr(&generator.target); - self.pop_assignment(); + self.add_unpackable_assignment( + &Unpackable::Comprehension { + node: generator, + first: true, + }, + &generator.target, + value, + ); for expr in &generator.ifs { self.visit_expr(expr); } for generator in generators_iter { - self.add_standalone_expression(&generator.iter); + let value = self.add_standalone_expression(&generator.iter); self.visit_expr(&generator.iter); - self.push_assignment(CurrentAssignment::Comprehension { - node: generator, - first: false, - }); - self.visit_expr(&generator.target); - self.pop_assignment(); + self.add_unpackable_assignment( + &Unpackable::Comprehension { + node: generator, + first: false, + }, + &generator.target, + value, + ); for expr in &generator.ifs { self.visit_expr(expr); @@ -933,9 +938,30 @@ impl<'db> SemanticIndexBuilder<'db> { let current_assignment = match target { ast::Expr::List(_) | ast::Expr::Tuple(_) => { + if matches!(unpackable, Unpackable::Comprehension { .. }) { + debug_assert_eq!( + self.scopes[self.current_scope()].node().scope_kind(), + ScopeKind::Comprehension + ); + } + // The first iterator of the comprehension is evaluated in the outer scope, while all subsequent + // nodes are evaluated in the inner scope. + // SAFETY: The current scope is the comprehension, and the comprehension scope must have a parent scope. + let value_file_scope = + if let Unpackable::Comprehension { first: true, .. } = unpackable { + self.scope_stack + .iter() + .rev() + .nth(1) + .expect("The comprehension scope must have a parent scope") + .file_scope_id + } else { + self.current_scope() + }; let unpack = Some(Unpack::new( self.db, self.file, + value_file_scope, self.current_scope(), // SAFETY: `target` belongs to the `self.module` tree #[allow(unsafe_code)] @@ -1804,7 +1830,7 @@ where let node_key = NodeKey::from_node(expr); match expr { - ast::Expr::Name(name_node @ ast::ExprName { id, ctx, .. }) => { + ast::Expr::Name(ast::ExprName { id, ctx, .. }) => { let (is_use, is_definition) = match (ctx, self.current_assignment()) { (ast::ExprContext::Store, Some(CurrentAssignment::AugAssign(_))) => { // For augmented assignment, the target expression is also used. @@ -1867,12 +1893,17 @@ where // implemented. self.add_definition(symbol, named); } - Some(CurrentAssignment::Comprehension { node, first }) => { + Some(CurrentAssignment::Comprehension { + unpack, + node, + first, + }) => { self.add_definition( symbol, ComprehensionDefinitionNodeRef { + unpack, iterable: &node.iter, - target: name_node, + target: expr, first, is_async: node.is_async, }, @@ -2143,14 +2174,37 @@ where DefinitionKind::WithItem(assignment), ); } - Some(CurrentAssignment::Comprehension { .. }) => { - // TODO: + Some(CurrentAssignment::Comprehension { + unpack, + node, + first, + }) => { + // SAFETY: `iter` and `expr` belong to the `self.module` tree + #[allow(unsafe_code)] + let assignment = ComprehensionDefinitionKind { + target_kind: TargetKind::from(unpack), + iterable: unsafe { + AstNodeRef::new(self.module.clone(), &node.iter) + }, + target: unsafe { AstNodeRef::new(self.module.clone(), expr) }, + first, + is_async: node.is_async, + }; + // Temporarily move to the scope of the method to which the instance attribute is defined. + // SAFETY: `self.scope_stack` is not empty because the targets in comprehensions should always introduce a new scope. + let scope = self.scope_stack.pop().expect("The popped scope must be a comprehension, which must have a parent scope"); + self.register_attribute_assignment( + object, + attr, + DefinitionKind::Comprehension(assignment), + ); + self.scope_stack.push(scope); } Some(CurrentAssignment::AugAssign(_)) => { // TODO: } Some(CurrentAssignment::Named(_)) => { - // TODO: + // A named expression whose target is an attribute is syntactically prohibited } None => {} } @@ -2244,6 +2298,7 @@ enum CurrentAssignment<'a> { Comprehension { node: &'a ast::Comprehension, first: bool, + unpack: Option<(UnpackPosition, Unpack<'a>)>, }, WithItem { item: &'a ast::WithItem, @@ -2257,11 +2312,9 @@ impl CurrentAssignment<'_> { match self { Self::Assign { unpack, .. } | Self::For { unpack, .. } - | Self::WithItem { unpack, .. } => unpack.as_mut().map(|(position, _)| position), - Self::AnnAssign(_) - | Self::AugAssign(_) - | Self::Named(_) - | Self::Comprehension { .. } => None, + | Self::WithItem { unpack, .. } + | Self::Comprehension { unpack, .. } => unpack.as_mut().map(|(position, _)| position), + Self::AnnAssign(_) | Self::AugAssign(_) | Self::Named(_) => None, } } } @@ -2316,13 +2369,17 @@ enum Unpackable<'a> { item: &'a ast::WithItem, is_async: bool, }, + Comprehension { + first: bool, + node: &'a ast::Comprehension, + }, } impl<'a> Unpackable<'a> { const fn kind(&self) -> UnpackKind { match self { Unpackable::Assign(_) => UnpackKind::Assign, - Unpackable::For(_) => UnpackKind::Iterable, + Unpackable::For(_) | Unpackable::Comprehension { .. } => UnpackKind::Iterable, Unpackable::WithItem { .. } => UnpackKind::ContextManager, } } @@ -2337,6 +2394,11 @@ impl<'a> Unpackable<'a> { is_async: *is_async, unpack, }, + Unpackable::Comprehension { node, first } => CurrentAssignment::Comprehension { + node, + first: *first, + unpack, + }, } } } diff --git a/crates/red_knot_python_semantic/src/semantic_index/definition.rs b/crates/red_knot_python_semantic/src/semantic_index/definition.rs index 9334ac6981..145e4b205a 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/definition.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/definition.rs @@ -281,8 +281,9 @@ pub(crate) struct ExceptHandlerDefinitionNodeRef<'a> { #[derive(Copy, Clone, Debug)] pub(crate) struct ComprehensionDefinitionNodeRef<'a> { + pub(crate) unpack: Option<(UnpackPosition, Unpack<'a>)>, pub(crate) iterable: &'a ast::Expr, - pub(crate) target: &'a ast::ExprName, + pub(crate) target: &'a ast::Expr, pub(crate) first: bool, pub(crate) is_async: bool, } @@ -374,11 +375,13 @@ impl<'db> DefinitionNodeRef<'db> { is_async, }), DefinitionNodeRef::Comprehension(ComprehensionDefinitionNodeRef { + unpack, iterable, target, first, is_async, }) => DefinitionKind::Comprehension(ComprehensionDefinitionKind { + target_kind: TargetKind::from(unpack), iterable: AstNodeRef::new(parsed.clone(), iterable), target: AstNodeRef::new(parsed, target), first, @@ -474,7 +477,9 @@ impl<'db> DefinitionNodeRef<'db> { unpack: _, is_async: _, }) => DefinitionNodeKey(NodeKey::from_node(target)), - Self::Comprehension(ComprehensionDefinitionNodeRef { target, .. }) => target.into(), + Self::Comprehension(ComprehensionDefinitionNodeRef { target, .. }) => { + DefinitionNodeKey(NodeKey::from_node(target)) + } Self::VariadicPositionalParameter(node) => node.into(), Self::VariadicKeywordParameter(node) => node.into(), Self::Parameter(node) => node.into(), @@ -550,7 +555,7 @@ pub enum DefinitionKind<'db> { AnnotatedAssignment(AnnotatedAssignmentDefinitionKind), AugmentedAssignment(AstNodeRef), For(ForStmtDefinitionKind<'db>), - Comprehension(ComprehensionDefinitionKind), + Comprehension(ComprehensionDefinitionKind<'db>), VariadicPositionalParameter(AstNodeRef), VariadicKeywordParameter(AstNodeRef), Parameter(AstNodeRef), @@ -749,19 +754,24 @@ impl MatchPatternDefinitionKind { } #[derive(Clone, Debug)] -pub struct ComprehensionDefinitionKind { - iterable: AstNodeRef, - target: AstNodeRef, - first: bool, - is_async: bool, +pub struct ComprehensionDefinitionKind<'db> { + pub(super) target_kind: TargetKind<'db>, + pub(super) iterable: AstNodeRef, + pub(super) target: AstNodeRef, + pub(super) first: bool, + pub(super) is_async: bool, } -impl ComprehensionDefinitionKind { +impl<'db> ComprehensionDefinitionKind<'db> { pub(crate) fn iterable(&self) -> &ast::Expr { self.iterable.node() } - pub(crate) fn target(&self) -> &ast::ExprName { + pub(crate) fn target_kind(&self) -> TargetKind<'db> { + self.target_kind + } + + pub(crate) fn target(&self) -> &ast::Expr { self.target.node() } diff --git a/crates/red_knot_python_semantic/src/types/class.rs b/crates/red_knot_python_semantic/src/types/class.rs index 2f57261650..8865a60950 100644 --- a/crates/red_knot_python_semantic/src/types/class.rs +++ b/crates/red_knot_python_semantic/src/types/class.rs @@ -1416,14 +1416,42 @@ impl<'db> ClassLiteralType<'db> { } } } - DefinitionKind::Comprehension(_) => { - // TODO: + DefinitionKind::Comprehension(comprehension) => { + match comprehension.target_kind() { + TargetKind::Sequence(_, unpack) => { + // We found an unpacking assignment like: + // + // [... for .., self.name, .. in ] + + let unpacked = infer_unpack_types(db, unpack); + let target_ast_id = comprehension + .target() + .scoped_expression_id(db, unpack.target_scope(db)); + let inferred_ty = unpacked.expression_type(target_ast_id); + + union_of_inferred_types = union_of_inferred_types.add(inferred_ty); + } + TargetKind::NameOrAttribute => { + // We found an attribute assignment like: + // + // [... for self.name in ] + + let iterable_ty = infer_expression_type( + db, + index.expression(comprehension.iterable()), + ); + // TODO: Potential diagnostics resulting from the iterable are currently not reported. + let inferred_ty = iterable_ty.iterate(db); + + union_of_inferred_types = union_of_inferred_types.add(inferred_ty); + } + } } DefinitionKind::AugmentedAssignment(_) => { // TODO: } DefinitionKind::NamedExpression(_) => { - // TODO: + // A named expression whose target is an attribute is syntactically prohibited } _ => {} } diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index e8f4ac2c6e..125655fff9 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -49,9 +49,9 @@ use crate::module_resolver::resolve_module; use crate::node_key::NodeKey; use crate::semantic_index::ast_ids::{HasScopedExpressionId, HasScopedUseId, ScopedExpressionId}; use crate::semantic_index::definition::{ - AnnotatedAssignmentDefinitionKind, AssignmentDefinitionKind, Definition, DefinitionKind, - DefinitionNodeKey, ExceptHandlerDefinitionKind, ForStmtDefinitionKind, TargetKind, - WithItemDefinitionKind, + AnnotatedAssignmentDefinitionKind, AssignmentDefinitionKind, ComprehensionDefinitionKind, + Definition, DefinitionKind, DefinitionNodeKey, ExceptHandlerDefinitionKind, + ForStmtDefinitionKind, TargetKind, WithItemDefinitionKind, }; use crate::semantic_index::expression::{Expression, ExpressionKind}; use crate::semantic_index::symbol::{ @@ -306,7 +306,7 @@ pub(super) fn infer_unpack_types<'db>(db: &'db dyn Db, unpack: Unpack<'db>) -> U let _span = tracing::trace_span!("infer_unpack_types", range=?unpack.range(db), ?file).entered(); - let mut unpacker = Unpacker::new(db, unpack.scope(db)); + let mut unpacker = Unpacker::new(db, unpack.target_scope(db), unpack.value_scope(db)); unpacker.unpack(unpack.target(db), unpack.value(db)); unpacker.finish() } @@ -946,13 +946,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_named_expression_definition(named_expression.node(), definition); } DefinitionKind::Comprehension(comprehension) => { - self.infer_comprehension_definition( - comprehension.iterable(), - comprehension.target(), - comprehension.is_first(), - comprehension.is_async(), - definition, - ); + self.infer_comprehension_definition(comprehension, definition); } DefinitionKind::VariadicPositionalParameter(parameter) => { self.infer_variadic_positional_parameter_definition(parameter, definition); @@ -1937,11 +1931,13 @@ impl<'db> TypeInferenceBuilder<'db> { for item in items { let target = item.optional_vars.as_deref(); if let Some(target) = target { - self.infer_target(target, &item.context_expr, |db, ctx_manager_ty| { + self.infer_target(target, &item.context_expr, |builder, context_expr| { // TODO: `infer_with_statement_definition` reports a diagnostic if `ctx_manager_ty` isn't a context manager // but only if the target is a name. We should report a diagnostic here if the target isn't a name: // `with not_context_manager as a.x: ... - ctx_manager_ty.enter(db) + builder + .infer_standalone_expression(context_expr) + .enter(builder.db()) }); } else { // Call into the context expression inference to validate that it evaluates @@ -2347,7 +2343,9 @@ impl<'db> TypeInferenceBuilder<'db> { } = assignment; for target in targets { - self.infer_target(target, value, |_, ty| ty); + self.infer_target(target, value, |builder, value_expr| { + builder.infer_standalone_expression(value_expr) + }); } } @@ -2357,23 +2355,16 @@ impl<'db> TypeInferenceBuilder<'db> { /// targets (unpacking). If `target` is an attribute expression, we check that the assignment /// is valid. For 'target's that are definitions, this check happens elsewhere. /// - /// The `to_assigned_ty` function is used to convert the inferred type of the `value` expression - /// to the type that is eventually assigned to the `target`. - /// - /// # Panics - /// - /// If the `value` is not a standalone expression. - fn infer_target(&mut self, target: &ast::Expr, value: &ast::Expr, to_assigned_ty: F) + /// The `infer_value_expr` function is used to infer the type of the `value` expression which + /// are not `Name` expressions. The returned type is the one that is eventually assigned to the + /// `target`. + fn infer_target(&mut self, target: &ast::Expr, value: &ast::Expr, infer_value_expr: F) where - F: Fn(&'db dyn Db, Type<'db>) -> Type<'db>, + F: Fn(&mut TypeInferenceBuilder<'db>, &ast::Expr) -> Type<'db>, { let assigned_ty = match target { ast::Expr::Name(_) => None, - _ => { - let value_ty = self.infer_standalone_expression(value); - - Some(to_assigned_ty(self.db(), value_ty)) - } + _ => Some(infer_value_expr(self, value)), }; self.infer_target_impl(target, assigned_ty); } @@ -3126,11 +3117,13 @@ impl<'db> TypeInferenceBuilder<'db> { is_async: _, } = for_statement; - self.infer_target(target, iter, |db, iter_ty| { + self.infer_target(target, iter, |builder, iter_expr| { // TODO: `infer_for_statement_definition` reports a diagnostic if `iter_ty` isn't iterable // but only if the target is a name. We should report a diagnostic here if the target isn't a name: // `for a.x in not_iterable: ... - iter_ty.iterate(db) + builder + .infer_standalone_expression(iter_expr) + .iterate(builder.db()) }); self.infer_body(body); @@ -3959,15 +3952,17 @@ impl<'db> TypeInferenceBuilder<'db> { is_async: _, } = comprehension; - if !is_first { - self.infer_standalone_expression(iter); - } - // TODO more complex assignment targets - if let ast::Expr::Name(name) = target { - self.infer_definition(name); - } else { - self.infer_expression(target); - } + self.infer_target(target, iter, |builder, iter_expr| { + // TODO: `infer_comprehension_definition` reports a diagnostic if `iter_ty` isn't iterable + // but only if the target is a name. We should report a diagnostic here if the target isn't a name: + // `[... for a.x in not_iterable] + if is_first { + infer_same_file_expression_type(builder.db(), builder.index.expression(iter_expr)) + } else { + builder.infer_standalone_expression(iter_expr) + } + .iterate(builder.db()) + }); for expr in ifs { self.infer_expression(expr); } @@ -3975,12 +3970,12 @@ impl<'db> TypeInferenceBuilder<'db> { fn infer_comprehension_definition( &mut self, - iterable: &ast::Expr, - target: &ast::ExprName, - is_first: bool, - is_async: bool, + comprehension: &ComprehensionDefinitionKind<'db>, definition: Definition<'db>, ) { + let iterable = comprehension.iterable(); + let target = comprehension.target(); + let expression = self.index.expression(iterable); let result = infer_expression_types(self.db(), expression); @@ -3990,7 +3985,7 @@ impl<'db> TypeInferenceBuilder<'db> { // (2) We must *not* call `self.extend()` on the result of the type inference, // because `ScopedExpressionId`s are only meaningful within their own scope, so // we'd add types for random wrong expressions in the current scope - let iterable_type = if is_first { + let iterable_type = if comprehension.is_first() { let lookup_scope = self .index .parent_scope_id(self.scope().file_scope_id(self.db())) @@ -4002,14 +3997,26 @@ impl<'db> TypeInferenceBuilder<'db> { result.expression_type(iterable.scoped_expression_id(self.db(), self.scope())) }; - let target_type = if is_async { + let target_type = if comprehension.is_async() { // TODO: async iterables/iterators! -- Alex todo_type!("async iterables/iterators") } else { - iterable_type.try_iterate(self.db()).unwrap_or_else(|err| { - err.report_diagnostic(&self.context, iterable_type, iterable.into()); - err.fallback_element_type(self.db()) - }) + match comprehension.target_kind() { + TargetKind::Sequence(unpack_position, unpack) => { + let unpacked = infer_unpack_types(self.db(), unpack); + if unpack_position == UnpackPosition::First { + self.context.extend(unpacked.diagnostics()); + } + let target_ast_id = target.scoped_expression_id(self.db(), self.scope()); + unpacked.expression_type(target_ast_id) + } + TargetKind::NameOrAttribute => { + iterable_type.try_iterate(self.db()).unwrap_or_else(|err| { + err.report_diagnostic(&self.context, iterable_type, iterable.into()); + err.fallback_element_type(self.db()) + }) + } + } }; self.types.expressions.insert( diff --git a/crates/red_knot_python_semantic/src/types/unpacker.rs b/crates/red_knot_python_semantic/src/types/unpacker.rs index a711357c87..ecaf1bdeab 100644 --- a/crates/red_knot_python_semantic/src/types/unpacker.rs +++ b/crates/red_knot_python_semantic/src/types/unpacker.rs @@ -18,16 +18,22 @@ use super::{TupleType, UnionType}; /// Unpacks the value expression type to their respective targets. pub(crate) struct Unpacker<'db> { context: InferContext<'db>, - scope: ScopeId<'db>, + target_scope: ScopeId<'db>, + value_scope: ScopeId<'db>, targets: FxHashMap>, } impl<'db> Unpacker<'db> { - pub(crate) fn new(db: &'db dyn Db, scope: ScopeId<'db>) -> Self { + pub(crate) fn new( + db: &'db dyn Db, + target_scope: ScopeId<'db>, + value_scope: ScopeId<'db>, + ) -> Self { Self { - context: InferContext::new(db, scope), + context: InferContext::new(db, target_scope), targets: FxHashMap::default(), - scope, + target_scope, + value_scope, } } @@ -43,7 +49,7 @@ impl<'db> Unpacker<'db> { ); let value_type = infer_expression_types(self.db(), value.expression()) - .expression_type(value.scoped_expression_id(self.db(), self.scope)); + .expression_type(value.scoped_expression_id(self.db(), self.value_scope)); let value_type = match value.kind() { UnpackKind::Assign => { @@ -79,8 +85,10 @@ impl<'db> Unpacker<'db> { ) { match target { ast::Expr::Name(_) | ast::Expr::Attribute(_) => { - self.targets - .insert(target.scoped_expression_id(self.db(), self.scope), value_ty); + self.targets.insert( + target.scoped_expression_id(self.db(), self.target_scope), + value_ty, + ); } ast::Expr::Starred(ast::ExprStarred { value, .. }) => { self.unpack_inner(value, value_expr, value_ty); diff --git a/crates/red_knot_python_semantic/src/unpack.rs b/crates/red_knot_python_semantic/src/unpack.rs index 4dadab3397..20081fe83f 100644 --- a/crates/red_knot_python_semantic/src/unpack.rs +++ b/crates/red_knot_python_semantic/src/unpack.rs @@ -30,7 +30,9 @@ use crate::Db; pub(crate) struct Unpack<'db> { pub(crate) file: File, - pub(crate) file_scope: FileScopeId, + pub(crate) value_file_scope: FileScopeId, + + pub(crate) target_file_scope: FileScopeId, /// The target expression that is being unpacked. For example, in `(a, b) = (1, 2)`, the target /// expression is `(a, b)`. @@ -47,9 +49,19 @@ pub(crate) struct Unpack<'db> { } impl<'db> Unpack<'db> { - /// Returns the scope where the unpacking is happening. - pub(crate) fn scope(self, db: &'db dyn Db) -> ScopeId<'db> { - self.file_scope(db).to_scope_id(db, self.file(db)) + /// Returns the scope in which the unpack value expression belongs. + /// + /// The scope in which the target and value expression belongs to are usually the same + /// except in generator expressions and comprehensions (list/dict/set), where the value + /// expression of the first generator is evaluated in the outer scope, while the ones in the subsequent + /// generators are evaluated in the comprehension scope. + pub(crate) fn value_scope(self, db: &'db dyn Db) -> ScopeId<'db> { + self.value_file_scope(db).to_scope_id(db, self.file(db)) + } + + /// Returns the scope where the unpack target expression belongs to. + pub(crate) fn target_scope(self, db: &'db dyn Db) -> ScopeId<'db> { + self.target_file_scope(db).to_scope_id(db, self.file(db)) } /// Returns the range of the unpack target expression.