[red-knot] infer attribute assignments bound in comprehensions (#17396)

## Summary

This PR is a follow-up to #16852.

Instance variables bound in comprehensions are recorded, allowing type
inference to work correctly.

This required adding support for unpacking in comprehension which
resolves https://github.com/astral-sh/ruff/issues/15369.

## Test Plan

One TODO in `mdtest/attributes.md` is now resolved, and some new test
cases are added.

---------

Co-authored-by: Dhruv Manilawala <dhruvmanila@gmail.com>
This commit is contained in:
Shunsuke Shibayama 2025-04-19 10:12:48 +09:00 committed by GitHub
parent 2a478ce1b2
commit da6b68cb58
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 349 additions and 108 deletions

View file

@ -6,7 +6,9 @@ use ruff_db::parsed::parsed_module;
use ruff_db::system::{SystemPath, SystemPathBuf, TestSystem};
use ruff_python_ast::visitor::source_order;
use ruff_python_ast::visitor::source_order::SourceOrderVisitor;
use ruff_python_ast::{self as ast, Alias, Expr, Parameter, ParameterWithDefault, Stmt};
use ruff_python_ast::{
self as ast, Alias, Comprehension, Expr, Parameter, ParameterWithDefault, Stmt,
};
fn setup_db(project_root: &SystemPath, system: TestSystem) -> anyhow::Result<ProjectDatabase> {
let project = ProjectMetadata::discover(project_root, &system)?;
@ -258,6 +260,14 @@ impl SourceOrderVisitor<'_> for PullTypesVisitor<'_> {
source_order::walk_expr(self, expr);
}
fn visit_comprehension(&mut self, comprehension: &Comprehension) {
self.visit_expr(&comprehension.iter);
self.visit_target(&comprehension.target);
for if_expr in &comprehension.ifs {
self.visit_expr(if_expr);
}
}
fn visit_parameter(&mut self, parameter: &Parameter) {
let _ty = parameter.inferred_type(&self.model);

View file

@ -397,15 +397,27 @@ class IntIterable:
def __iter__(self) -> IntIterator:
return IntIterator()
class TupleIterator:
def __next__(self) -> tuple[int, str]:
return (1, "a")
class TupleIterable:
def __iter__(self) -> TupleIterator:
return TupleIterator()
class C:
def __init__(self) -> None:
[... for self.a in IntIterable()]
[... for (self.b, self.c) in TupleIterable()]
[... for self.d in IntIterable() for self.e in IntIterable()]
c_instance = C()
# TODO: Should be `Unknown | int`
# error: [unresolved-attribute]
reveal_type(c_instance.a) # revealed: Unknown
reveal_type(c_instance.a) # revealed: Unknown | int
reveal_type(c_instance.b) # revealed: Unknown | int
reveal_type(c_instance.c) # revealed: Unknown | str
reveal_type(c_instance.d) # revealed: Unknown | int
reveal_type(c_instance.e) # revealed: Unknown | int
```
#### Conditionally declared / bound attributes

View file

@ -708,3 +708,95 @@ with ContextManager() as (a, b, c):
reveal_type(b) # revealed: Unknown
reveal_type(c) # revealed: Unknown
```
## Comprehension
Unpacking in a comprehension.
### Same types
```py
def _(arg: tuple[tuple[int, int], tuple[int, int]]):
# revealed: tuple[int, int]
[reveal_type((a, b)) for a, b in arg]
```
### Mixed types (1)
```py
def _(arg: tuple[tuple[int, int], tuple[int, str]]):
# revealed: tuple[int, int | str]
[reveal_type((a, b)) for a, b in arg]
```
### Mixed types (2)
```py
def _(arg: tuple[tuple[int, str], tuple[str, int]]):
# revealed: tuple[int | str, str | int]
[reveal_type((a, b)) for a, b in arg]
```
### Mixed types (3)
```py
def _(arg: tuple[tuple[int, int, int], tuple[int, str, bytes], tuple[int, int, str]]):
# revealed: tuple[int, int | str, int | bytes | str]
[reveal_type((a, b, c)) for a, b, c in arg]
```
### Same literal values
```py
# revealed: tuple[Literal[1, 3], Literal[2, 4]]
[reveal_type((a, b)) for a, b in ((1, 2), (3, 4))]
```
### Mixed literal values (1)
```py
# revealed: tuple[Literal[1, "a"], Literal[2, "b"]]
[reveal_type((a, b)) for a, b in ((1, 2), ("a", "b"))]
```
### Mixed literals values (2)
```py
# error: "Object of type `Literal[1]` is not iterable"
# error: "Object of type `Literal[2]` is not iterable"
# error: "Object of type `Literal[4]` is not iterable"
# error: [invalid-assignment] "Not enough values to unpack (expected 2, got 1)"
# revealed: tuple[Unknown | Literal[3, 5], Unknown | Literal["a", "b"]]
[reveal_type((a, b)) for a, b in (1, 2, (3, "a"), 4, (5, "b"), "c")]
```
### Custom iterator (1)
```py
class Iterator:
def __next__(self) -> tuple[int, int]:
return (1, 2)
class Iterable:
def __iter__(self) -> Iterator:
return Iterator()
# revealed: tuple[int, int]
[reveal_type((a, b)) for a, b in Iterable()]
```
### Custom iterator (2)
```py
class Iterator:
def __next__(self) -> bytes:
return b""
class Iterable:
def __iter__(self) -> Iterator:
return Iterator()
def _(arg: tuple[tuple[int, str], Iterable]):
# revealed: tuple[int | bytes, str | bytes]
[reveal_type((a, b)) for a, b in arg]
```

View file

@ -940,7 +940,7 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
panic!("expected generator definition")
};
let target = comprehension.target();
let name = target.id().as_str();
let name = target.as_name_expr().unwrap().id().as_str();
assert_eq!(name, "x");
assert_eq!(target.range(), TextRange::new(23.into(), 24.into()));

View file

@ -18,11 +18,12 @@ use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey;
use crate::semantic_index::ast_ids::AstIdsBuilder;
use crate::semantic_index::definition::{
AnnotatedAssignmentDefinitionKind, AnnotatedAssignmentDefinitionNodeRef,
AssignmentDefinitionKind, AssignmentDefinitionNodeRef, ComprehensionDefinitionNodeRef,
Definition, DefinitionCategory, DefinitionKind, DefinitionNodeKey, DefinitionNodeRef,
Definitions, ExceptHandlerDefinitionNodeRef, ForStmtDefinitionKind, ForStmtDefinitionNodeRef,
ImportDefinitionNodeRef, ImportFromDefinitionNodeRef, MatchPatternDefinitionNodeRef,
StarImportDefinitionNodeRef, TargetKind, WithItemDefinitionKind, WithItemDefinitionNodeRef,
AssignmentDefinitionKind, AssignmentDefinitionNodeRef, ComprehensionDefinitionKind,
ComprehensionDefinitionNodeRef, Definition, DefinitionCategory, DefinitionKind,
DefinitionNodeKey, DefinitionNodeRef, Definitions, ExceptHandlerDefinitionNodeRef,
ForStmtDefinitionKind, ForStmtDefinitionNodeRef, ImportDefinitionNodeRef,
ImportFromDefinitionNodeRef, MatchPatternDefinitionNodeRef, StarImportDefinitionNodeRef,
TargetKind, WithItemDefinitionKind, WithItemDefinitionNodeRef,
};
use crate::semantic_index::expression::{Expression, ExpressionKind};
use crate::semantic_index::predicate::{
@ -850,31 +851,35 @@ impl<'db> SemanticIndexBuilder<'db> {
// The `iter` of the first generator is evaluated in the outer scope, while all subsequent
// nodes are evaluated in the inner scope.
self.add_standalone_expression(&generator.iter);
let value = self.add_standalone_expression(&generator.iter);
self.visit_expr(&generator.iter);
self.push_scope(scope);
self.push_assignment(CurrentAssignment::Comprehension {
self.add_unpackable_assignment(
&Unpackable::Comprehension {
node: generator,
first: true,
});
self.visit_expr(&generator.target);
self.pop_assignment();
},
&generator.target,
value,
);
for expr in &generator.ifs {
self.visit_expr(expr);
}
for generator in generators_iter {
self.add_standalone_expression(&generator.iter);
let value = self.add_standalone_expression(&generator.iter);
self.visit_expr(&generator.iter);
self.push_assignment(CurrentAssignment::Comprehension {
self.add_unpackable_assignment(
&Unpackable::Comprehension {
node: generator,
first: false,
});
self.visit_expr(&generator.target);
self.pop_assignment();
},
&generator.target,
value,
);
for expr in &generator.ifs {
self.visit_expr(expr);
@ -933,9 +938,30 @@ impl<'db> SemanticIndexBuilder<'db> {
let current_assignment = match target {
ast::Expr::List(_) | ast::Expr::Tuple(_) => {
if matches!(unpackable, Unpackable::Comprehension { .. }) {
debug_assert_eq!(
self.scopes[self.current_scope()].node().scope_kind(),
ScopeKind::Comprehension
);
}
// The first iterator of the comprehension is evaluated in the outer scope, while all subsequent
// nodes are evaluated in the inner scope.
// SAFETY: The current scope is the comprehension, and the comprehension scope must have a parent scope.
let value_file_scope =
if let Unpackable::Comprehension { first: true, .. } = unpackable {
self.scope_stack
.iter()
.rev()
.nth(1)
.expect("The comprehension scope must have a parent scope")
.file_scope_id
} else {
self.current_scope()
};
let unpack = Some(Unpack::new(
self.db,
self.file,
value_file_scope,
self.current_scope(),
// SAFETY: `target` belongs to the `self.module` tree
#[allow(unsafe_code)]
@ -1804,7 +1830,7 @@ where
let node_key = NodeKey::from_node(expr);
match expr {
ast::Expr::Name(name_node @ ast::ExprName { id, ctx, .. }) => {
ast::Expr::Name(ast::ExprName { id, ctx, .. }) => {
let (is_use, is_definition) = match (ctx, self.current_assignment()) {
(ast::ExprContext::Store, Some(CurrentAssignment::AugAssign(_))) => {
// For augmented assignment, the target expression is also used.
@ -1867,12 +1893,17 @@ where
// implemented.
self.add_definition(symbol, named);
}
Some(CurrentAssignment::Comprehension { node, first }) => {
Some(CurrentAssignment::Comprehension {
unpack,
node,
first,
}) => {
self.add_definition(
symbol,
ComprehensionDefinitionNodeRef {
unpack,
iterable: &node.iter,
target: name_node,
target: expr,
first,
is_async: node.is_async,
},
@ -2143,14 +2174,37 @@ where
DefinitionKind::WithItem(assignment),
);
}
Some(CurrentAssignment::Comprehension { .. }) => {
// TODO:
Some(CurrentAssignment::Comprehension {
unpack,
node,
first,
}) => {
// SAFETY: `iter` and `expr` belong to the `self.module` tree
#[allow(unsafe_code)]
let assignment = ComprehensionDefinitionKind {
target_kind: TargetKind::from(unpack),
iterable: unsafe {
AstNodeRef::new(self.module.clone(), &node.iter)
},
target: unsafe { AstNodeRef::new(self.module.clone(), expr) },
first,
is_async: node.is_async,
};
// Temporarily move to the scope of the method to which the instance attribute is defined.
// SAFETY: `self.scope_stack` is not empty because the targets in comprehensions should always introduce a new scope.
let scope = self.scope_stack.pop().expect("The popped scope must be a comprehension, which must have a parent scope");
self.register_attribute_assignment(
object,
attr,
DefinitionKind::Comprehension(assignment),
);
self.scope_stack.push(scope);
}
Some(CurrentAssignment::AugAssign(_)) => {
// TODO:
}
Some(CurrentAssignment::Named(_)) => {
// TODO:
// A named expression whose target is an attribute is syntactically prohibited
}
None => {}
}
@ -2244,6 +2298,7 @@ enum CurrentAssignment<'a> {
Comprehension {
node: &'a ast::Comprehension,
first: bool,
unpack: Option<(UnpackPosition, Unpack<'a>)>,
},
WithItem {
item: &'a ast::WithItem,
@ -2257,11 +2312,9 @@ impl CurrentAssignment<'_> {
match self {
Self::Assign { unpack, .. }
| Self::For { unpack, .. }
| Self::WithItem { unpack, .. } => unpack.as_mut().map(|(position, _)| position),
Self::AnnAssign(_)
| Self::AugAssign(_)
| Self::Named(_)
| Self::Comprehension { .. } => None,
| Self::WithItem { unpack, .. }
| Self::Comprehension { unpack, .. } => unpack.as_mut().map(|(position, _)| position),
Self::AnnAssign(_) | Self::AugAssign(_) | Self::Named(_) => None,
}
}
}
@ -2316,13 +2369,17 @@ enum Unpackable<'a> {
item: &'a ast::WithItem,
is_async: bool,
},
Comprehension {
first: bool,
node: &'a ast::Comprehension,
},
}
impl<'a> Unpackable<'a> {
const fn kind(&self) -> UnpackKind {
match self {
Unpackable::Assign(_) => UnpackKind::Assign,
Unpackable::For(_) => UnpackKind::Iterable,
Unpackable::For(_) | Unpackable::Comprehension { .. } => UnpackKind::Iterable,
Unpackable::WithItem { .. } => UnpackKind::ContextManager,
}
}
@ -2337,6 +2394,11 @@ impl<'a> Unpackable<'a> {
is_async: *is_async,
unpack,
},
Unpackable::Comprehension { node, first } => CurrentAssignment::Comprehension {
node,
first: *first,
unpack,
},
}
}
}

View file

@ -281,8 +281,9 @@ pub(crate) struct ExceptHandlerDefinitionNodeRef<'a> {
#[derive(Copy, Clone, Debug)]
pub(crate) struct ComprehensionDefinitionNodeRef<'a> {
pub(crate) unpack: Option<(UnpackPosition, Unpack<'a>)>,
pub(crate) iterable: &'a ast::Expr,
pub(crate) target: &'a ast::ExprName,
pub(crate) target: &'a ast::Expr,
pub(crate) first: bool,
pub(crate) is_async: bool,
}
@ -374,11 +375,13 @@ impl<'db> DefinitionNodeRef<'db> {
is_async,
}),
DefinitionNodeRef::Comprehension(ComprehensionDefinitionNodeRef {
unpack,
iterable,
target,
first,
is_async,
}) => DefinitionKind::Comprehension(ComprehensionDefinitionKind {
target_kind: TargetKind::from(unpack),
iterable: AstNodeRef::new(parsed.clone(), iterable),
target: AstNodeRef::new(parsed, target),
first,
@ -474,7 +477,9 @@ impl<'db> DefinitionNodeRef<'db> {
unpack: _,
is_async: _,
}) => DefinitionNodeKey(NodeKey::from_node(target)),
Self::Comprehension(ComprehensionDefinitionNodeRef { target, .. }) => target.into(),
Self::Comprehension(ComprehensionDefinitionNodeRef { target, .. }) => {
DefinitionNodeKey(NodeKey::from_node(target))
}
Self::VariadicPositionalParameter(node) => node.into(),
Self::VariadicKeywordParameter(node) => node.into(),
Self::Parameter(node) => node.into(),
@ -550,7 +555,7 @@ pub enum DefinitionKind<'db> {
AnnotatedAssignment(AnnotatedAssignmentDefinitionKind),
AugmentedAssignment(AstNodeRef<ast::StmtAugAssign>),
For(ForStmtDefinitionKind<'db>),
Comprehension(ComprehensionDefinitionKind),
Comprehension(ComprehensionDefinitionKind<'db>),
VariadicPositionalParameter(AstNodeRef<ast::Parameter>),
VariadicKeywordParameter(AstNodeRef<ast::Parameter>),
Parameter(AstNodeRef<ast::ParameterWithDefault>),
@ -749,19 +754,24 @@ impl MatchPatternDefinitionKind {
}
#[derive(Clone, Debug)]
pub struct ComprehensionDefinitionKind {
iterable: AstNodeRef<ast::Expr>,
target: AstNodeRef<ast::ExprName>,
first: bool,
is_async: bool,
pub struct ComprehensionDefinitionKind<'db> {
pub(super) target_kind: TargetKind<'db>,
pub(super) iterable: AstNodeRef<ast::Expr>,
pub(super) target: AstNodeRef<ast::Expr>,
pub(super) first: bool,
pub(super) is_async: bool,
}
impl ComprehensionDefinitionKind {
impl<'db> ComprehensionDefinitionKind<'db> {
pub(crate) fn iterable(&self) -> &ast::Expr {
self.iterable.node()
}
pub(crate) fn target(&self) -> &ast::ExprName {
pub(crate) fn target_kind(&self) -> TargetKind<'db> {
self.target_kind
}
pub(crate) fn target(&self) -> &ast::Expr {
self.target.node()
}

View file

@ -1416,14 +1416,42 @@ impl<'db> ClassLiteralType<'db> {
}
}
}
DefinitionKind::Comprehension(_) => {
// TODO:
DefinitionKind::Comprehension(comprehension) => {
match comprehension.target_kind() {
TargetKind::Sequence(_, unpack) => {
// We found an unpacking assignment like:
//
// [... for .., self.name, .. in <iterable>]
let unpacked = infer_unpack_types(db, unpack);
let target_ast_id = comprehension
.target()
.scoped_expression_id(db, unpack.target_scope(db));
let inferred_ty = unpacked.expression_type(target_ast_id);
union_of_inferred_types = union_of_inferred_types.add(inferred_ty);
}
TargetKind::NameOrAttribute => {
// We found an attribute assignment like:
//
// [... for self.name in <iterable>]
let iterable_ty = infer_expression_type(
db,
index.expression(comprehension.iterable()),
);
// TODO: Potential diagnostics resulting from the iterable are currently not reported.
let inferred_ty = iterable_ty.iterate(db);
union_of_inferred_types = union_of_inferred_types.add(inferred_ty);
}
}
}
DefinitionKind::AugmentedAssignment(_) => {
// TODO:
}
DefinitionKind::NamedExpression(_) => {
// TODO:
// A named expression whose target is an attribute is syntactically prohibited
}
_ => {}
}

View file

@ -49,9 +49,9 @@ use crate::module_resolver::resolve_module;
use crate::node_key::NodeKey;
use crate::semantic_index::ast_ids::{HasScopedExpressionId, HasScopedUseId, ScopedExpressionId};
use crate::semantic_index::definition::{
AnnotatedAssignmentDefinitionKind, AssignmentDefinitionKind, Definition, DefinitionKind,
DefinitionNodeKey, ExceptHandlerDefinitionKind, ForStmtDefinitionKind, TargetKind,
WithItemDefinitionKind,
AnnotatedAssignmentDefinitionKind, AssignmentDefinitionKind, ComprehensionDefinitionKind,
Definition, DefinitionKind, DefinitionNodeKey, ExceptHandlerDefinitionKind,
ForStmtDefinitionKind, TargetKind, WithItemDefinitionKind,
};
use crate::semantic_index::expression::{Expression, ExpressionKind};
use crate::semantic_index::symbol::{
@ -306,7 +306,7 @@ pub(super) fn infer_unpack_types<'db>(db: &'db dyn Db, unpack: Unpack<'db>) -> U
let _span =
tracing::trace_span!("infer_unpack_types", range=?unpack.range(db), ?file).entered();
let mut unpacker = Unpacker::new(db, unpack.scope(db));
let mut unpacker = Unpacker::new(db, unpack.target_scope(db), unpack.value_scope(db));
unpacker.unpack(unpack.target(db), unpack.value(db));
unpacker.finish()
}
@ -946,13 +946,7 @@ impl<'db> TypeInferenceBuilder<'db> {
self.infer_named_expression_definition(named_expression.node(), definition);
}
DefinitionKind::Comprehension(comprehension) => {
self.infer_comprehension_definition(
comprehension.iterable(),
comprehension.target(),
comprehension.is_first(),
comprehension.is_async(),
definition,
);
self.infer_comprehension_definition(comprehension, definition);
}
DefinitionKind::VariadicPositionalParameter(parameter) => {
self.infer_variadic_positional_parameter_definition(parameter, definition);
@ -1937,11 +1931,13 @@ impl<'db> TypeInferenceBuilder<'db> {
for item in items {
let target = item.optional_vars.as_deref();
if let Some(target) = target {
self.infer_target(target, &item.context_expr, |db, ctx_manager_ty| {
self.infer_target(target, &item.context_expr, |builder, context_expr| {
// TODO: `infer_with_statement_definition` reports a diagnostic if `ctx_manager_ty` isn't a context manager
// but only if the target is a name. We should report a diagnostic here if the target isn't a name:
// `with not_context_manager as a.x: ...
ctx_manager_ty.enter(db)
builder
.infer_standalone_expression(context_expr)
.enter(builder.db())
});
} else {
// Call into the context expression inference to validate that it evaluates
@ -2347,7 +2343,9 @@ impl<'db> TypeInferenceBuilder<'db> {
} = assignment;
for target in targets {
self.infer_target(target, value, |_, ty| ty);
self.infer_target(target, value, |builder, value_expr| {
builder.infer_standalone_expression(value_expr)
});
}
}
@ -2357,23 +2355,16 @@ impl<'db> TypeInferenceBuilder<'db> {
/// targets (unpacking). If `target` is an attribute expression, we check that the assignment
/// is valid. For 'target's that are definitions, this check happens elsewhere.
///
/// The `to_assigned_ty` function is used to convert the inferred type of the `value` expression
/// to the type that is eventually assigned to the `target`.
///
/// # Panics
///
/// If the `value` is not a standalone expression.
fn infer_target<F>(&mut self, target: &ast::Expr, value: &ast::Expr, to_assigned_ty: F)
/// The `infer_value_expr` function is used to infer the type of the `value` expression which
/// are not `Name` expressions. The returned type is the one that is eventually assigned to the
/// `target`.
fn infer_target<F>(&mut self, target: &ast::Expr, value: &ast::Expr, infer_value_expr: F)
where
F: Fn(&'db dyn Db, Type<'db>) -> Type<'db>,
F: Fn(&mut TypeInferenceBuilder<'db>, &ast::Expr) -> Type<'db>,
{
let assigned_ty = match target {
ast::Expr::Name(_) => None,
_ => {
let value_ty = self.infer_standalone_expression(value);
Some(to_assigned_ty(self.db(), value_ty))
}
_ => Some(infer_value_expr(self, value)),
};
self.infer_target_impl(target, assigned_ty);
}
@ -3126,11 +3117,13 @@ impl<'db> TypeInferenceBuilder<'db> {
is_async: _,
} = for_statement;
self.infer_target(target, iter, |db, iter_ty| {
self.infer_target(target, iter, |builder, iter_expr| {
// TODO: `infer_for_statement_definition` reports a diagnostic if `iter_ty` isn't iterable
// but only if the target is a name. We should report a diagnostic here if the target isn't a name:
// `for a.x in not_iterable: ...
iter_ty.iterate(db)
builder
.infer_standalone_expression(iter_expr)
.iterate(builder.db())
});
self.infer_body(body);
@ -3959,15 +3952,17 @@ impl<'db> TypeInferenceBuilder<'db> {
is_async: _,
} = comprehension;
if !is_first {
self.infer_standalone_expression(iter);
}
// TODO more complex assignment targets
if let ast::Expr::Name(name) = target {
self.infer_definition(name);
self.infer_target(target, iter, |builder, iter_expr| {
// TODO: `infer_comprehension_definition` reports a diagnostic if `iter_ty` isn't iterable
// but only if the target is a name. We should report a diagnostic here if the target isn't a name:
// `[... for a.x in not_iterable]
if is_first {
infer_same_file_expression_type(builder.db(), builder.index.expression(iter_expr))
} else {
self.infer_expression(target);
builder.infer_standalone_expression(iter_expr)
}
.iterate(builder.db())
});
for expr in ifs {
self.infer_expression(expr);
}
@ -3975,12 +3970,12 @@ impl<'db> TypeInferenceBuilder<'db> {
fn infer_comprehension_definition(
&mut self,
iterable: &ast::Expr,
target: &ast::ExprName,
is_first: bool,
is_async: bool,
comprehension: &ComprehensionDefinitionKind<'db>,
definition: Definition<'db>,
) {
let iterable = comprehension.iterable();
let target = comprehension.target();
let expression = self.index.expression(iterable);
let result = infer_expression_types(self.db(), expression);
@ -3990,7 +3985,7 @@ impl<'db> TypeInferenceBuilder<'db> {
// (2) We must *not* call `self.extend()` on the result of the type inference,
// because `ScopedExpressionId`s are only meaningful within their own scope, so
// we'd add types for random wrong expressions in the current scope
let iterable_type = if is_first {
let iterable_type = if comprehension.is_first() {
let lookup_scope = self
.index
.parent_scope_id(self.scope().file_scope_id(self.db()))
@ -4002,14 +3997,26 @@ impl<'db> TypeInferenceBuilder<'db> {
result.expression_type(iterable.scoped_expression_id(self.db(), self.scope()))
};
let target_type = if is_async {
let target_type = if comprehension.is_async() {
// TODO: async iterables/iterators! -- Alex
todo_type!("async iterables/iterators")
} else {
match comprehension.target_kind() {
TargetKind::Sequence(unpack_position, unpack) => {
let unpacked = infer_unpack_types(self.db(), unpack);
if unpack_position == UnpackPosition::First {
self.context.extend(unpacked.diagnostics());
}
let target_ast_id = target.scoped_expression_id(self.db(), self.scope());
unpacked.expression_type(target_ast_id)
}
TargetKind::NameOrAttribute => {
iterable_type.try_iterate(self.db()).unwrap_or_else(|err| {
err.report_diagnostic(&self.context, iterable_type, iterable.into());
err.fallback_element_type(self.db())
})
}
}
};
self.types.expressions.insert(

View file

@ -18,16 +18,22 @@ use super::{TupleType, UnionType};
/// Unpacks the value expression type to their respective targets.
pub(crate) struct Unpacker<'db> {
context: InferContext<'db>,
scope: ScopeId<'db>,
target_scope: ScopeId<'db>,
value_scope: ScopeId<'db>,
targets: FxHashMap<ScopedExpressionId, Type<'db>>,
}
impl<'db> Unpacker<'db> {
pub(crate) fn new(db: &'db dyn Db, scope: ScopeId<'db>) -> Self {
pub(crate) fn new(
db: &'db dyn Db,
target_scope: ScopeId<'db>,
value_scope: ScopeId<'db>,
) -> Self {
Self {
context: InferContext::new(db, scope),
context: InferContext::new(db, target_scope),
targets: FxHashMap::default(),
scope,
target_scope,
value_scope,
}
}
@ -43,7 +49,7 @@ impl<'db> Unpacker<'db> {
);
let value_type = infer_expression_types(self.db(), value.expression())
.expression_type(value.scoped_expression_id(self.db(), self.scope));
.expression_type(value.scoped_expression_id(self.db(), self.value_scope));
let value_type = match value.kind() {
UnpackKind::Assign => {
@ -79,8 +85,10 @@ impl<'db> Unpacker<'db> {
) {
match target {
ast::Expr::Name(_) | ast::Expr::Attribute(_) => {
self.targets
.insert(target.scoped_expression_id(self.db(), self.scope), value_ty);
self.targets.insert(
target.scoped_expression_id(self.db(), self.target_scope),
value_ty,
);
}
ast::Expr::Starred(ast::ExprStarred { value, .. }) => {
self.unpack_inner(value, value_expr, value_ty);

View file

@ -30,7 +30,9 @@ use crate::Db;
pub(crate) struct Unpack<'db> {
pub(crate) file: File,
pub(crate) file_scope: FileScopeId,
pub(crate) value_file_scope: FileScopeId,
pub(crate) target_file_scope: FileScopeId,
/// The target expression that is being unpacked. For example, in `(a, b) = (1, 2)`, the target
/// expression is `(a, b)`.
@ -47,9 +49,19 @@ pub(crate) struct Unpack<'db> {
}
impl<'db> Unpack<'db> {
/// Returns the scope where the unpacking is happening.
pub(crate) fn scope(self, db: &'db dyn Db) -> ScopeId<'db> {
self.file_scope(db).to_scope_id(db, self.file(db))
/// Returns the scope in which the unpack value expression belongs.
///
/// The scope in which the target and value expression belongs to are usually the same
/// except in generator expressions and comprehensions (list/dict/set), where the value
/// expression of the first generator is evaluated in the outer scope, while the ones in the subsequent
/// generators are evaluated in the comprehension scope.
pub(crate) fn value_scope(self, db: &'db dyn Db) -> ScopeId<'db> {
self.value_file_scope(db).to_scope_id(db, self.file(db))
}
/// Returns the scope where the unpack target expression belongs to.
pub(crate) fn target_scope(self, db: &'db dyn Db) -> ScopeId<'db> {
self.target_file_scope(db).to_scope_id(db, self.file(db))
}
/// Returns the range of the unpack target expression.