[ty] Support variable-length tuples in unpacking assignments (#18948)

This PR updates our unpacking assignment logic to use the new tuple
machinery. As a result, we can now unpack variable-length tuples
correctly.

As part of this, the `TupleSpec` classes have been renamed to `Tuple`,
and can now contain any element (Rust) type, not just `Type<'db>`. The
unpacker uses a tuple of `UnionBuilder`s to maintain the types that will
be assigned to each target, as we iterate through potentially many union
elements on the rhs. We also add a new consuming iterator for tuples,
and update the `all_elements` methods to wrap the result in an enum
(similar to `itertools::Position`) letting you know which part of the
tuple each element appears in. I also added a new
`UnionBuilder::try_build`, which lets you specify a different fallback
type if the union contains no elements.
This commit is contained in:
Douglas Creager 2025-06-27 15:29:04 -04:00 committed by GitHub
parent a50a993b9c
commit c60e590b4c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 779 additions and 423 deletions

View file

@ -24,7 +24,7 @@ error[invalid-assignment]: Not enough values to unpack
1 | [a, *b, c, d] = (1, 2) # error: [invalid-assignment] 1 | [a, *b, c, d] = (1, 2) # error: [invalid-assignment]
| ^^^^^^^^^^^^^ ------ Got 2 | ^^^^^^^^^^^^^ ------ Got 2
| | | |
| Expected 3 or more | Expected at least 3
| |
info: rule `invalid-assignment` is enabled by default info: rule `invalid-assignment` is enabled by default

View file

@ -106,7 +106,7 @@ reveal_type(d) # revealed: Literal[5]
### Starred expression (1) ### Starred expression (1)
```py ```py
# error: [invalid-assignment] "Not enough values to unpack: Expected 3 or more" # error: [invalid-assignment] "Not enough values to unpack: Expected at least 3"
[a, *b, c, d] = (1, 2) [a, *b, c, d] = (1, 2)
reveal_type(a) # revealed: Unknown reveal_type(a) # revealed: Unknown
reveal_type(b) # revealed: list[Unknown] reveal_type(b) # revealed: list[Unknown]
@ -119,7 +119,7 @@ reveal_type(d) # revealed: Unknown
```py ```py
[a, *b, c] = (1, 2) [a, *b, c] = (1, 2)
reveal_type(a) # revealed: Literal[1] reveal_type(a) # revealed: Literal[1]
reveal_type(b) # revealed: list[Unknown] reveal_type(b) # revealed: list[Never]
reveal_type(c) # revealed: Literal[2] reveal_type(c) # revealed: Literal[2]
``` ```
@ -154,7 +154,7 @@ reveal_type(c) # revealed: list[Literal[3, 4]]
### Starred expression (6) ### Starred expression (6)
```py ```py
# error: [invalid-assignment] "Not enough values to unpack: Expected 5 or more" # error: [invalid-assignment] "Not enough values to unpack: Expected at least 5"
(a, b, c, *d, e, f) = (1,) (a, b, c, *d, e, f) = (1,)
reveal_type(a) # revealed: Unknown reveal_type(a) # revealed: Unknown
reveal_type(b) # revealed: Unknown reveal_type(b) # revealed: Unknown
@ -258,6 +258,155 @@ def _(value: list[int]):
reveal_type(c) # revealed: int reveal_type(c) # revealed: int
``` ```
## Homogeneous tuples
### Simple unpacking
```py
def _(value: tuple[int, ...]):
a, b = value
reveal_type(a) # revealed: int
reveal_type(b) # revealed: int
```
### Nested unpacking
```py
def _(value: tuple[tuple[int, ...], ...]):
a, (b, c) = value
reveal_type(a) # revealed: tuple[int, ...]
reveal_type(b) # revealed: int
reveal_type(c) # revealed: int
```
### Invalid nested unpacking
```py
def _(value: tuple[int, ...]):
# error: [not-iterable] "Object of type `int` is not iterable"
a, (b, c) = value
reveal_type(a) # revealed: int
reveal_type(b) # revealed: Unknown
reveal_type(c) # revealed: Unknown
```
### Starred expression
```py
def _(value: tuple[int, ...]):
a, *b, c = value
reveal_type(a) # revealed: int
reveal_type(b) # revealed: list[int]
reveal_type(c) # revealed: int
```
## Mixed tuples
```toml
[environment]
python-version = "3.11"
```
### Simple unpacking (1)
```py
def _(value: tuple[int, *tuple[str, ...]]):
a, b = value
reveal_type(a) # revealed: int
reveal_type(b) # revealed: str
```
### Simple unpacking (2)
```py
def _(value: tuple[int, int, *tuple[str, ...]]):
a, b = value
reveal_type(a) # revealed: int
reveal_type(b) # revealed: int
```
### Simple unpacking (3)
```py
def _(value: tuple[int, *tuple[str, ...], int]):
a, b, c = value
reveal_type(a) # revealed: int
reveal_type(b) # revealed: str
reveal_type(c) # revealed: int
```
### Invalid unpacked
```py
def _(value: tuple[int, int, int, *tuple[str, ...]]):
# error: [invalid-assignment] "Too many values to unpack: Expected 2"
a, b = value
reveal_type(a) # revealed: Unknown
reveal_type(b) # revealed: Unknown
```
### Nested unpacking
```py
def _(value: tuple[str, *tuple[tuple[int, ...], ...]]):
a, (b, c) = value
reveal_type(a) # revealed: str
reveal_type(b) # revealed: int
reveal_type(c) # revealed: int
```
### Invalid nested unpacking
```py
def _(value: tuple[str, *tuple[int, ...]]):
# error: [not-iterable] "Object of type `int` is not iterable"
a, (b, c) = value
reveal_type(a) # revealed: str
reveal_type(b) # revealed: Unknown
reveal_type(c) # revealed: Unknown
```
### Starred expression (1)
```py
def _(value: tuple[int, *tuple[str, ...]]):
a, *b, c = value
reveal_type(a) # revealed: int
reveal_type(b) # revealed: list[str]
reveal_type(c) # revealed: str
```
### Starred expression (2)
```py
def _(value: tuple[int, *tuple[str, ...], int]):
a, *b, c = value
reveal_type(a) # revealed: int
reveal_type(b) # revealed: list[str]
reveal_type(c) # revealed: int
```
### Starred expression (3)
```py
def _(value: tuple[int, *tuple[str, ...], int]):
a, *b, c, d = value
reveal_type(a) # revealed: int
reveal_type(b) # revealed: list[str]
reveal_type(c) # revealed: str
reveal_type(d) # revealed: int
```
### Starred expression (4)
```py
def _(value: tuple[int, int, *tuple[str, ...], int]):
a, *b, c = value
reveal_type(a) # revealed: int
reveal_type(b) # revealed: list[int | str]
reveal_type(c) # revealed: int
```
## String ## String
### Simple unpacking ### Simple unpacking
@ -290,7 +439,7 @@ reveal_type(b) # revealed: Unknown
### Starred expression (1) ### Starred expression (1)
```py ```py
# error: [invalid-assignment] "Not enough values to unpack: Expected 3 or more" # error: [invalid-assignment] "Not enough values to unpack: Expected at least 3"
(a, *b, c, d) = "ab" (a, *b, c, d) = "ab"
reveal_type(a) # revealed: Unknown reveal_type(a) # revealed: Unknown
reveal_type(b) # revealed: list[Unknown] reveal_type(b) # revealed: list[Unknown]
@ -299,7 +448,7 @@ reveal_type(d) # revealed: Unknown
``` ```
```py ```py
# error: [invalid-assignment] "Not enough values to unpack: Expected 3 or more" # error: [invalid-assignment] "Not enough values to unpack: Expected at least 3"
(a, b, *c, d) = "a" (a, b, *c, d) = "a"
reveal_type(a) # revealed: Unknown reveal_type(a) # revealed: Unknown
reveal_type(b) # revealed: Unknown reveal_type(b) # revealed: Unknown
@ -312,7 +461,7 @@ reveal_type(d) # revealed: Unknown
```py ```py
(a, *b, c) = "ab" (a, *b, c) = "ab"
reveal_type(a) # revealed: LiteralString reveal_type(a) # revealed: LiteralString
reveal_type(b) # revealed: list[Unknown] reveal_type(b) # revealed: list[Never]
reveal_type(c) # revealed: LiteralString reveal_type(c) # revealed: LiteralString
``` ```

View file

@ -726,7 +726,7 @@ impl<'db> Type<'db> {
.map(|ty| ty.materialize(db, variance.flip())), .map(|ty| ty.materialize(db, variance.flip())),
) )
.build(), .build(),
Type::Tuple(tuple_type) => Type::tuple(db, tuple_type.materialize(db, variance)), Type::Tuple(tuple_type) => Type::tuple(tuple_type.materialize(db, variance)),
Type::TypeVar(type_var) => Type::TypeVar(type_var.materialize(db, variance)), Type::TypeVar(type_var) => Type::TypeVar(type_var.materialize(db, variance)),
Type::TypeIs(type_is) => { Type::TypeIs(type_is) => {
type_is.with_type(db, type_is.return_type(db).materialize(db, variance)) type_is.with_type(db, type_is.return_type(db).materialize(db, variance))
@ -1141,7 +1141,7 @@ impl<'db> Type<'db> {
match self { match self {
Type::Union(union) => Type::Union(union.normalized(db)), Type::Union(union) => Type::Union(union.normalized(db)),
Type::Intersection(intersection) => Type::Intersection(intersection.normalized(db)), Type::Intersection(intersection) => Type::Intersection(intersection.normalized(db)),
Type::Tuple(tuple) => Type::tuple(db, tuple.normalized(db)), Type::Tuple(tuple) => Type::tuple(tuple.normalized(db)),
Type::Callable(callable) => Type::Callable(callable.normalized(db)), Type::Callable(callable) => Type::Callable(callable.normalized(db)),
Type::ProtocolInstance(protocol) => protocol.normalized(db), Type::ProtocolInstance(protocol) => protocol.normalized(db),
Type::NominalInstance(instance) => Type::NominalInstance(instance.normalized(db)), Type::NominalInstance(instance) => Type::NominalInstance(instance.normalized(db)),
@ -3458,7 +3458,7 @@ impl<'db> Type<'db> {
Type::BooleanLiteral(bool) => Truthiness::from(*bool), Type::BooleanLiteral(bool) => Truthiness::from(*bool),
Type::StringLiteral(str) => Truthiness::from(!str.value(db).is_empty()), Type::StringLiteral(str) => Truthiness::from(!str.value(db).is_empty()),
Type::BytesLiteral(bytes) => Truthiness::from(!bytes.value(db).is_empty()), Type::BytesLiteral(bytes) => Truthiness::from(!bytes.value(db).is_empty()),
Type::Tuple(tuple) => match tuple.tuple(db).size_hint() { Type::Tuple(tuple) => match tuple.tuple(db).len().size_hint() {
// The tuple type is AlwaysFalse if it contains only the empty tuple // The tuple type is AlwaysFalse if it contains only the empty tuple
(_, Some(0)) => Truthiness::AlwaysFalse, (_, Some(0)) => Truthiness::AlwaysFalse,
// The tuple type is AlwaysTrue if its inhabitants must always have length >=1 // The tuple type is AlwaysTrue if its inhabitants must always have length >=1
@ -4312,7 +4312,7 @@ impl<'db> Type<'db> {
let mut parameter = let mut parameter =
Parameter::positional_only(Some(Name::new_static("iterable"))) Parameter::positional_only(Some(Name::new_static("iterable")))
.with_annotated_type(instantiated); .with_annotated_type(instantiated);
if matches!(spec.size_hint().1, Some(0)) { if matches!(spec.len().maximum(), Some(0)) {
parameter = parameter.with_default_type(TupleType::empty(db)); parameter = parameter.with_default_type(TupleType::empty(db));
} }
Parameters::new([parameter]) Parameters::new([parameter])
@ -5350,7 +5350,7 @@ impl<'db> Type<'db> {
} }
builder.build() builder.build()
} }
Type::Tuple(tuple) => Type::Tuple(tuple.apply_type_mapping(db, type_mapping)), Type::Tuple(tuple) => Type::tuple(tuple.apply_type_mapping(db, type_mapping)),
Type::TypeIs(type_is) => type_is.with_type(db, type_is.return_type(db).apply_type_mapping(db, type_mapping)), Type::TypeIs(type_is) => type_is.with_type(db, type_is.return_type(db).apply_type_mapping(db, type_mapping)),

View file

@ -444,6 +444,10 @@ impl<'db> UnionBuilder<'db> {
} }
pub(crate) fn build(self) -> Type<'db> { pub(crate) fn build(self) -> Type<'db> {
self.try_build().unwrap_or(Type::Never)
}
pub(crate) fn try_build(self) -> Option<Type<'db>> {
let mut types = vec![]; let mut types = vec![];
for element in self.elements { for element in self.elements {
match element { match element {
@ -460,9 +464,12 @@ impl<'db> UnionBuilder<'db> {
} }
} }
match types.len() { match types.len() {
0 => Type::Never, 0 => None,
1 => types[0], 1 => Some(types[0]),
_ => Type::Union(UnionType::new(self.db, types.into_boxed_slice())), _ => Some(Type::Union(UnionType::new(
self.db,
types.into_boxed_slice(),
))),
} }
} }
} }

View file

@ -221,10 +221,10 @@ fn expand_type<'db>(db: &'db dyn Db, ty: Type<'db>) -> Option<Vec<Type<'db>>> {
let expanded = tuple let expanded = tuple
.all_elements() .all_elements()
.map(|element| { .map(|element| {
if let Some(expanded) = expand_type(db, element) { if let Some(expanded) = expand_type(db, *element) {
Either::Left(expanded.into_iter()) Either::Left(expanded.into_iter())
} else { } else {
Either::Right(std::iter::once(element)) Either::Right(std::iter::once(*element))
} }
}) })
.multi_cartesian_product() .multi_cartesian_product()

View file

@ -286,9 +286,13 @@ impl<'db> Specialization<'db> {
return tuple; return tuple;
} }
if let [element_type] = self.types(db) { if let [element_type] = self.types(db) {
return TupleType::new(db, TupleSpec::homogeneous(*element_type)).tuple(db); if let Some(tuple) = TupleType::new(db, TupleSpec::homogeneous(*element_type)) {
return tuple.tuple(db);
} }
TupleType::new(db, TupleSpec::homogeneous(Type::unknown())).tuple(db) }
TupleType::new(db, TupleSpec::homogeneous(Type::unknown()))
.expect("tuple[Unknown, ...] should never contain Never")
.tuple(db)
} }
/// Returns the type that a typevar is mapped to, or None if the typevar isn't part of this /// Returns the type that a typevar is mapped to, or None if the typevar isn't part of this
@ -330,7 +334,7 @@ impl<'db> Specialization<'db> {
.collect(); .collect();
let tuple_inner = self let tuple_inner = self
.tuple_inner(db) .tuple_inner(db)
.map(|tuple| tuple.apply_type_mapping(db, type_mapping)); .and_then(|tuple| tuple.apply_type_mapping(db, type_mapping));
Specialization::new(db, self.generic_context(db), types, tuple_inner) Specialization::new(db, self.generic_context(db), types, tuple_inner)
} }
@ -374,7 +378,7 @@ impl<'db> Specialization<'db> {
pub(crate) fn normalized(self, db: &'db dyn Db) -> Self { pub(crate) fn normalized(self, db: &'db dyn Db) -> Self {
let types: Box<[_]> = self.types(db).iter().map(|ty| ty.normalized(db)).collect(); let types: Box<[_]> = self.types(db).iter().map(|ty| ty.normalized(db)).collect();
let tuple_inner = self.tuple_inner(db).map(|tuple| tuple.normalized(db)); let tuple_inner = self.tuple_inner(db).and_then(|tuple| tuple.normalized(db));
Self::new(db, self.generic_context(db), types, tuple_inner) Self::new(db, self.generic_context(db), types, tuple_inner)
} }
@ -394,7 +398,7 @@ impl<'db> Specialization<'db> {
vartype.materialize(db, variance) vartype.materialize(db, variance)
}) })
.collect(); .collect();
let tuple_inner = self.tuple_inner(db).map(|tuple| { let tuple_inner = self.tuple_inner(db).and_then(|tuple| {
// Tuples are immutable, so tuple element types are always in covariant position. // Tuples are immutable, so tuple element types are always in covariant position.
tuple.materialize(db, variance) tuple.materialize(db, variance)
}); });
@ -637,7 +641,7 @@ impl<'db> SpecializationBuilder<'db> {
(TupleSpec::Fixed(formal_tuple), TupleSpec::Fixed(actual_tuple)) => { (TupleSpec::Fixed(formal_tuple), TupleSpec::Fixed(actual_tuple)) => {
if formal_tuple.len() == actual_tuple.len() { if formal_tuple.len() == actual_tuple.len() {
for (formal_element, actual_element) in formal_tuple.elements().zip(actual_tuple.elements()) { for (formal_element, actual_element) in formal_tuple.elements().zip(actual_tuple.elements()) {
self.infer(formal_element, actual_element)?; self.infer(*formal_element, *actual_element)?;
} }
} }
} }

View file

@ -2835,7 +2835,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// it will actually be the type of the generic parameters to `BaseExceptionGroup` or `ExceptionGroup`. // it will actually be the type of the generic parameters to `BaseExceptionGroup` or `ExceptionGroup`.
let symbol_ty = if let Type::Tuple(tuple) = node_ty { let symbol_ty = if let Type::Tuple(tuple) = node_ty {
let mut builder = UnionBuilder::new(self.db()); let mut builder = UnionBuilder::new(self.db());
for element in tuple.tuple(self.db()).all_elements() { for element in tuple.tuple(self.db()).all_elements().copied() {
builder = builder.add( builder = builder.add(
if element.is_assignable_to(self.db(), type_base_exception) { if element.is_assignable_to(self.db(), type_base_exception) {
element.to_instance(self.db()).expect( element.to_instance(self.db()).expect(
@ -3701,7 +3701,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
ast::Expr::List(ast::ExprList { elts, .. }) ast::Expr::List(ast::ExprList { elts, .. })
| ast::Expr::Tuple(ast::ExprTuple { elts, .. }) => { | ast::Expr::Tuple(ast::ExprTuple { elts, .. }) => {
let mut assigned_tys = match assigned_ty { let mut assigned_tys = match assigned_ty {
Some(Type::Tuple(tuple)) => Either::Left(tuple.tuple(self.db()).all_elements()), Some(Type::Tuple(tuple)) => {
Either::Left(tuple.tuple(self.db()).all_elements().copied())
}
Some(_) | None => Either::Right(std::iter::empty()), Some(_) | None => Either::Right(std::iter::empty()),
}; };
@ -6485,13 +6487,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
op, op,
), ),
(Type::Tuple(lhs), Type::Tuple(rhs), ast::Operator::Add) => Some(Type::tuple( (Type::Tuple(lhs), Type::Tuple(rhs), ast::Operator::Add) => {
self.db(), Some(Type::tuple(TupleType::new(
TupleType::new(
self.db(), self.db(),
lhs.tuple(self.db()).concat(self.db(), rhs.tuple(self.db())), lhs.tuple(self.db()).concat(self.db(), rhs.tuple(self.db())),
), )))
)), }
// We've handled all of the special cases that we support for literals, so we need to // We've handled all of the special cases that we support for literals, so we need to
// fall back on looking for dunder methods on one of the operand types. // fall back on looking for dunder methods on one of the operand types.
@ -6948,14 +6949,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// tuples. // tuples.
// //
// Ref: https://github.com/astral-sh/ruff/pull/18251#discussion_r2115909311 // Ref: https://github.com/astral-sh/ruff/pull/18251#discussion_r2115909311
let (minimum_length, _) = tuple.tuple(self.db()).size_hint(); let (minimum_length, _) = tuple.tuple(self.db()).len().size_hint();
if minimum_length > 1 << 12 { if minimum_length > 1 << 12 {
return None; return None;
} }
let mut definitely_true = false; let mut definitely_true = false;
let mut definitely_false = true; let mut definitely_false = true;
for element in tuple.tuple(self.db()).all_elements() { for element in tuple.tuple(self.db()).all_elements().copied() {
if element.is_string_literal() { if element.is_string_literal() {
if literal == element { if literal == element {
definitely_true = true; definitely_true = true;
@ -7238,7 +7239,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let mut any_eq = false; let mut any_eq = false;
let mut any_ambiguous = false; let mut any_ambiguous = false;
for ty in rhs_tuple.all_elements() { for ty in rhs_tuple.all_elements().copied() {
let eq_result = self.infer_binary_type_comparison( let eq_result = self.infer_binary_type_comparison(
Type::Tuple(lhs), Type::Tuple(lhs),
ast::CmpOp::Eq, ast::CmpOp::Eq,
@ -7450,8 +7451,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
return Ok(Type::unknown()); return Ok(Type::unknown());
}; };
let left_iter = left.elements(); let left_iter = left.elements().copied();
let right_iter = right.elements(); let right_iter = right.elements().copied();
let mut builder = UnionBuilder::new(self.db()); let mut builder = UnionBuilder::new(self.db());
@ -7695,7 +7696,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
"tuple", "tuple",
value_node.into(), value_node.into(),
value_ty, value_ty,
tuple.display_minimum_length(), tuple.len().display_minimum(),
int, int,
); );
Type::unknown() Type::unknown()
@ -8856,7 +8857,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
let ty = if return_todo { let ty = if return_todo {
todo_type!("PEP 646") todo_type!("PEP 646")
} else { } else {
Type::tuple(self.db(), TupleType::new(self.db(), element_types)) Type::tuple(TupleType::new(self.db(), element_types))
}; };
// Here, we store the type for the inner `int, str` tuple-expression, // Here, we store the type for the inner `int, str` tuple-expression,

View file

@ -19,7 +19,7 @@ impl<'db> Type<'db> {
TupleType::homogeneous(db, Type::unknown()) TupleType::homogeneous(db, Type::unknown())
} }
(ClassType::Generic(alias), Some(KnownClass::Tuple)) => { (ClassType::Generic(alias), Some(KnownClass::Tuple)) => {
Self::tuple(db, TupleType::new(db, alias.specialization(db).tuple(db))) Self::tuple(TupleType::new(db, alias.specialization(db).tuple(db)))
} }
_ if class.class_literal(db).0.is_protocol(db) => { _ if class.class_literal(db).0.is_protocol(db) => {
Self::ProtocolInstance(ProtocolInstanceType::from_class(class)) Self::ProtocolInstance(ProtocolInstanceType::from_class(class))

View file

@ -184,6 +184,7 @@ impl ClassInfoConstraintFunction {
tuple tuple
.tuple(db) .tuple(db)
.all_elements() .all_elements()
.copied()
.map(|element| self.generate_constraint(db, element)), .map(|element| self.generate_constraint(db, element)),
), ),
Type::ClassLiteral(class_literal) => { Type::ClassLiteral(class_literal) => {

File diff suppressed because it is too large Load diff

View file

@ -1,5 +1,4 @@
use std::borrow::Cow; use std::borrow::Cow;
use std::cmp::Ordering;
use ruff_db::parsed::ParsedModuleRef; use ruff_db::parsed::ParsedModuleRef;
use rustc_hash::FxHashMap; use rustc_hash::FxHashMap;
@ -9,13 +8,12 @@ use ruff_python_ast::{self as ast, AnyNodeRef};
use crate::Db; use crate::Db;
use crate::semantic_index::ast_ids::{HasScopedExpressionId, ScopedExpressionId}; use crate::semantic_index::ast_ids::{HasScopedExpressionId, ScopedExpressionId};
use crate::semantic_index::place::ScopeId; use crate::semantic_index::place::ScopeId;
use crate::types::tuple::{FixedLengthTupleSpec, TupleSpec, TupleType}; use crate::types::tuple::{ResizeTupleError, Tuple, TupleLength, TupleUnpacker};
use crate::types::{Type, TypeCheckDiagnostics, infer_expression_types, todo_type}; use crate::types::{Type, TypeCheckDiagnostics, infer_expression_types};
use crate::unpack::{UnpackKind, UnpackValue}; use crate::unpack::{UnpackKind, UnpackValue};
use super::context::InferContext; use super::context::InferContext;
use super::diagnostic::INVALID_ASSIGNMENT; use super::diagnostic::INVALID_ASSIGNMENT;
use super::{KnownClass, UnionType};
/// Unpacks the value expression type to their respective targets. /// Unpacks the value expression type to their respective targets.
pub(crate) struct Unpacker<'db, 'ast> { pub(crate) struct Unpacker<'db, 'ast> {
@ -115,18 +113,13 @@ impl<'db, 'ast> Unpacker<'db, 'ast> {
} }
ast::Expr::List(ast::ExprList { elts, .. }) ast::Expr::List(ast::ExprList { elts, .. })
| ast::Expr::Tuple(ast::ExprTuple { elts, .. }) => { | ast::Expr::Tuple(ast::ExprTuple { elts, .. }) => {
// Initialize the vector of target types, one for each target. let target_len = match elts.iter().position(ast::Expr::is_starred_expr) {
// Some(starred_index) => {
// This is mainly useful for the union type where the target type at index `n` is TupleLength::Variable(starred_index, elts.len() - (starred_index + 1))
// going to be a union of types from every union type element at index `n`. }
// None => TupleLength::Fixed(elts.len()),
// For example, if the type is `tuple[int, int] | tuple[int, str]` and the target };
// has two elements `(a, b)`, then let mut unpacker = TupleUnpacker::new(self.db(), target_len);
// * The type of `a` will be a union of `int` and `int` which are at index 0 in the
// first and second tuple respectively which resolves to an `int`.
// * Similarly, the type of `b` will be a union of `int` and `str` which are at
// index 1 in the first and second tuple respectively which will be `int | str`.
let mut target_types = vec![vec![]; elts.len()];
let unpack_types = match value_ty { let unpack_types = match value_ty {
Type::Union(union_ty) => union_ty.elements(self.db()), Type::Union(union_ty) => union_ty.elements(self.db()),
@ -134,205 +127,75 @@ impl<'db, 'ast> Unpacker<'db, 'ast> {
}; };
for ty in unpack_types.iter().copied() { for ty in unpack_types.iter().copied() {
// Deconstruct certain types to delegate the inference back to the tuple type let tuple = match ty {
// for correct handling of starred expressions. Type::Tuple(tuple_ty) => Cow::Borrowed(tuple_ty.tuple(self.db())),
let ty = match ty {
Type::StringLiteral(string_literal_ty) => { Type::StringLiteral(string_literal_ty) => {
// We could go further and deconstruct to an array of `StringLiteral` // We could go further and deconstruct to an array of `StringLiteral`
// with each individual character, instead of just an array of // with each individual character, instead of just an array of
// `LiteralString`, but there would be a cost and it's not clear that // `LiteralString`, but there would be a cost and it's not clear that
// it's worth it. // it's worth it.
TupleType::from_elements( Cow::Owned(Tuple::from_elements(std::iter::repeat_n(
self.db(),
std::iter::repeat_n(
Type::LiteralString, Type::LiteralString,
string_literal_ty.python_len(self.db()), string_literal_ty.python_len(self.db()),
), )))
) }
Type::LiteralString => Cow::Owned(Tuple::homogeneous(Type::LiteralString)),
_ => {
// TODO: Update our iterator protocol machinery to return a tuple
// describing the returned values in more detail, when we can.
Cow::Owned(Tuple::homogeneous(
ty.try_iterate(self.db()).unwrap_or_else(|err| {
err.report_diagnostic(&self.context, ty, value_expr);
err.fallback_element_type(self.db())
}),
))
} }
_ => ty,
}; };
if let Type::Tuple(tuple_ty) = ty { if let Err(err) = unpacker.unpack_tuple(tuple.as_ref()) {
let tuple = self.tuple_ty_elements(target, elts, tuple_ty, value_expr); unpacker
.unpack_tuple(&Tuple::homogeneous(Type::unknown()))
let length_mismatch = match elts.len().cmp(&tuple.len()) { .expect("adding a homogeneous tuple should always succeed");
Ordering::Less => { if let Some(builder) = self.context.report_lint(&INVALID_ASSIGNMENT, target)
if let Some(builder) =
self.context.report_lint(&INVALID_ASSIGNMENT, target)
{ {
match err {
ResizeTupleError::TooManyValues => {
let mut diag = let mut diag =
builder.into_diagnostic("Too many values to unpack"); builder.into_diagnostic("Too many values to unpack");
diag.set_primary_message(format_args!( diag.set_primary_message(format_args!(
"Expected {}", "Expected {}",
elts.len(), target_len.display_minimum(),
));
diag.annotate(self.context.secondary(value_expr).message(
format_args!("Got {}", tuple.len().display_minimum()),
)); ));
diag.annotate(
self.context
.secondary(value_expr)
.message(format_args!("Got {}", tuple.len())),
);
} }
true ResizeTupleError::TooFewValues => {
}
Ordering::Greater => {
if let Some(builder) =
self.context.report_lint(&INVALID_ASSIGNMENT, target)
{
let mut diag = let mut diag =
builder.into_diagnostic("Not enough values to unpack"); builder.into_diagnostic("Not enough values to unpack");
diag.set_primary_message(format_args!( diag.set_primary_message(format_args!(
"Expected {}", "Expected {}",
elts.len(), target_len.display_minimum(),
));
diag.annotate(self.context.secondary(value_expr).message(
format_args!("Got {}", tuple.len().display_maximum()),
)); ));
diag.annotate(
self.context
.secondary(value_expr)
.message(format_args!("Got {}", tuple.len())),
);
} }
true
}
Ordering::Equal => false,
};
for (index, ty) in tuple.elements().enumerate() {
if let Some(element_types) = target_types.get_mut(index) {
if length_mismatch {
element_types.push(Type::unknown());
} else {
element_types.push(ty);
}
}
}
} else {
let ty = if ty.is_literal_string() {
Type::LiteralString
} else {
ty.try_iterate(self.db()).unwrap_or_else(|err| {
err.report_diagnostic(&self.context, ty, value_expr);
err.fallback_element_type(self.db())
})
};
// Both `elts` and `target_types` are guaranteed to have the same length.
for (element, target_type) in elts.iter().zip(&mut target_types) {
if element.is_starred_expr() {
target_type.push(
KnownClass::List.to_specialized_instance(self.db(), [ty]),
);
} else {
target_type.push(ty);
} }
} }
} }
} }
for (index, element) in elts.iter().enumerate() { // We constructed unpacker above using the length of elts, so the zip should
// SAFETY: `target_types` is initialized with the same length as `elts`. // consume the same number of elements from each.
let element_ty = match target_types[index].as_slice() { for (target, value_ty) in elts.iter().zip(unpacker.into_types()) {
[] => Type::unknown(), self.unpack_inner(target, value_expr, value_ty);
types => UnionType::from_elements(self.db(), types),
};
self.unpack_inner(element, value_expr, element_ty);
} }
} }
_ => {} _ => {}
} }
} }
/// Returns the [`Type`] elements inside the given [`TupleType`] taking into account that there
/// can be a starred expression in the `elements`.
///
/// `value_expr` is an AST reference to the value being unpacked. It is
/// only used for diagnostics.
fn tuple_ty_elements(
&self,
expr: &ast::Expr,
targets: &[ast::Expr],
tuple_ty: TupleType<'db>,
value_expr: AnyNodeRef<'_>,
) -> Cow<'_, FixedLengthTupleSpec<'db>> {
let TupleSpec::Fixed(tuple) = tuple_ty.tuple(self.db()) else {
let todo = todo_type!("Unpack variable-length tuple");
return Cow::Owned(FixedLengthTupleSpec::from_elements(targets.iter().map(
|target| {
if target.is_starred_expr() {
KnownClass::List.to_specialized_instance(self.db(), [todo])
} else {
todo
}
},
)));
};
// If there is a starred expression, it will consume all of the types at that location.
let Some(starred_index) = targets.iter().position(ast::Expr::is_starred_expr) else {
// Otherwise, the types will be unpacked 1-1 to the targets.
return Cow::Borrowed(tuple);
};
if tuple.len() >= targets.len() - 1 {
// This branch is only taken when there are enough elements in the tuple type to
// combine for the starred expression. So, the arithmetic and indexing operations are
// safe to perform.
let mut element_types = FixedLengthTupleSpec::with_capacity(targets.len());
let tuple_elements = tuple.elements_slice();
// Insert all the elements before the starred expression.
// SAFETY: Safe because of the length check above.
element_types.extend_from_slice(&tuple_elements[..starred_index]);
// The number of target expressions that are remaining after the starred expression.
// For example, in `(a, *b, c, d) = ...`, the index of starred element `b` is 1 and the
// remaining elements after that are 2.
let remaining = targets.len() - (starred_index + 1);
// This index represents the position of the last element that belongs to the starred
// expression, in an exclusive manner. For example, in `(a, *b, c) = (1, 2, 3, 4)`, the
// starred expression `b` will consume the elements `Literal[2]` and `Literal[3]` and
// the index value would be 3.
let starred_end_index = tuple.len() - remaining;
// SAFETY: Safe because of the length check above.
let starred_element_types = &tuple_elements[starred_index..starred_end_index];
element_types.push(KnownClass::List.to_specialized_instance(
self.db(),
[if starred_element_types.is_empty() {
Type::unknown()
} else {
UnionType::from_elements(self.db(), starred_element_types)
}],
));
// Insert the types remaining that aren't consumed by the starred expression.
// SAFETY: Safe because of the length check above.
element_types.extend_from_slice(&tuple_elements[starred_end_index..]);
Cow::Owned(element_types)
} else {
if let Some(builder) = self.context.report_lint(&INVALID_ASSIGNMENT, expr) {
let mut diag = builder.into_diagnostic("Not enough values to unpack");
diag.set_primary_message(format_args!("Expected {} or more", targets.len() - 1));
diag.annotate(
self.context
.secondary(value_expr)
.message(format_args!("Got {}", tuple.len())),
);
}
Cow::Owned(FixedLengthTupleSpec::from_elements(targets.iter().map(
|target| {
if target.is_starred_expr() {
KnownClass::List.to_specialized_instance(self.db(), [Type::unknown()])
} else {
Type::unknown()
}
},
)))
}
}
pub(crate) fn finish(mut self) -> UnpackResult<'db> { pub(crate) fn finish(mut self) -> UnpackResult<'db> {
self.targets.shrink_to_fit(); self.targets.shrink_to_fit();
UnpackResult { UnpackResult {