mirror of
https://github.com/astral-sh/ruff.git
synced 2025-08-03 18:28:24 +00:00
[red-knot] Add support for unpacking union types (#15052)
Some checks are pending
CI / cargo fmt (push) Waiting to run
CI / cargo build (release) (push) Waiting to run
CI / cargo shear (push) Blocked by required conditions
CI / Determine changes (push) Waiting to run
CI / cargo clippy (push) Blocked by required conditions
CI / cargo test (linux) (push) Blocked by required conditions
CI / cargo test (linux, release) (push) Blocked by required conditions
CI / cargo test (windows) (push) Blocked by required conditions
CI / cargo test (wasm) (push) Blocked by required conditions
CI / cargo build (msrv) (push) Blocked by required conditions
CI / cargo fuzz build (push) Blocked by required conditions
CI / fuzz parser (push) Blocked by required conditions
CI / test scripts (push) Blocked by required conditions
CI / ecosystem (push) Blocked by required conditions
CI / python package (push) Waiting to run
CI / pre-commit (push) Waiting to run
CI / mkdocs (push) Waiting to run
CI / formatter instabilities and black similarity (push) Blocked by required conditions
CI / test ruff-lsp (push) Blocked by required conditions
CI / benchmarks (push) Blocked by required conditions
Some checks are pending
CI / cargo fmt (push) Waiting to run
CI / cargo build (release) (push) Waiting to run
CI / cargo shear (push) Blocked by required conditions
CI / Determine changes (push) Waiting to run
CI / cargo clippy (push) Blocked by required conditions
CI / cargo test (linux) (push) Blocked by required conditions
CI / cargo test (linux, release) (push) Blocked by required conditions
CI / cargo test (windows) (push) Blocked by required conditions
CI / cargo test (wasm) (push) Blocked by required conditions
CI / cargo build (msrv) (push) Blocked by required conditions
CI / cargo fuzz build (push) Blocked by required conditions
CI / fuzz parser (push) Blocked by required conditions
CI / test scripts (push) Blocked by required conditions
CI / ecosystem (push) Blocked by required conditions
CI / python package (push) Waiting to run
CI / pre-commit (push) Waiting to run
CI / mkdocs (push) Waiting to run
CI / formatter instabilities and black similarity (push) Blocked by required conditions
CI / test ruff-lsp (push) Blocked by required conditions
CI / benchmarks (push) Blocked by required conditions
## Summary Refer: https://github.com/astral-sh/ruff/issues/13773#issuecomment-2548020368 This PR adds support for unpacking union types. Unpacking a union type requires us to first distribute the types for all the targets that are involved in an unpacking. For example, if there are two targets and a union type that needs to be unpacked, each target will get a type from each element in the union type. For example, if the type is `tuple[int, int] | tuple[int, str]` and the target has two elements `(a, b)`, then * The type of `a` will be a union of `int` and `int` which are at index 0 in the first and second tuple respectively which resolves to an `int`. * Similarly, the type of `b` will be a union of `int` and `str` which are at index 1 in the first and second tuple respectively which will be `int | str`. ### Refactors There are couple of refactors that are added in this PR: * Add a `debug_assertion` to validate that the unpack target is a list or a tuple * Add a separate method to handle starred expression ## Test Plan Update `unpacking.md` with additional test cases that uses union types. This is done using parameter type hints style.
This commit is contained in:
parent
089a98e904
commit
d47fba1e4a
4 changed files with 309 additions and 80 deletions
|
@ -306,3 +306,169 @@ reveal_type(b) # revealed: Unknown
|
|||
reveal_type(a) # revealed: LiteralString
|
||||
reveal_type(b) # revealed: LiteralString
|
||||
```
|
||||
|
||||
## Union
|
||||
|
||||
### Same types
|
||||
|
||||
Union of two tuples of equal length and each element is of the same type.
|
||||
|
||||
```py
|
||||
def _(arg: tuple[int, int] | tuple[int, int]):
|
||||
(a, b) = arg
|
||||
reveal_type(a) # revealed: int
|
||||
reveal_type(b) # revealed: int
|
||||
```
|
||||
|
||||
### Mixed types (1)
|
||||
|
||||
Union of two tuples of equal length and one element differs in its type.
|
||||
|
||||
```py
|
||||
def _(arg: tuple[int, int] | tuple[int, str]):
|
||||
a, b = arg
|
||||
reveal_type(a) # revealed: int
|
||||
reveal_type(b) # revealed: int | str
|
||||
```
|
||||
|
||||
### Mixed types (2)
|
||||
|
||||
Union of two tuples of equal length and both the element types are different.
|
||||
|
||||
```py
|
||||
def _(arg: tuple[int, str] | tuple[str, int]):
|
||||
a, b = arg
|
||||
reveal_type(a) # revealed: int | str
|
||||
reveal_type(b) # revealed: str | int
|
||||
```
|
||||
|
||||
### Mixed types (3)
|
||||
|
||||
Union of three tuples of equal length and various combination of element types:
|
||||
|
||||
1. All same types
|
||||
1. One different type
|
||||
1. All different types
|
||||
|
||||
```py
|
||||
def _(arg: tuple[int, int, int] | tuple[int, str, bytes] | tuple[int, int, str]):
|
||||
a, b, c = arg
|
||||
reveal_type(a) # revealed: int
|
||||
reveal_type(b) # revealed: int | str
|
||||
reveal_type(c) # revealed: int | bytes | str
|
||||
```
|
||||
|
||||
### Nested
|
||||
|
||||
```py
|
||||
def _(arg: tuple[int, tuple[str, bytes]] | tuple[tuple[int, bytes], Literal["ab"]]):
|
||||
a, (b, c) = arg
|
||||
reveal_type(a) # revealed: int | tuple[int, bytes]
|
||||
reveal_type(b) # revealed: str
|
||||
reveal_type(c) # revealed: bytes | LiteralString
|
||||
```
|
||||
|
||||
### Starred expression
|
||||
|
||||
```py
|
||||
def _(arg: tuple[int, bytes, int] | tuple[int, int, str, int, bytes]):
|
||||
a, *b, c = arg
|
||||
reveal_type(a) # revealed: int
|
||||
# TODO: Should be `list[bytes | int | str]`
|
||||
reveal_type(b) # revealed: @Todo(starred unpacking)
|
||||
reveal_type(c) # revealed: int | bytes
|
||||
```
|
||||
|
||||
### Size mismatch (1)
|
||||
|
||||
```py
|
||||
def _(arg: tuple[int, bytes, int] | tuple[int, int, str, int, bytes]):
|
||||
# TODO: Add diagnostic (too many values to unpack)
|
||||
a, b = arg
|
||||
reveal_type(a) # revealed: int
|
||||
reveal_type(b) # revealed: bytes | int
|
||||
```
|
||||
|
||||
### Size mismatch (2)
|
||||
|
||||
```py
|
||||
def _(arg: tuple[int, bytes] | tuple[int, str]):
|
||||
# TODO: Add diagnostic (there aren't enough values to unpack)
|
||||
a, b, c = arg
|
||||
reveal_type(a) # revealed: int
|
||||
reveal_type(b) # revealed: bytes | str
|
||||
reveal_type(c) # revealed: Unknown
|
||||
```
|
||||
|
||||
### Same literal types
|
||||
|
||||
```py
|
||||
def _(flag: bool):
|
||||
if flag:
|
||||
value = (1, 2)
|
||||
else:
|
||||
value = (3, 4)
|
||||
|
||||
a, b = value
|
||||
reveal_type(a) # revealed: Literal[1, 3]
|
||||
reveal_type(b) # revealed: Literal[2, 4]
|
||||
```
|
||||
|
||||
### Mixed literal types
|
||||
|
||||
```py
|
||||
def _(flag: bool):
|
||||
if flag:
|
||||
value = (1, 2)
|
||||
else:
|
||||
value = ("a", "b")
|
||||
|
||||
a, b = value
|
||||
reveal_type(a) # revealed: Literal[1] | Literal["a"]
|
||||
reveal_type(b) # revealed: Literal[2] | Literal["b"]
|
||||
```
|
||||
|
||||
### Typing literal
|
||||
|
||||
```py
|
||||
from typing import Literal
|
||||
|
||||
def _(arg: tuple[int, int] | Literal["ab"]):
|
||||
a, b = arg
|
||||
reveal_type(a) # revealed: int | LiteralString
|
||||
reveal_type(b) # revealed: int | LiteralString
|
||||
```
|
||||
|
||||
### Custom iterator (1)
|
||||
|
||||
```py
|
||||
class Iterator:
|
||||
def __next__(self) -> tuple[int, int] | tuple[int, str]:
|
||||
return (1, 2)
|
||||
|
||||
class Iterable:
|
||||
def __iter__(self) -> Iterator:
|
||||
return Iterator()
|
||||
|
||||
((a, b), c) = Iterable()
|
||||
reveal_type(a) # revealed: int
|
||||
reveal_type(b) # revealed: int | str
|
||||
reveal_type(c) # revealed: tuple[int, int] | tuple[int, str]
|
||||
```
|
||||
|
||||
### Custom iterator (2)
|
||||
|
||||
```py
|
||||
class Iterator:
|
||||
def __next__(self) -> bytes:
|
||||
return b""
|
||||
|
||||
class Iterable:
|
||||
def __iter__(self) -> Iterator:
|
||||
return Iterator()
|
||||
|
||||
def _(arg: tuple[int, str] | Iterable):
|
||||
a, b = arg
|
||||
reveal_type(a) # revealed: int | bytes
|
||||
reveal_type(b) # revealed: str | bytes
|
||||
```
|
||||
|
|
|
@ -580,6 +580,13 @@ impl<'db> Type<'db> {
|
|||
.expect("Expected a Type::KnownInstance variant")
|
||||
}
|
||||
|
||||
pub const fn into_tuple(self) -> Option<TupleType<'db>> {
|
||||
match self {
|
||||
Type::Tuple(tuple_type) => Some(tuple_type),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub const fn is_boolean_literal(&self) -> bool {
|
||||
matches!(self, Type::BooleanLiteral(..))
|
||||
}
|
||||
|
|
|
@ -207,8 +207,8 @@ fn infer_unpack_types<'db>(db: &'db dyn Db, unpack: Unpack<'db>) -> UnpackResult
|
|||
let result = infer_expression_types(db, value);
|
||||
let value_ty = result.expression_ty(value.node_ref(db).scoped_expression_id(db, scope));
|
||||
|
||||
let mut unpacker = Unpacker::new(db, file);
|
||||
unpacker.unpack(unpack.target(db), value_ty, scope);
|
||||
let mut unpacker = Unpacker::new(db, scope);
|
||||
unpacker.unpack(unpack.target(db), value_ty);
|
||||
unpacker.finish()
|
||||
}
|
||||
|
||||
|
|
|
@ -1,27 +1,30 @@
|
|||
use std::borrow::Cow;
|
||||
|
||||
use ruff_db::files::File;
|
||||
use ruff_python_ast::{self as ast, AnyNodeRef};
|
||||
use rustc_hash::FxHashMap;
|
||||
|
||||
use ruff_python_ast::{self as ast, AnyNodeRef};
|
||||
|
||||
use crate::semantic_index::ast_ids::{HasScopedExpressionId, ScopedExpressionId};
|
||||
use crate::semantic_index::symbol::ScopeId;
|
||||
use crate::types::{todo_type, Type, TypeCheckDiagnostics};
|
||||
use crate::Db;
|
||||
|
||||
use super::context::{InferContext, WithDiagnostics};
|
||||
use super::{TupleType, UnionType};
|
||||
|
||||
/// Unpacks the value expression type to their respective targets.
|
||||
pub(crate) struct Unpacker<'db> {
|
||||
context: InferContext<'db>,
|
||||
scope: ScopeId<'db>,
|
||||
targets: FxHashMap<ScopedExpressionId, Type<'db>>,
|
||||
}
|
||||
|
||||
impl<'db> Unpacker<'db> {
|
||||
pub(crate) fn new(db: &'db dyn Db, file: File) -> Self {
|
||||
pub(crate) fn new(db: &'db dyn Db, scope: ScopeId<'db>) -> Self {
|
||||
Self {
|
||||
context: InferContext::new(db, file),
|
||||
context: InferContext::new(db, scope.file(db)),
|
||||
targets: FxHashMap::default(),
|
||||
scope,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -29,98 +32,151 @@ impl<'db> Unpacker<'db> {
|
|||
self.context.db()
|
||||
}
|
||||
|
||||
pub(crate) fn unpack(&mut self, target: &ast::Expr, value_ty: Type<'db>, scope: ScopeId<'db>) {
|
||||
/// Unpack the value type to the target expression.
|
||||
pub(crate) fn unpack(&mut self, target: &ast::Expr, value_ty: Type<'db>) {
|
||||
debug_assert!(
|
||||
matches!(target, ast::Expr::List(_) | ast::Expr::Tuple(_)),
|
||||
"Unpacking target must be a list or tuple expression"
|
||||
);
|
||||
|
||||
self.unpack_inner(target, value_ty);
|
||||
}
|
||||
|
||||
fn unpack_inner(&mut self, target: &ast::Expr, value_ty: Type<'db>) {
|
||||
match target {
|
||||
ast::Expr::Name(target_name) => {
|
||||
self.targets
|
||||
.insert(target_name.scoped_expression_id(self.db(), scope), value_ty);
|
||||
self.targets.insert(
|
||||
target_name.scoped_expression_id(self.db(), self.scope),
|
||||
value_ty,
|
||||
);
|
||||
}
|
||||
ast::Expr::Starred(ast::ExprStarred { value, .. }) => {
|
||||
self.unpack(value, value_ty, scope);
|
||||
self.unpack_inner(value, value_ty);
|
||||
}
|
||||
ast::Expr::List(ast::ExprList { elts, .. })
|
||||
| ast::Expr::Tuple(ast::ExprTuple { elts, .. }) => match value_ty {
|
||||
Type::Tuple(tuple_ty) => {
|
||||
let starred_index = elts.iter().position(ast::Expr::is_starred_expr);
|
||||
| ast::Expr::Tuple(ast::ExprTuple { elts, .. }) => {
|
||||
// Initialize the vector of target types, one for each target.
|
||||
//
|
||||
// This is mainly useful for the union type where the target type at index `n` is
|
||||
// going to be a union of types from every union type element at index `n`.
|
||||
//
|
||||
// For example, if the type is `tuple[int, int] | tuple[int, str]` and the target
|
||||
// has two elements `(a, b)`, then
|
||||
// * The type of `a` will be a union of `int` and `int` which are at index 0 in the
|
||||
// first and second tuple respectively which resolves to an `int`.
|
||||
// * Similarly, the type of `b` will be a union of `int` and `str` which are at
|
||||
// index 1 in the first and second tuple respectively which will be `int | str`.
|
||||
let mut target_types = vec![vec![]; elts.len()];
|
||||
|
||||
let element_types = if let Some(starred_index) = starred_index {
|
||||
if tuple_ty.len(self.db()) >= elts.len() - 1 {
|
||||
let mut element_types = Vec::with_capacity(elts.len());
|
||||
element_types.extend_from_slice(
|
||||
// SAFETY: Safe because of the length check above.
|
||||
&tuple_ty.elements(self.db())[..starred_index],
|
||||
);
|
||||
let unpack_types = match value_ty {
|
||||
Type::Union(union_ty) => union_ty.elements(self.db()),
|
||||
_ => std::slice::from_ref(&value_ty),
|
||||
};
|
||||
|
||||
// E.g., in `(a, *b, c, d) = ...`, the index of starred element `b`
|
||||
// is 1 and the remaining elements after that are 2.
|
||||
let remaining = elts.len() - (starred_index + 1);
|
||||
// This index represents the type of the last element that belongs
|
||||
// to the starred expression, in an exclusive manner.
|
||||
let starred_end_index = tuple_ty.len(self.db()) - remaining;
|
||||
// SAFETY: Safe because of the length check above.
|
||||
let _starred_element_types =
|
||||
&tuple_ty.elements(self.db())[starred_index..starred_end_index];
|
||||
// TODO: Combine the types into a list type. If the
|
||||
// starred_element_types is empty, then it should be `List[Any]`.
|
||||
// combine_types(starred_element_types);
|
||||
element_types.push(todo_type!("starred unpacking"));
|
||||
for ty in unpack_types.iter().copied() {
|
||||
// Deconstruct certain types to delegate the inference back to the tuple type
|
||||
// for correct handling of starred expressions.
|
||||
let ty = match ty {
|
||||
Type::StringLiteral(string_literal_ty) => {
|
||||
// We could go further and deconstruct to an array of `StringLiteral`
|
||||
// with each individual character, instead of just an array of
|
||||
// `LiteralString`, but there would be a cost and it's not clear that
|
||||
// it's worth it.
|
||||
Type::tuple(
|
||||
self.db(),
|
||||
std::iter::repeat(Type::LiteralString)
|
||||
.take(string_literal_ty.python_len(self.db())),
|
||||
)
|
||||
}
|
||||
_ => ty,
|
||||
};
|
||||
|
||||
element_types.extend_from_slice(
|
||||
// SAFETY: Safe because of the length check above.
|
||||
&tuple_ty.elements(self.db())[starred_end_index..],
|
||||
);
|
||||
Cow::Owned(element_types)
|
||||
} else {
|
||||
let mut element_types = tuple_ty.elements(self.db()).to_vec();
|
||||
// Subtract 1 to insert the starred expression type at the correct
|
||||
// index.
|
||||
element_types.resize(elts.len() - 1, Type::Unknown);
|
||||
// TODO: This should be `list[Unknown]`
|
||||
element_types.insert(starred_index, todo_type!("starred unpacking"));
|
||||
Cow::Owned(element_types)
|
||||
if let Some(tuple_ty) = ty.into_tuple() {
|
||||
let tuple_ty_elements = self.tuple_ty_elements(elts, tuple_ty);
|
||||
|
||||
// TODO: Add diagnostic for length mismatch
|
||||
|
||||
for (index, ty) in tuple_ty_elements.iter().enumerate() {
|
||||
if let Some(element_types) = target_types.get_mut(index) {
|
||||
element_types.push(*ty);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Cow::Borrowed(tuple_ty.elements(self.db()).as_ref())
|
||||
};
|
||||
let ty = if ty.is_literal_string() {
|
||||
Type::LiteralString
|
||||
} else {
|
||||
ty.iterate(self.db())
|
||||
.unwrap_with_diagnostic(&self.context, AnyNodeRef::from(target))
|
||||
};
|
||||
for target_type in &mut target_types {
|
||||
target_type.push(ty);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (index, element) in elts.iter().enumerate() {
|
||||
self.unpack(
|
||||
element,
|
||||
element_types.get(index).copied().unwrap_or(Type::Unknown),
|
||||
scope,
|
||||
);
|
||||
}
|
||||
}
|
||||
Type::StringLiteral(string_literal_ty) => {
|
||||
// Deconstruct the string literal to delegate the inference back to the
|
||||
// tuple type for correct handling of starred expressions. We could go
|
||||
// further and deconstruct to an array of `StringLiteral` with each
|
||||
// individual character, instead of just an array of `LiteralString`, but
|
||||
// there would be a cost and it's not clear that it's worth it.
|
||||
let value_ty = Type::tuple(
|
||||
self.db(),
|
||||
std::iter::repeat(Type::LiteralString)
|
||||
.take(string_literal_ty.python_len(self.db())),
|
||||
);
|
||||
self.unpack(target, value_ty, scope);
|
||||
}
|
||||
_ => {
|
||||
let value_ty = if value_ty.is_literal_string() {
|
||||
Type::LiteralString
|
||||
} else {
|
||||
value_ty
|
||||
.iterate(self.db())
|
||||
.unwrap_with_diagnostic(&self.context, AnyNodeRef::from(target))
|
||||
for (index, element) in elts.iter().enumerate() {
|
||||
// SAFETY: `target_types` is initialized with the same length as `elts`.
|
||||
let element_ty = match target_types[index].as_slice() {
|
||||
[] => Type::Unknown,
|
||||
types => UnionType::from_elements(self.db(), types),
|
||||
};
|
||||
for element in elts {
|
||||
self.unpack(element, value_ty, scope);
|
||||
}
|
||||
self.unpack_inner(element, element_ty);
|
||||
}
|
||||
},
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the [`Type`] elements inside the given [`TupleType`] taking into account that there
|
||||
/// can be a starred expression in the `elements`.
|
||||
fn tuple_ty_elements(
|
||||
&mut self,
|
||||
targets: &[ast::Expr],
|
||||
tuple_ty: TupleType<'db>,
|
||||
) -> Cow<'_, [Type<'db>]> {
|
||||
// If there is a starred expression, it will consume all of the entries at that location.
|
||||
let Some(starred_index) = targets.iter().position(ast::Expr::is_starred_expr) else {
|
||||
// Otherwise, the types will be unpacked 1-1 to the elements.
|
||||
return Cow::Borrowed(tuple_ty.elements(self.db()).as_ref());
|
||||
};
|
||||
|
||||
if tuple_ty.len(self.db()) >= targets.len() - 1 {
|
||||
let mut element_types = Vec::with_capacity(targets.len());
|
||||
element_types.extend_from_slice(
|
||||
// SAFETY: Safe because of the length check above.
|
||||
&tuple_ty.elements(self.db())[..starred_index],
|
||||
);
|
||||
|
||||
// E.g., in `(a, *b, c, d) = ...`, the index of starred element `b`
|
||||
// is 1 and the remaining elements after that are 2.
|
||||
let remaining = targets.len() - (starred_index + 1);
|
||||
// This index represents the type of the last element that belongs
|
||||
// to the starred expression, in an exclusive manner.
|
||||
let starred_end_index = tuple_ty.len(self.db()) - remaining;
|
||||
// SAFETY: Safe because of the length check above.
|
||||
let _starred_element_types =
|
||||
&tuple_ty.elements(self.db())[starred_index..starred_end_index];
|
||||
// TODO: Combine the types into a list type. If the
|
||||
// starred_element_types is empty, then it should be `List[Any]`.
|
||||
// combine_types(starred_element_types);
|
||||
element_types.push(todo_type!("starred unpacking"));
|
||||
|
||||
element_types.extend_from_slice(
|
||||
// SAFETY: Safe because of the length check above.
|
||||
&tuple_ty.elements(self.db())[starred_end_index..],
|
||||
);
|
||||
Cow::Owned(element_types)
|
||||
} else {
|
||||
let mut element_types = tuple_ty.elements(self.db()).to_vec();
|
||||
// Subtract 1 to insert the starred expression type at the correct
|
||||
// index.
|
||||
element_types.resize(targets.len() - 1, Type::Unknown);
|
||||
// TODO: This should be `list[Unknown]`
|
||||
element_types.insert(starred_index, todo_type!("starred unpacking"));
|
||||
Cow::Owned(element_types)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn finish(mut self) -> UnpackResult<'db> {
|
||||
self.targets.shrink_to_fit();
|
||||
UnpackResult {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue