mirror of
https://github.com/astral-sh/ruff.git
synced 2025-08-04 10:48:32 +00:00
[ty] Argument type expansion for overload call evaluation (#18382)
## Summary
Part of astral-sh/ty#104, closes: astral-sh/ty#468
This PR implements the argument type expansion which is step 3 of the
overload call evaluation algorithm.
Specifically, this step needs to be taken if type checking resolves to
no matching overload and there are argument types that can be expanded.
## Test Plan
Add new test cases.
## Ecosystem analysis
This PR removes 174 `no-matching-overload` false positives -- I looked
at a lot of them and they all are false positives.
One thing that I'm not able to understand is that in
2b7e3adf27/sphinx/ext/autodoc/preserve_defaults.py (L179)
the inferred type of `value` is `str | None` by ty and Pyright, which is
correct, but it's only ty that raises `invalid-argument-type` error
while Pyright doesn't. The constructor method of `DefaultValue` has
declared type of `str` which is invalid.
There are few cases of false positives resulting due to the fact that ty
doesn't implement narrowing on attribute expressions.
This commit is contained in:
parent
0079cc6817
commit
7ea773daf2
3 changed files with 955 additions and 25 deletions
401
crates/ty_python_semantic/resources/mdtest/call/overloads.md
Normal file
401
crates/ty_python_semantic/resources/mdtest/call/overloads.md
Normal file
|
@ -0,0 +1,401 @@
|
|||
# Overloads
|
||||
|
||||
When ty evaluates the call of an overloaded function, it attempts to "match" the supplied arguments
|
||||
with one or more overloads. This document describes the algorithm that it uses for overload
|
||||
matching, which is the same as the one mentioned in the
|
||||
[spec](https://typing.python.org/en/latest/spec/overload.html#overload-call-evaluation).
|
||||
|
||||
## Arity check
|
||||
|
||||
The first step is to perform arity check. The non-overloaded cases are described in the
|
||||
[function](./function.md) document.
|
||||
|
||||
`overloaded.pyi`:
|
||||
|
||||
```pyi
|
||||
from typing import overload
|
||||
|
||||
@overload
|
||||
def f() -> None: ...
|
||||
@overload
|
||||
def f(x: int) -> int: ...
|
||||
```
|
||||
|
||||
```py
|
||||
from overloaded import f
|
||||
|
||||
# These match a single overload
|
||||
reveal_type(f()) # revealed: None
|
||||
reveal_type(f(1)) # revealed: int
|
||||
|
||||
# error: [no-matching-overload] "No overload of function `f` matches arguments"
|
||||
reveal_type(f("a", "b")) # revealed: Unknown
|
||||
```
|
||||
|
||||
## Type checking
|
||||
|
||||
The second step is to perform type checking. This is done for all the overloads that passed the
|
||||
arity check.
|
||||
|
||||
### Single match
|
||||
|
||||
`overloaded.pyi`:
|
||||
|
||||
```pyi
|
||||
from typing import overload
|
||||
|
||||
@overload
|
||||
def f(x: int) -> int: ...
|
||||
@overload
|
||||
def f(x: str) -> str: ...
|
||||
@overload
|
||||
def f(x: bytes) -> bytes: ...
|
||||
```
|
||||
|
||||
Here, all of the calls below pass the arity check for all overloads, so we proceed to type checking
|
||||
which filters out all but the matching overload:
|
||||
|
||||
```py
|
||||
from overloaded import f
|
||||
|
||||
reveal_type(f(1)) # revealed: int
|
||||
reveal_type(f("a")) # revealed: str
|
||||
reveal_type(f(b"b")) # revealed: bytes
|
||||
```
|
||||
|
||||
### Single match error
|
||||
|
||||
`overloaded.pyi`:
|
||||
|
||||
```pyi
|
||||
from typing import overload
|
||||
|
||||
@overload
|
||||
def f() -> None: ...
|
||||
@overload
|
||||
def f(x: int) -> int: ...
|
||||
```
|
||||
|
||||
If the arity check only matches a single overload, it should be evaluated as a regular
|
||||
(non-overloaded) function call. This means that any diagnostics resulted during type checking that
|
||||
call should be reported directly and not as a `no-matching-overload` error.
|
||||
|
||||
```py
|
||||
from overloaded import f
|
||||
|
||||
reveal_type(f()) # revealed: None
|
||||
|
||||
# TODO: This should be `invalid-argument-type` instead
|
||||
# error: [no-matching-overload]
|
||||
reveal_type(f("a")) # revealed: Unknown
|
||||
```
|
||||
|
||||
### Multiple matches
|
||||
|
||||
`overloaded.pyi`:
|
||||
|
||||
```pyi
|
||||
from typing import overload
|
||||
|
||||
class A: ...
|
||||
class B(A): ...
|
||||
|
||||
@overload
|
||||
def f(x: A) -> A: ...
|
||||
@overload
|
||||
def f(x: B, y: int = 0) -> B: ...
|
||||
```
|
||||
|
||||
```py
|
||||
from overloaded import A, B, f
|
||||
|
||||
# These calls pass the arity check, and type checking matches both overloads:
|
||||
reveal_type(f(A())) # revealed: A
|
||||
reveal_type(f(B())) # revealed: A
|
||||
|
||||
# But, in this case, the arity check filters out the first overload, so we only have one match:
|
||||
reveal_type(f(B(), 1)) # revealed: B
|
||||
```
|
||||
|
||||
## Argument type expansion
|
||||
|
||||
This step is performed only if the previous steps resulted in **no matches**.
|
||||
|
||||
In this case, the algorithm would perform
|
||||
[argument type expansion](https://typing.python.org/en/latest/spec/overload.html#argument-type-expansion)
|
||||
and loops over from the type checking step, evaluating the argument lists.
|
||||
|
||||
### Expanding the only argument
|
||||
|
||||
`overloaded.pyi`:
|
||||
|
||||
```pyi
|
||||
from typing import overload
|
||||
|
||||
class A: ...
|
||||
class B: ...
|
||||
class C: ...
|
||||
|
||||
@overload
|
||||
def f(x: A) -> A: ...
|
||||
@overload
|
||||
def f(x: B) -> B: ...
|
||||
@overload
|
||||
def f(x: C) -> C: ...
|
||||
```
|
||||
|
||||
```py
|
||||
from overloaded import A, B, C, f
|
||||
|
||||
def _(ab: A | B, ac: A | C, bc: B | C):
|
||||
reveal_type(f(ab)) # revealed: A | B
|
||||
reveal_type(f(bc)) # revealed: B | C
|
||||
reveal_type(f(ac)) # revealed: A | C
|
||||
```
|
||||
|
||||
### Expanding first argument
|
||||
|
||||
If the set of argument lists created by expanding the first argument evaluates successfully, the
|
||||
algorithm shouldn't expand the second argument.
|
||||
|
||||
`overloaded.pyi`:
|
||||
|
||||
```pyi
|
||||
from typing import Literal, overload
|
||||
|
||||
class A: ...
|
||||
class B: ...
|
||||
class C: ...
|
||||
class D: ...
|
||||
|
||||
@overload
|
||||
def f(x: A, y: C) -> A: ...
|
||||
@overload
|
||||
def f(x: A, y: D) -> B: ...
|
||||
@overload
|
||||
def f(x: B, y: C) -> C: ...
|
||||
@overload
|
||||
def f(x: B, y: D) -> D: ...
|
||||
```
|
||||
|
||||
```py
|
||||
from overloaded import A, B, C, D, f
|
||||
|
||||
def _(a_b: A | B):
|
||||
reveal_type(f(a_b, C())) # revealed: A | C
|
||||
reveal_type(f(a_b, D())) # revealed: B | D
|
||||
|
||||
# But, if it doesn't, it should expand the second argument and try again:
|
||||
def _(a_b: A | B, c_d: C | D):
|
||||
reveal_type(f(a_b, c_d)) # revealed: A | B | C | D
|
||||
```
|
||||
|
||||
### Expanding second argument
|
||||
|
||||
If the first argument cannot be expanded, the algorithm should move on to the second argument,
|
||||
keeping the first argument as is.
|
||||
|
||||
`overloaded.pyi`:
|
||||
|
||||
```pyi
|
||||
from typing import overload
|
||||
|
||||
class A: ...
|
||||
class B: ...
|
||||
class C: ...
|
||||
class D: ...
|
||||
|
||||
@overload
|
||||
def f(x: A, y: B) -> B: ...
|
||||
@overload
|
||||
def f(x: A, y: C) -> C: ...
|
||||
@overload
|
||||
def f(x: B, y: D) -> D: ...
|
||||
```
|
||||
|
||||
```py
|
||||
from overloaded import A, B, C, D, f
|
||||
|
||||
def _(a: A, bc: B | C, cd: C | D):
|
||||
# This also tests that partial matching works correctly as the argument type expansion results
|
||||
# in matching the first and second overloads, but not the third one.
|
||||
reveal_type(f(a, bc)) # revealed: B | C
|
||||
|
||||
# error: [no-matching-overload] "No overload of function `f` matches arguments"
|
||||
reveal_type(f(a, cd)) # revealed: Unknown
|
||||
```
|
||||
|
||||
### Generics (legacy)
|
||||
|
||||
`overloaded.pyi`:
|
||||
|
||||
```pyi
|
||||
from typing import TypeVar, overload
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
class A: ...
|
||||
class B: ...
|
||||
|
||||
@overload
|
||||
def f(x: A) -> A: ...
|
||||
@overload
|
||||
def f(x: _T) -> _T: ...
|
||||
```
|
||||
|
||||
```py
|
||||
from overloaded import A, f
|
||||
|
||||
def _(x: int, y: A | int):
|
||||
reveal_type(f(x)) # revealed: int
|
||||
reveal_type(f(y)) # revealed: A | int
|
||||
```
|
||||
|
||||
### Generics (PEP 695)
|
||||
|
||||
```toml
|
||||
[environment]
|
||||
python-version = "3.12"
|
||||
```
|
||||
|
||||
`overloaded.pyi`:
|
||||
|
||||
```pyi
|
||||
from typing import overload
|
||||
|
||||
class A: ...
|
||||
class B: ...
|
||||
|
||||
@overload
|
||||
def f(x: B) -> B: ...
|
||||
@overload
|
||||
def f[T](x: T) -> T: ...
|
||||
```
|
||||
|
||||
```py
|
||||
from overloaded import B, f
|
||||
|
||||
def _(x: int, y: B | int):
|
||||
reveal_type(f(x)) # revealed: int
|
||||
reveal_type(f(y)) # revealed: B | int
|
||||
```
|
||||
|
||||
### Expanding `bool`
|
||||
|
||||
`overloaded.pyi`:
|
||||
|
||||
```pyi
|
||||
from typing import Literal, overload
|
||||
|
||||
class T: ...
|
||||
class F: ...
|
||||
|
||||
@overload
|
||||
def f(x: Literal[True]) -> T: ...
|
||||
@overload
|
||||
def f(x: Literal[False]) -> F: ...
|
||||
```
|
||||
|
||||
```py
|
||||
from overloaded import f
|
||||
|
||||
def _(flag: bool):
|
||||
reveal_type(f(True)) # revealed: T
|
||||
reveal_type(f(False)) # revealed: F
|
||||
reveal_type(f(flag)) # revealed: T | F
|
||||
```
|
||||
|
||||
### Expanding `tuple`
|
||||
|
||||
`overloaded.pyi`:
|
||||
|
||||
```pyi
|
||||
from typing import Literal, overload
|
||||
|
||||
class A: ...
|
||||
class B: ...
|
||||
class C: ...
|
||||
class D: ...
|
||||
|
||||
@overload
|
||||
def f(x: tuple[A, int], y: tuple[int, Literal[True]]) -> A: ...
|
||||
@overload
|
||||
def f(x: tuple[A, int], y: tuple[int, Literal[False]]) -> B: ...
|
||||
@overload
|
||||
def f(x: tuple[B, int], y: tuple[int, Literal[True]]) -> C: ...
|
||||
@overload
|
||||
def f(x: tuple[B, int], y: tuple[int, Literal[False]]) -> D: ...
|
||||
```
|
||||
|
||||
```py
|
||||
from overloaded import A, B, f
|
||||
|
||||
def _(x: tuple[A | B, int], y: tuple[int, bool]):
|
||||
reveal_type(f(x, y)) # revealed: A | B | C | D
|
||||
```
|
||||
|
||||
### Expanding `type`
|
||||
|
||||
There's no special handling for expanding `type[A | B]` type because ty stores this type in it's
|
||||
distributed form, which is `type[A] | type[B]`.
|
||||
|
||||
`overloaded.pyi`:
|
||||
|
||||
```pyi
|
||||
from typing import overload
|
||||
|
||||
class A: ...
|
||||
class B: ...
|
||||
|
||||
@overload
|
||||
def f(x: type[A]) -> A: ...
|
||||
@overload
|
||||
def f(x: type[B]) -> B: ...
|
||||
```
|
||||
|
||||
```py
|
||||
from overloaded import A, B, f
|
||||
|
||||
def _(x: type[A | B]):
|
||||
reveal_type(x) # revealed: type[A] | type[B]
|
||||
reveal_type(f(x)) # revealed: A | B
|
||||
```
|
||||
|
||||
### Expanding enums
|
||||
|
||||
`overloaded.pyi`:
|
||||
|
||||
```pyi
|
||||
from enum import Enum
|
||||
from typing import Literal, overload
|
||||
|
||||
class SomeEnum(Enum):
|
||||
A = 1
|
||||
B = 2
|
||||
C = 3
|
||||
|
||||
|
||||
class A: ...
|
||||
class B: ...
|
||||
class C: ...
|
||||
|
||||
@overload
|
||||
def f(x: Literal[SomeEnum.A]) -> A: ...
|
||||
@overload
|
||||
def f(x: Literal[SomeEnum.B]) -> B: ...
|
||||
@overload
|
||||
def f(x: Literal[SomeEnum.C]) -> C: ...
|
||||
```
|
||||
|
||||
```py
|
||||
from overloaded import SomeEnum, A, B, C, f
|
||||
|
||||
def _(x: SomeEnum):
|
||||
reveal_type(f(SomeEnum.A)) # revealed: A
|
||||
# TODO: This should be `B` once enums are supported and are expanded
|
||||
reveal_type(f(SomeEnum.B)) # revealed: A
|
||||
# TODO: This should be `C` once enums are supported and are expanded
|
||||
reveal_type(f(SomeEnum.C)) # revealed: A
|
||||
# TODO: This should be `A | B | C` once enums are supported and are expanded
|
||||
reveal_type(f(x)) # revealed: A
|
||||
```
|
|
@ -1,6 +1,11 @@
|
|||
use std::borrow::Cow;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
|
||||
use itertools::{Either, Itertools};
|
||||
|
||||
use crate::Db;
|
||||
use crate::types::{KnownClass, TupleType};
|
||||
|
||||
use super::Type;
|
||||
|
||||
/// Arguments for a single call, in source order.
|
||||
|
@ -86,6 +91,10 @@ impl<'a, 'db> CallArgumentTypes<'a, 'db> {
|
|||
Self { arguments, types }
|
||||
}
|
||||
|
||||
pub(crate) fn types(&self) -> &[Type<'db>] {
|
||||
&self.types
|
||||
}
|
||||
|
||||
/// Prepend an optional extra synthetic argument (for a `self` or `cls` parameter) to the front
|
||||
/// of this argument list. (If `bound_self` is none, we return the argument list
|
||||
/// unmodified.)
|
||||
|
@ -108,6 +117,72 @@ impl<'a, 'db> CallArgumentTypes<'a, 'db> {
|
|||
pub(crate) fn iter(&self) -> impl Iterator<Item = (Argument<'a>, Type<'db>)> + '_ {
|
||||
self.arguments.iter().zip(self.types.iter().copied())
|
||||
}
|
||||
|
||||
/// Returns an iterator on performing [argument type expansion].
|
||||
///
|
||||
/// Each element of the iterator represents a set of argument lists, where each argument list
|
||||
/// contains the same arguments, but with one or more of the argument types expanded.
|
||||
///
|
||||
/// [argument type expansion]: https://typing.python.org/en/latest/spec/overload.html#argument-type-expansion
|
||||
pub(crate) fn expand(&self, db: &'db dyn Db) -> impl Iterator<Item = Vec<Vec<Type<'db>>>> + '_ {
|
||||
/// Represents the state of the expansion process.
|
||||
///
|
||||
/// This is useful to avoid cloning the initial types vector if none of the types can be
|
||||
/// expanded.
|
||||
enum State<'a, 'db> {
|
||||
Initial(&'a Vec<Type<'db>>),
|
||||
Expanded(Vec<Vec<Type<'db>>>),
|
||||
}
|
||||
|
||||
impl<'db> State<'_, 'db> {
|
||||
fn len(&self) -> usize {
|
||||
match self {
|
||||
State::Initial(_) => 1,
|
||||
State::Expanded(expanded) => expanded.len(),
|
||||
}
|
||||
}
|
||||
|
||||
fn iter(&self) -> impl Iterator<Item = &Vec<Type<'db>>> + '_ {
|
||||
match self {
|
||||
State::Initial(types) => std::slice::from_ref(*types).iter(),
|
||||
State::Expanded(expanded) => expanded.iter(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut index = 0;
|
||||
|
||||
std::iter::successors(Some(State::Initial(&self.types)), move |previous| {
|
||||
// Find the next type that can be expanded.
|
||||
let expanded_types = loop {
|
||||
let arg_type = self.types.get(index)?;
|
||||
if let Some(expanded_types) = expand_type(db, *arg_type) {
|
||||
break expanded_types;
|
||||
}
|
||||
index += 1;
|
||||
};
|
||||
|
||||
let mut expanded_arg_types = Vec::with_capacity(expanded_types.len() * previous.len());
|
||||
|
||||
for pre_expanded_types in previous.iter() {
|
||||
for subtype in &expanded_types {
|
||||
let mut new_expanded_types = pre_expanded_types.clone();
|
||||
new_expanded_types[index] = *subtype;
|
||||
expanded_arg_types.push(new_expanded_types);
|
||||
}
|
||||
}
|
||||
|
||||
// Increment the index to move to the next argument type for the next iteration.
|
||||
index += 1;
|
||||
|
||||
Some(State::Expanded(expanded_arg_types))
|
||||
})
|
||||
.skip(1) // Skip the initial state, which has no expanded types.
|
||||
.map(|state| match state {
|
||||
State::Initial(_) => unreachable!("initial state should be skipped"),
|
||||
State::Expanded(expanded) => expanded,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Deref for CallArgumentTypes<'a, '_> {
|
||||
|
@ -122,3 +197,138 @@ impl<'a> DerefMut for CallArgumentTypes<'a, '_> {
|
|||
&mut self.arguments
|
||||
}
|
||||
}
|
||||
|
||||
/// Expands a type into its possible subtypes, if applicable.
|
||||
///
|
||||
/// Returns [`None`] if the type cannot be expanded.
|
||||
fn expand_type<'db>(db: &'db dyn Db, ty: Type<'db>) -> Option<Vec<Type<'db>>> {
|
||||
// TODO: Expand enums to their variants
|
||||
match ty {
|
||||
Type::NominalInstance(instance) if instance.class.is_known(db, KnownClass::Bool) => {
|
||||
Some(vec![
|
||||
Type::BooleanLiteral(true),
|
||||
Type::BooleanLiteral(false),
|
||||
])
|
||||
}
|
||||
Type::Tuple(tuple) => {
|
||||
// Note: This should only account for tuples of known length, i.e., `tuple[bool, ...]`
|
||||
// should not be expanded here.
|
||||
let expanded = tuple
|
||||
.iter(db)
|
||||
.map(|element| {
|
||||
if let Some(expanded) = expand_type(db, element) {
|
||||
Either::Left(expanded.into_iter())
|
||||
} else {
|
||||
Either::Right(std::iter::once(element))
|
||||
}
|
||||
})
|
||||
.multi_cartesian_product()
|
||||
.map(|types| TupleType::from_elements(db, types))
|
||||
.collect::<Vec<_>>();
|
||||
if expanded.len() == 1 {
|
||||
// There are no elements in the tuple type that can be expanded.
|
||||
None
|
||||
} else {
|
||||
Some(expanded)
|
||||
}
|
||||
}
|
||||
Type::Union(union) => Some(union.iter(db).copied().collect()),
|
||||
// We don't handle `type[A | B]` here because it's already stored in the expanded form
|
||||
// i.e., `type[A] | type[B]` which is handled by the `Type::Union` case.
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::db::tests::setup_db;
|
||||
use crate::types::{KnownClass, TupleType, Type, UnionType};
|
||||
|
||||
use super::expand_type;
|
||||
|
||||
#[test]
|
||||
fn expand_union_type() {
|
||||
let db = setup_db();
|
||||
let types = [
|
||||
KnownClass::Int.to_instance(&db),
|
||||
KnownClass::Str.to_instance(&db),
|
||||
KnownClass::Bytes.to_instance(&db),
|
||||
];
|
||||
let union_type = UnionType::from_elements(&db, types);
|
||||
let expanded = expand_type(&db, union_type).unwrap();
|
||||
assert_eq!(expanded.len(), types.len());
|
||||
assert_eq!(expanded, types);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expand_bool_type() {
|
||||
let db = setup_db();
|
||||
let bool_instance = KnownClass::Bool.to_instance(&db);
|
||||
let expanded = expand_type(&db, bool_instance).unwrap();
|
||||
let expected_types = [Type::BooleanLiteral(true), Type::BooleanLiteral(false)];
|
||||
assert_eq!(expanded.len(), expected_types.len());
|
||||
assert_eq!(expanded, expected_types);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expand_tuple_type() {
|
||||
let db = setup_db();
|
||||
|
||||
let int_ty = KnownClass::Int.to_instance(&db);
|
||||
let str_ty = KnownClass::Str.to_instance(&db);
|
||||
let bytes_ty = KnownClass::Bytes.to_instance(&db);
|
||||
let bool_ty = KnownClass::Bool.to_instance(&db);
|
||||
let true_ty = Type::BooleanLiteral(true);
|
||||
let false_ty = Type::BooleanLiteral(false);
|
||||
|
||||
// Empty tuple
|
||||
let empty_tuple = TupleType::empty(&db);
|
||||
let expanded = expand_type(&db, empty_tuple);
|
||||
assert!(expanded.is_none());
|
||||
|
||||
// None of the elements can be expanded.
|
||||
let tuple_type1 = TupleType::from_elements(&db, [int_ty, str_ty]);
|
||||
let expanded = expand_type(&db, tuple_type1);
|
||||
assert!(expanded.is_none());
|
||||
|
||||
// All elements can be expanded.
|
||||
let tuple_type2 = TupleType::from_elements(
|
||||
&db,
|
||||
[
|
||||
bool_ty,
|
||||
UnionType::from_elements(&db, [int_ty, str_ty, bytes_ty]),
|
||||
],
|
||||
);
|
||||
let expected_types = [
|
||||
TupleType::from_elements(&db, [true_ty, int_ty]),
|
||||
TupleType::from_elements(&db, [true_ty, str_ty]),
|
||||
TupleType::from_elements(&db, [true_ty, bytes_ty]),
|
||||
TupleType::from_elements(&db, [false_ty, int_ty]),
|
||||
TupleType::from_elements(&db, [false_ty, str_ty]),
|
||||
TupleType::from_elements(&db, [false_ty, bytes_ty]),
|
||||
];
|
||||
let expanded = expand_type(&db, tuple_type2).unwrap();
|
||||
assert_eq!(expanded.len(), expected_types.len());
|
||||
assert_eq!(expanded, expected_types);
|
||||
|
||||
// Mixed set of elements where some can be expanded while others cannot be.
|
||||
let tuple_type3 = TupleType::from_elements(
|
||||
&db,
|
||||
[
|
||||
bool_ty,
|
||||
int_ty,
|
||||
UnionType::from_elements(&db, [str_ty, bytes_ty]),
|
||||
str_ty,
|
||||
],
|
||||
);
|
||||
let expected_types = [
|
||||
TupleType::from_elements(&db, [true_ty, int_ty, str_ty, str_ty]),
|
||||
TupleType::from_elements(&db, [true_ty, int_ty, bytes_ty, str_ty]),
|
||||
TupleType::from_elements(&db, [false_ty, int_ty, str_ty, str_ty]),
|
||||
TupleType::from_elements(&db, [false_ty, int_ty, bytes_ty, str_ty]),
|
||||
];
|
||||
let expanded = expand_type(&db, tuple_type3).unwrap();
|
||||
assert_eq!(expanded.len(), expected_types.len());
|
||||
assert_eq!(expanded, expected_types);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1012,6 +1012,7 @@ impl<'db> From<Binding<'db>> for Bindings<'db> {
|
|||
signature_type,
|
||||
dunder_call_is_possibly_unbound: false,
|
||||
bound_type: None,
|
||||
return_type: None,
|
||||
overloads: smallvec![from],
|
||||
};
|
||||
Bindings {
|
||||
|
@ -1030,14 +1031,9 @@ impl<'db> From<Binding<'db>> for Bindings<'db> {
|
|||
/// If the callable has multiple overloads, the first one that matches is used as the overall
|
||||
/// binding match.
|
||||
///
|
||||
/// TODO: Implement the call site evaluation algorithm in the [proposed updated typing
|
||||
/// spec][overloads], which is much more subtle than “first match wins”.
|
||||
///
|
||||
/// If the arguments cannot be matched to formal parameters, we store information about the
|
||||
/// specific errors that occurred when trying to match them up. If the callable has multiple
|
||||
/// overloads, we store this error information for each overload.
|
||||
///
|
||||
/// [overloads]: https://github.com/python/typing/pull/1839
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct CallableBinding<'db> {
|
||||
/// The type that is (hopefully) callable.
|
||||
|
@ -1055,6 +1051,14 @@ pub(crate) struct CallableBinding<'db> {
|
|||
/// The type of the bound `self` or `cls` parameter if this signature is for a bound method.
|
||||
pub(crate) bound_type: Option<Type<'db>>,
|
||||
|
||||
/// The return type of this callable.
|
||||
///
|
||||
/// This is only `Some` if it's an overloaded callable, "argument type expansion" was
|
||||
/// performed, and one of the expansion evaluated successfully for all of the argument lists.
|
||||
/// This type is then the union of all the return types of the matched overloads for the
|
||||
/// expanded argument lists.
|
||||
return_type: Option<Type<'db>>,
|
||||
|
||||
/// The bindings of each overload of this callable. Will be empty if the type is not callable.
|
||||
///
|
||||
/// By using `SmallVec`, we avoid an extra heap allocation for the common case of a
|
||||
|
@ -1076,6 +1080,7 @@ impl<'db> CallableBinding<'db> {
|
|||
signature_type,
|
||||
dunder_call_is_possibly_unbound: false,
|
||||
bound_type: None,
|
||||
return_type: None,
|
||||
overloads,
|
||||
}
|
||||
}
|
||||
|
@ -1086,6 +1091,7 @@ impl<'db> CallableBinding<'db> {
|
|||
signature_type,
|
||||
dunder_call_is_possibly_unbound: false,
|
||||
bound_type: None,
|
||||
return_type: None,
|
||||
overloads: smallvec![],
|
||||
}
|
||||
}
|
||||
|
@ -1114,12 +1120,6 @@ impl<'db> CallableBinding<'db> {
|
|||
// before checking.
|
||||
let arguments = arguments.with_self(self.bound_type);
|
||||
|
||||
// TODO: This checks every overload. In the proposed more detailed call checking spec [1],
|
||||
// arguments are checked for arity first, and are only checked for type assignability against
|
||||
// the matching overloads. Make sure to implement that as part of separating call binding into
|
||||
// two phases.
|
||||
//
|
||||
// [1] https://typing.python.org/en/latest/spec/overload.html#overload-call-evaluation
|
||||
for overload in &mut self.overloads {
|
||||
overload.match_parameters(arguments.as_ref(), argument_forms, conflicting_forms);
|
||||
}
|
||||
|
@ -1129,9 +1129,154 @@ impl<'db> CallableBinding<'db> {
|
|||
// If this callable is a bound method, prepend the self instance onto the arguments list
|
||||
// before checking.
|
||||
let argument_types = argument_types.with_self(self.bound_type);
|
||||
for overload in &mut self.overloads {
|
||||
overload.check_types(db, argument_types.as_ref());
|
||||
|
||||
// Step 1: Check the result of the arity check which is done by `match_parameters`
|
||||
let matching_overload_indexes = match self.matching_overload_index() {
|
||||
MatchingOverloadIndex::None => {
|
||||
// If no candidate overloads remain from the arity check, we can stop here. We
|
||||
// still perform type checking for non-overloaded function to provide better user
|
||||
// experience.
|
||||
if let [overload] = self.overloads.as_mut_slice() {
|
||||
overload.check_types(db, argument_types.as_ref(), argument_types.types());
|
||||
}
|
||||
return;
|
||||
}
|
||||
MatchingOverloadIndex::Single(index) => {
|
||||
// If only one candidate overload remains, it is the winning match.
|
||||
// TODO: Evaluate it as a regular (non-overloaded) call. This means that any
|
||||
// diagnostics reported in this check should be reported directly instead of
|
||||
// reporting it as `no-matching-overload`.
|
||||
self.overloads[index].check_types(
|
||||
db,
|
||||
argument_types.as_ref(),
|
||||
argument_types.types(),
|
||||
);
|
||||
return;
|
||||
}
|
||||
MatchingOverloadIndex::Multiple(indexes) => {
|
||||
// If two or more candidate overloads remain, proceed to step 2.
|
||||
indexes
|
||||
}
|
||||
};
|
||||
|
||||
let snapshotter = MatchingOverloadsSnapshotter::new(matching_overload_indexes);
|
||||
|
||||
// State of the bindings _before_ evaluating (type checking) the matching overloads using
|
||||
// the non-expanded argument types.
|
||||
let pre_evaluation_snapshot = snapshotter.take(self);
|
||||
|
||||
// Step 2: Evaluate each remaining overload as a regular (non-overloaded) call to determine
|
||||
// whether it is compatible with the supplied argument list.
|
||||
for (_, overload) in self.matching_overloads_mut() {
|
||||
overload.check_types(db, argument_types.as_ref(), argument_types.types());
|
||||
}
|
||||
|
||||
match self.matching_overload_index() {
|
||||
MatchingOverloadIndex::None => {
|
||||
// If all overloads result in errors, proceed to step 3.
|
||||
}
|
||||
MatchingOverloadIndex::Single(_) => {
|
||||
// If only one overload evaluates without error, it is the winning match.
|
||||
return;
|
||||
}
|
||||
MatchingOverloadIndex::Multiple(_) => {
|
||||
// If two or more candidate overloads remain, proceed to step 4.
|
||||
// TODO: Step 4 and Step 5 goes here...
|
||||
// We're returning here because this shouldn't lead to argument type expansion.
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Perform "argument type expansion". Reference:
|
||||
// https://typing.python.org/en/latest/spec/overload.html#argument-type-expansion
|
||||
let mut expansions = argument_types.expand(db).peekable();
|
||||
|
||||
if expansions.peek().is_none() {
|
||||
// Return early if there are no argument types to expand.
|
||||
return;
|
||||
}
|
||||
|
||||
// State of the bindings _after_ evaluating (type checking) the matching overloads using
|
||||
// the non-expanded argument types.
|
||||
let post_evaluation_snapshot = snapshotter.take(self);
|
||||
|
||||
// Restore the bindings state to the one prior to the type checking step in preparation
|
||||
// for evaluating the expanded argument lists.
|
||||
snapshotter.restore(self, pre_evaluation_snapshot);
|
||||
|
||||
for expanded_argument_lists in expansions {
|
||||
// This is the merged state of the bindings after evaluating all of the expanded
|
||||
// argument lists. This will be the final state to restore the bindings to if all of
|
||||
// the expanded argument lists evaluated successfully.
|
||||
let mut merged_evaluation_state: Option<MatchingOverloadsSnapshot<'db>> = None;
|
||||
|
||||
let mut return_types = Vec::new();
|
||||
|
||||
for expanded_argument_types in &expanded_argument_lists {
|
||||
let pre_evaluation_snapshot = snapshotter.take(self);
|
||||
|
||||
for (_, overload) in self.matching_overloads_mut() {
|
||||
overload.check_types(db, argument_types.as_ref(), expanded_argument_types);
|
||||
}
|
||||
|
||||
let return_type = match self.matching_overload_index() {
|
||||
MatchingOverloadIndex::None => None,
|
||||
MatchingOverloadIndex::Single(index) => {
|
||||
Some(self.overloads[index].return_type())
|
||||
}
|
||||
MatchingOverloadIndex::Multiple(index) => {
|
||||
// TODO: Step 4 and Step 5 goes here... but for now we just use the return
|
||||
// type of the first matched overload.
|
||||
Some(self.overloads[index[0]].return_type())
|
||||
}
|
||||
};
|
||||
|
||||
// This split between initializing and updating the merged evaluation state is
|
||||
// required because otherwise it's difficult to differentiate between the
|
||||
// following:
|
||||
// 1. An initial unmatched overload becomes a matched overload when evaluating the
|
||||
// first argument list
|
||||
// 2. An unmatched overload after evaluating the first argument list becomes a
|
||||
// matched overload when evaluating the second argument list
|
||||
if let Some(merged_evaluation_state) = merged_evaluation_state.as_mut() {
|
||||
merged_evaluation_state.update(self);
|
||||
} else {
|
||||
merged_evaluation_state = Some(snapshotter.take(self));
|
||||
}
|
||||
|
||||
// Restore the bindings state before evaluating the next argument list.
|
||||
snapshotter.restore(self, pre_evaluation_snapshot);
|
||||
|
||||
if let Some(return_type) = return_type {
|
||||
return_types.push(return_type);
|
||||
} else {
|
||||
// No need to check the remaining argument lists if the current argument list
|
||||
// doesn't evaluate successfully. Move on to expanding the next argument type.
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if return_types.len() == expanded_argument_lists.len() {
|
||||
// If the number of return types is equal to the number of expanded argument lists,
|
||||
// they all evaluated successfully. So, we need to combine their return types by
|
||||
// union to determine the final return type.
|
||||
self.return_type = Some(UnionType::from_elements(db, return_types));
|
||||
|
||||
// Restore the bindings state to the one that merges the bindings state evaluating
|
||||
// each of the expanded argument list.
|
||||
if let Some(merged_evaluation_state) = merged_evaluation_state {
|
||||
snapshotter.restore(self, merged_evaluation_state);
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// If the type expansion didn't yield any successful return type, we need to restore the
|
||||
// bindings state back to the one after the type checking step using the non-expanded
|
||||
// argument types. This is necessary because we restore the state to the pre-evaluation
|
||||
// snapshot when processing the expanded argument lists.
|
||||
snapshotter.restore(self, post_evaluation_snapshot);
|
||||
}
|
||||
|
||||
fn as_result(&self) -> Result<(), CallErrorKind> {
|
||||
|
@ -1160,6 +1305,25 @@ impl<'db> CallableBinding<'db> {
|
|||
self.matching_overloads().next().is_none()
|
||||
}
|
||||
|
||||
/// Returns the index of the matching overload in the form of [`MatchingOverloadIndex`].
|
||||
fn matching_overload_index(&self) -> MatchingOverloadIndex {
|
||||
let mut matching_overloads = self.matching_overloads();
|
||||
match matching_overloads.next() {
|
||||
None => MatchingOverloadIndex::None,
|
||||
Some((first, _)) => {
|
||||
if let Some((second, _)) = matching_overloads.next() {
|
||||
let mut indexes = vec![first, second];
|
||||
for (index, _) in matching_overloads {
|
||||
indexes.push(index);
|
||||
}
|
||||
MatchingOverloadIndex::Multiple(indexes)
|
||||
} else {
|
||||
MatchingOverloadIndex::Single(first)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an iterator over all the overloads that matched for this call binding.
|
||||
pub(crate) fn matching_overloads(&self) -> impl Iterator<Item = (usize, &Binding<'db>)> {
|
||||
self.overloads
|
||||
|
@ -1178,16 +1342,20 @@ impl<'db> CallableBinding<'db> {
|
|||
.filter(|(_, overload)| overload.as_result().is_ok())
|
||||
}
|
||||
|
||||
/// Returns the return type of this call. For a valid call, this is the return type of the
|
||||
/// first overload that the arguments matched against. For an invalid call to a non-overloaded
|
||||
/// function, this is the return type of the function. For an invalid call to an overloaded
|
||||
/// function, we return `Type::unknown`, since we cannot make any useful conclusions about
|
||||
/// which overload was intended to be called.
|
||||
/// Returns the return type of this call.
|
||||
///
|
||||
/// For a valid call, this is the return type of either a successful argument type expansion of
|
||||
/// an overloaded function, or the return type of the first overload that the arguments matched
|
||||
/// against.
|
||||
///
|
||||
/// For an invalid call to a non-overloaded function, this is the return type of the function.
|
||||
///
|
||||
/// For an invalid call to an overloaded function, we return `Type::unknown`, since we cannot
|
||||
/// make any useful conclusions about which overload was intended to be called.
|
||||
pub(crate) fn return_type(&self) -> Type<'db> {
|
||||
// TODO: Implement the overload call evaluation algorithm as mentioned in the spec [1] to
|
||||
// get the matching overload and use that to get the return type.
|
||||
//
|
||||
// [1]: https://typing.python.org/en/latest/spec/overload.html#overload-call-evaluation
|
||||
if let Some(return_type) = self.return_type {
|
||||
return return_type;
|
||||
}
|
||||
if let Some((_, first_overload)) = self.matching_overloads().next() {
|
||||
return first_overload.return_type();
|
||||
}
|
||||
|
@ -1336,6 +1504,18 @@ impl<'a, 'db> IntoIterator for &'a CallableBinding<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum MatchingOverloadIndex {
|
||||
/// No matching overloads found.
|
||||
None,
|
||||
|
||||
/// Exactly one matching overload found at the given index.
|
||||
Single(usize),
|
||||
|
||||
/// Multiple matching overloads found at the given indexes.
|
||||
Multiple(Vec<usize>),
|
||||
}
|
||||
|
||||
/// Binding information for one of the overloads of a callable.
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct Binding<'db> {
|
||||
|
@ -1510,7 +1690,12 @@ impl<'db> Binding<'db> {
|
|||
self.parameter_tys = vec![None; parameters.len()].into_boxed_slice();
|
||||
}
|
||||
|
||||
fn check_types(&mut self, db: &'db dyn Db, argument_types: &CallArgumentTypes<'_, 'db>) {
|
||||
fn check_types(
|
||||
&mut self,
|
||||
db: &'db dyn Db,
|
||||
arguments: &CallArguments<'_>,
|
||||
argument_types: &[Type<'db>],
|
||||
) {
|
||||
let mut num_synthetic_args = 0;
|
||||
let get_argument_index = |argument_index: usize, num_synthetic_args: usize| {
|
||||
if argument_index >= num_synthetic_args {
|
||||
|
@ -1524,13 +1709,20 @@ impl<'db> Binding<'db> {
|
|||
}
|
||||
};
|
||||
|
||||
let enumerate_argument_types = || {
|
||||
arguments
|
||||
.iter()
|
||||
.zip(argument_types.iter().copied())
|
||||
.enumerate()
|
||||
};
|
||||
|
||||
// If this overload is generic, first see if we can infer a specialization of the function
|
||||
// from the arguments that were passed in.
|
||||
let signature = &self.signature;
|
||||
let parameters = signature.parameters();
|
||||
if signature.generic_context.is_some() || signature.inherited_generic_context.is_some() {
|
||||
let mut builder = SpecializationBuilder::new(db);
|
||||
for (argument_index, (argument, argument_type)) in argument_types.iter().enumerate() {
|
||||
for (argument_index, (argument, argument_type)) in enumerate_argument_types() {
|
||||
if matches!(argument, Argument::Synthetic) {
|
||||
num_synthetic_args += 1;
|
||||
}
|
||||
|
@ -1562,7 +1754,7 @@ impl<'db> Binding<'db> {
|
|||
}
|
||||
|
||||
num_synthetic_args = 0;
|
||||
for (argument_index, (argument, mut argument_type)) in argument_types.iter().enumerate() {
|
||||
for (argument_index, (argument, mut argument_type)) in enumerate_argument_types() {
|
||||
if matches!(argument, Argument::Synthetic) {
|
||||
num_synthetic_args += 1;
|
||||
}
|
||||
|
@ -1665,6 +1857,133 @@ impl<'db> Binding<'db> {
|
|||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn snapshot(&self) -> BindingSnapshot<'db> {
|
||||
BindingSnapshot {
|
||||
return_ty: self.return_ty,
|
||||
specialization: self.specialization,
|
||||
inherited_specialization: self.inherited_specialization,
|
||||
argument_parameters: self.argument_parameters.clone(),
|
||||
parameter_tys: self.parameter_tys.clone(),
|
||||
errors: self.errors.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn restore(&mut self, snapshot: BindingSnapshot<'db>) {
|
||||
let BindingSnapshot {
|
||||
return_ty,
|
||||
specialization,
|
||||
inherited_specialization,
|
||||
argument_parameters,
|
||||
parameter_tys,
|
||||
errors,
|
||||
} = snapshot;
|
||||
|
||||
self.return_ty = return_ty;
|
||||
self.specialization = specialization;
|
||||
self.inherited_specialization = inherited_specialization;
|
||||
self.argument_parameters = argument_parameters;
|
||||
self.parameter_tys = parameter_tys;
|
||||
self.errors = errors;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct BindingSnapshot<'db> {
|
||||
return_ty: Type<'db>,
|
||||
specialization: Option<Specialization<'db>>,
|
||||
inherited_specialization: Option<Specialization<'db>>,
|
||||
argument_parameters: Box<[Option<usize>]>,
|
||||
parameter_tys: Box<[Option<Type<'db>>]>,
|
||||
errors: Vec<BindingError<'db>>,
|
||||
}
|
||||
|
||||
/// Represents the snapshot of the matched overload bindings.
|
||||
///
|
||||
/// The reason that this only contains the matched overloads are:
|
||||
/// 1. Avoid creating snapshots for the overloads that have been filtered by the arity check
|
||||
/// 2. Avoid duplicating errors when merging the snapshots on a successful evaluation of all the
|
||||
/// expanded argument lists
|
||||
#[derive(Clone, Debug)]
|
||||
struct MatchingOverloadsSnapshot<'db>(Vec<(usize, BindingSnapshot<'db>)>);
|
||||
|
||||
impl<'db> MatchingOverloadsSnapshot<'db> {
|
||||
/// Update the state of the matched overload bindings in this snapshot with the current
|
||||
/// state in the given `binding`.
|
||||
fn update(&mut self, binding: &CallableBinding<'db>) {
|
||||
// Here, the `snapshot` is the state of this binding for the previous argument list and
|
||||
// `binding` would contain the state after evaluating the current argument list.
|
||||
for (snapshot, binding) in self
|
||||
.0
|
||||
.iter_mut()
|
||||
.map(|(index, snapshot)| (snapshot, &binding.overloads[*index]))
|
||||
{
|
||||
if binding.errors.is_empty() {
|
||||
// If the binding has no errors, this means that the current argument list was
|
||||
// evaluated successfully and this is the matching overload.
|
||||
//
|
||||
// Clear the errors from the snapshot of this overload to signal this change ...
|
||||
snapshot.errors.clear();
|
||||
|
||||
// ... and update the snapshot with the current state of the binding.
|
||||
snapshot.return_ty = binding.return_ty;
|
||||
snapshot.specialization = binding.specialization;
|
||||
snapshot.inherited_specialization = binding.inherited_specialization;
|
||||
snapshot
|
||||
.argument_parameters
|
||||
.clone_from(&binding.argument_parameters);
|
||||
snapshot.parameter_tys.clone_from(&binding.parameter_tys);
|
||||
}
|
||||
|
||||
// If the errors in the snapshot was empty, then this binding is the matching overload
|
||||
// for a previously evaluated argument list. This means that we don't need to change
|
||||
// any information for an already matched overload binding.
|
||||
//
|
||||
// If it does have errors, we could extend it with the errors from evaluating the
|
||||
// current argument list. Arguably, this isn't required, since the errors in the
|
||||
// snapshot should already signal that this is an unmatched overload which is why we
|
||||
// don't do it. Similarly, due to this being an unmatched overload, there's no point in
|
||||
// updating the binding state.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A helper to take snapshots of the matched overload bindings for the current state of the
|
||||
/// bindings.
|
||||
struct MatchingOverloadsSnapshotter(Vec<usize>);
|
||||
|
||||
impl MatchingOverloadsSnapshotter {
|
||||
/// Creates a new snapshotter for the given indexes of the matched overloads.
|
||||
fn new(indexes: Vec<usize>) -> Self {
|
||||
debug_assert!(indexes.len() > 1);
|
||||
MatchingOverloadsSnapshotter(indexes)
|
||||
}
|
||||
|
||||
/// Takes a snapshot of the current state of the matched overload bindings.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the indexes of the matched overloads are not valid for the given binding.
|
||||
fn take<'db>(&self, binding: &CallableBinding<'db>) -> MatchingOverloadsSnapshot<'db> {
|
||||
MatchingOverloadsSnapshot(
|
||||
self.0
|
||||
.iter()
|
||||
.map(|index| (*index, binding.overloads[*index].snapshot()))
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Restores the state of the matched overload bindings from the given snapshot.
|
||||
fn restore<'db>(
|
||||
&self,
|
||||
binding: &mut CallableBinding<'db>,
|
||||
snapshot: MatchingOverloadsSnapshot<'db>,
|
||||
) {
|
||||
debug_assert_eq!(self.0.len(), snapshot.0.len());
|
||||
for (index, snapshot) in snapshot.0 {
|
||||
binding.overloads[index].restore(snapshot);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Describes a callable for the purposes of diagnostics.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue