[red-knot] Add support for unpacking union types (#15052)
Some checks are pending
CI / cargo fmt (push) Waiting to run
CI / cargo build (release) (push) Waiting to run
CI / cargo shear (push) Blocked by required conditions
CI / Determine changes (push) Waiting to run
CI / cargo clippy (push) Blocked by required conditions
CI / cargo test (linux) (push) Blocked by required conditions
CI / cargo test (linux, release) (push) Blocked by required conditions
CI / cargo test (windows) (push) Blocked by required conditions
CI / cargo test (wasm) (push) Blocked by required conditions
CI / cargo build (msrv) (push) Blocked by required conditions
CI / cargo fuzz build (push) Blocked by required conditions
CI / fuzz parser (push) Blocked by required conditions
CI / test scripts (push) Blocked by required conditions
CI / ecosystem (push) Blocked by required conditions
CI / python package (push) Waiting to run
CI / pre-commit (push) Waiting to run
CI / mkdocs (push) Waiting to run
CI / formatter instabilities and black similarity (push) Blocked by required conditions
CI / test ruff-lsp (push) Blocked by required conditions
CI / benchmarks (push) Blocked by required conditions

## Summary

Refer:
https://github.com/astral-sh/ruff/issues/13773#issuecomment-2548020368

This PR adds support for unpacking union types. 

Unpacking a union type requires us to first distribute the types for all
the targets that are involved in an unpacking. For example, if there are
two targets and a union type that needs to be unpacked, each target will
get a type from each element in the union type.

For example, if the type is `tuple[int, int] | tuple[int, str]` and the
target has two elements `(a, b)`, then
* The type of `a` will be a union of `int` and `int` which are at index
0 in the first and second tuple respectively which resolves to an `int`.
* Similarly, the type of `b` will be a union of `int` and `str` which
are at index 1 in the first and second tuple respectively which will be
`int | str`.

### Refactors

There are couple of refactors that are added in this PR:
* Add a `debug_assertion` to validate that the unpack target is a list
or a tuple
* Add a separate method to handle starred expression

## Test Plan

Update `unpacking.md` with additional test cases that uses union types.
This is done using parameter type hints style.
This commit is contained in:
Dhruv Manilawala 2024-12-20 16:31:15 +05:30 committed by GitHub
parent 089a98e904
commit d47fba1e4a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 309 additions and 80 deletions

View file

@ -306,3 +306,169 @@ reveal_type(b) # revealed: Unknown
reveal_type(a) # revealed: LiteralString
reveal_type(b) # revealed: LiteralString
```
## Union
### Same types
Union of two tuples of equal length and each element is of the same type.
```py
def _(arg: tuple[int, int] | tuple[int, int]):
(a, b) = arg
reveal_type(a) # revealed: int
reveal_type(b) # revealed: int
```
### Mixed types (1)
Union of two tuples of equal length and one element differs in its type.
```py
def _(arg: tuple[int, int] | tuple[int, str]):
a, b = arg
reveal_type(a) # revealed: int
reveal_type(b) # revealed: int | str
```
### Mixed types (2)
Union of two tuples of equal length and both the element types are different.
```py
def _(arg: tuple[int, str] | tuple[str, int]):
a, b = arg
reveal_type(a) # revealed: int | str
reveal_type(b) # revealed: str | int
```
### Mixed types (3)
Union of three tuples of equal length and various combination of element types:
1. All same types
1. One different type
1. All different types
```py
def _(arg: tuple[int, int, int] | tuple[int, str, bytes] | tuple[int, int, str]):
a, b, c = arg
reveal_type(a) # revealed: int
reveal_type(b) # revealed: int | str
reveal_type(c) # revealed: int | bytes | str
```
### Nested
```py
def _(arg: tuple[int, tuple[str, bytes]] | tuple[tuple[int, bytes], Literal["ab"]]):
a, (b, c) = arg
reveal_type(a) # revealed: int | tuple[int, bytes]
reveal_type(b) # revealed: str
reveal_type(c) # revealed: bytes | LiteralString
```
### Starred expression
```py
def _(arg: tuple[int, bytes, int] | tuple[int, int, str, int, bytes]):
a, *b, c = arg
reveal_type(a) # revealed: int
# TODO: Should be `list[bytes | int | str]`
reveal_type(b) # revealed: @Todo(starred unpacking)
reveal_type(c) # revealed: int | bytes
```
### Size mismatch (1)
```py
def _(arg: tuple[int, bytes, int] | tuple[int, int, str, int, bytes]):
# TODO: Add diagnostic (too many values to unpack)
a, b = arg
reveal_type(a) # revealed: int
reveal_type(b) # revealed: bytes | int
```
### Size mismatch (2)
```py
def _(arg: tuple[int, bytes] | tuple[int, str]):
# TODO: Add diagnostic (there aren't enough values to unpack)
a, b, c = arg
reveal_type(a) # revealed: int
reveal_type(b) # revealed: bytes | str
reveal_type(c) # revealed: Unknown
```
### Same literal types
```py
def _(flag: bool):
if flag:
value = (1, 2)
else:
value = (3, 4)
a, b = value
reveal_type(a) # revealed: Literal[1, 3]
reveal_type(b) # revealed: Literal[2, 4]
```
### Mixed literal types
```py
def _(flag: bool):
if flag:
value = (1, 2)
else:
value = ("a", "b")
a, b = value
reveal_type(a) # revealed: Literal[1] | Literal["a"]
reveal_type(b) # revealed: Literal[2] | Literal["b"]
```
### Typing literal
```py
from typing import Literal
def _(arg: tuple[int, int] | Literal["ab"]):
a, b = arg
reveal_type(a) # revealed: int | LiteralString
reveal_type(b) # revealed: int | LiteralString
```
### Custom iterator (1)
```py
class Iterator:
def __next__(self) -> tuple[int, int] | tuple[int, str]:
return (1, 2)
class Iterable:
def __iter__(self) -> Iterator:
return Iterator()
((a, b), c) = Iterable()
reveal_type(a) # revealed: int
reveal_type(b) # revealed: int | str
reveal_type(c) # revealed: tuple[int, int] | tuple[int, str]
```
### Custom iterator (2)
```py
class Iterator:
def __next__(self) -> bytes:
return b""
class Iterable:
def __iter__(self) -> Iterator:
return Iterator()
def _(arg: tuple[int, str] | Iterable):
a, b = arg
reveal_type(a) # revealed: int | bytes
reveal_type(b) # revealed: str | bytes
```

View file

@ -580,6 +580,13 @@ impl<'db> Type<'db> {
.expect("Expected a Type::KnownInstance variant")
}
pub const fn into_tuple(self) -> Option<TupleType<'db>> {
match self {
Type::Tuple(tuple_type) => Some(tuple_type),
_ => None,
}
}
pub const fn is_boolean_literal(&self) -> bool {
matches!(self, Type::BooleanLiteral(..))
}

View file

@ -207,8 +207,8 @@ fn infer_unpack_types<'db>(db: &'db dyn Db, unpack: Unpack<'db>) -> UnpackResult
let result = infer_expression_types(db, value);
let value_ty = result.expression_ty(value.node_ref(db).scoped_expression_id(db, scope));
let mut unpacker = Unpacker::new(db, file);
unpacker.unpack(unpack.target(db), value_ty, scope);
let mut unpacker = Unpacker::new(db, scope);
unpacker.unpack(unpack.target(db), value_ty);
unpacker.finish()
}

View file

@ -1,27 +1,30 @@
use std::borrow::Cow;
use ruff_db::files::File;
use ruff_python_ast::{self as ast, AnyNodeRef};
use rustc_hash::FxHashMap;
use ruff_python_ast::{self as ast, AnyNodeRef};
use crate::semantic_index::ast_ids::{HasScopedExpressionId, ScopedExpressionId};
use crate::semantic_index::symbol::ScopeId;
use crate::types::{todo_type, Type, TypeCheckDiagnostics};
use crate::Db;
use super::context::{InferContext, WithDiagnostics};
use super::{TupleType, UnionType};
/// Unpacks the value expression type to their respective targets.
pub(crate) struct Unpacker<'db> {
context: InferContext<'db>,
scope: ScopeId<'db>,
targets: FxHashMap<ScopedExpressionId, Type<'db>>,
}
impl<'db> Unpacker<'db> {
pub(crate) fn new(db: &'db dyn Db, file: File) -> Self {
pub(crate) fn new(db: &'db dyn Db, scope: ScopeId<'db>) -> Self {
Self {
context: InferContext::new(db, file),
context: InferContext::new(db, scope.file(db)),
targets: FxHashMap::default(),
scope,
}
}
@ -29,98 +32,151 @@ impl<'db> Unpacker<'db> {
self.context.db()
}
pub(crate) fn unpack(&mut self, target: &ast::Expr, value_ty: Type<'db>, scope: ScopeId<'db>) {
/// Unpack the value type to the target expression.
pub(crate) fn unpack(&mut self, target: &ast::Expr, value_ty: Type<'db>) {
debug_assert!(
matches!(target, ast::Expr::List(_) | ast::Expr::Tuple(_)),
"Unpacking target must be a list or tuple expression"
);
self.unpack_inner(target, value_ty);
}
fn unpack_inner(&mut self, target: &ast::Expr, value_ty: Type<'db>) {
match target {
ast::Expr::Name(target_name) => {
self.targets
.insert(target_name.scoped_expression_id(self.db(), scope), value_ty);
self.targets.insert(
target_name.scoped_expression_id(self.db(), self.scope),
value_ty,
);
}
ast::Expr::Starred(ast::ExprStarred { value, .. }) => {
self.unpack(value, value_ty, scope);
self.unpack_inner(value, value_ty);
}
ast::Expr::List(ast::ExprList { elts, .. })
| ast::Expr::Tuple(ast::ExprTuple { elts, .. }) => match value_ty {
Type::Tuple(tuple_ty) => {
let starred_index = elts.iter().position(ast::Expr::is_starred_expr);
| ast::Expr::Tuple(ast::ExprTuple { elts, .. }) => {
// Initialize the vector of target types, one for each target.
//
// This is mainly useful for the union type where the target type at index `n` is
// going to be a union of types from every union type element at index `n`.
//
// For example, if the type is `tuple[int, int] | tuple[int, str]` and the target
// has two elements `(a, b)`, then
// * The type of `a` will be a union of `int` and `int` which are at index 0 in the
// first and second tuple respectively which resolves to an `int`.
// * Similarly, the type of `b` will be a union of `int` and `str` which are at
// index 1 in the first and second tuple respectively which will be `int | str`.
let mut target_types = vec![vec![]; elts.len()];
let element_types = if let Some(starred_index) = starred_index {
if tuple_ty.len(self.db()) >= elts.len() - 1 {
let mut element_types = Vec::with_capacity(elts.len());
element_types.extend_from_slice(
// SAFETY: Safe because of the length check above.
&tuple_ty.elements(self.db())[..starred_index],
);
let unpack_types = match value_ty {
Type::Union(union_ty) => union_ty.elements(self.db()),
_ => std::slice::from_ref(&value_ty),
};
// E.g., in `(a, *b, c, d) = ...`, the index of starred element `b`
// is 1 and the remaining elements after that are 2.
let remaining = elts.len() - (starred_index + 1);
// This index represents the type of the last element that belongs
// to the starred expression, in an exclusive manner.
let starred_end_index = tuple_ty.len(self.db()) - remaining;
// SAFETY: Safe because of the length check above.
let _starred_element_types =
&tuple_ty.elements(self.db())[starred_index..starred_end_index];
// TODO: Combine the types into a list type. If the
// starred_element_types is empty, then it should be `List[Any]`.
// combine_types(starred_element_types);
element_types.push(todo_type!("starred unpacking"));
for ty in unpack_types.iter().copied() {
// Deconstruct certain types to delegate the inference back to the tuple type
// for correct handling of starred expressions.
let ty = match ty {
Type::StringLiteral(string_literal_ty) => {
// We could go further and deconstruct to an array of `StringLiteral`
// with each individual character, instead of just an array of
// `LiteralString`, but there would be a cost and it's not clear that
// it's worth it.
Type::tuple(
self.db(),
std::iter::repeat(Type::LiteralString)
.take(string_literal_ty.python_len(self.db())),
)
}
_ => ty,
};
element_types.extend_from_slice(
// SAFETY: Safe because of the length check above.
&tuple_ty.elements(self.db())[starred_end_index..],
);
Cow::Owned(element_types)
} else {
let mut element_types = tuple_ty.elements(self.db()).to_vec();
// Subtract 1 to insert the starred expression type at the correct
// index.
element_types.resize(elts.len() - 1, Type::Unknown);
// TODO: This should be `list[Unknown]`
element_types.insert(starred_index, todo_type!("starred unpacking"));
Cow::Owned(element_types)
if let Some(tuple_ty) = ty.into_tuple() {
let tuple_ty_elements = self.tuple_ty_elements(elts, tuple_ty);
// TODO: Add diagnostic for length mismatch
for (index, ty) in tuple_ty_elements.iter().enumerate() {
if let Some(element_types) = target_types.get_mut(index) {
element_types.push(*ty);
}
}
} else {
Cow::Borrowed(tuple_ty.elements(self.db()).as_ref())
};
let ty = if ty.is_literal_string() {
Type::LiteralString
} else {
ty.iterate(self.db())
.unwrap_with_diagnostic(&self.context, AnyNodeRef::from(target))
};
for target_type in &mut target_types {
target_type.push(ty);
}
}
}
for (index, element) in elts.iter().enumerate() {
self.unpack(
element,
element_types.get(index).copied().unwrap_or(Type::Unknown),
scope,
);
}
}
Type::StringLiteral(string_literal_ty) => {
// Deconstruct the string literal to delegate the inference back to the
// tuple type for correct handling of starred expressions. We could go
// further and deconstruct to an array of `StringLiteral` with each
// individual character, instead of just an array of `LiteralString`, but
// there would be a cost and it's not clear that it's worth it.
let value_ty = Type::tuple(
self.db(),
std::iter::repeat(Type::LiteralString)
.take(string_literal_ty.python_len(self.db())),
);
self.unpack(target, value_ty, scope);
}
_ => {
let value_ty = if value_ty.is_literal_string() {
Type::LiteralString
} else {
value_ty
.iterate(self.db())
.unwrap_with_diagnostic(&self.context, AnyNodeRef::from(target))
for (index, element) in elts.iter().enumerate() {
// SAFETY: `target_types` is initialized with the same length as `elts`.
let element_ty = match target_types[index].as_slice() {
[] => Type::Unknown,
types => UnionType::from_elements(self.db(), types),
};
for element in elts {
self.unpack(element, value_ty, scope);
}
self.unpack_inner(element, element_ty);
}
},
}
_ => {}
}
}
/// Returns the [`Type`] elements inside the given [`TupleType`] taking into account that there
/// can be a starred expression in the `elements`.
fn tuple_ty_elements(
&mut self,
targets: &[ast::Expr],
tuple_ty: TupleType<'db>,
) -> Cow<'_, [Type<'db>]> {
// If there is a starred expression, it will consume all of the entries at that location.
let Some(starred_index) = targets.iter().position(ast::Expr::is_starred_expr) else {
// Otherwise, the types will be unpacked 1-1 to the elements.
return Cow::Borrowed(tuple_ty.elements(self.db()).as_ref());
};
if tuple_ty.len(self.db()) >= targets.len() - 1 {
let mut element_types = Vec::with_capacity(targets.len());
element_types.extend_from_slice(
// SAFETY: Safe because of the length check above.
&tuple_ty.elements(self.db())[..starred_index],
);
// E.g., in `(a, *b, c, d) = ...`, the index of starred element `b`
// is 1 and the remaining elements after that are 2.
let remaining = targets.len() - (starred_index + 1);
// This index represents the type of the last element that belongs
// to the starred expression, in an exclusive manner.
let starred_end_index = tuple_ty.len(self.db()) - remaining;
// SAFETY: Safe because of the length check above.
let _starred_element_types =
&tuple_ty.elements(self.db())[starred_index..starred_end_index];
// TODO: Combine the types into a list type. If the
// starred_element_types is empty, then it should be `List[Any]`.
// combine_types(starred_element_types);
element_types.push(todo_type!("starred unpacking"));
element_types.extend_from_slice(
// SAFETY: Safe because of the length check above.
&tuple_ty.elements(self.db())[starred_end_index..],
);
Cow::Owned(element_types)
} else {
let mut element_types = tuple_ty.elements(self.db()).to_vec();
// Subtract 1 to insert the starred expression type at the correct
// index.
element_types.resize(targets.len() - 1, Type::Unknown);
// TODO: This should be `list[Unknown]`
element_types.insert(starred_index, todo_type!("starred unpacking"));
Cow::Owned(element_types)
}
}
pub(crate) fn finish(mut self) -> UnpackResult<'db> {
self.targets.shrink_to_fit();
UnpackResult {