[ty] Argument type expansion for overload call evaluation (#18382)

## Summary

Part of astral-sh/ty#104, closes: astral-sh/ty#468

This PR implements the argument type expansion which is step 3 of the
overload call evaluation algorithm.

Specifically, this step needs to be taken if type checking resolves to
no matching overload and there are argument types that can be expanded.

## Test Plan

Add new test cases.

## Ecosystem analysis

This PR removes 174 `no-matching-overload` false positives -- I looked
at a lot of them and they all are false positives.

One thing that I'm not able to understand is that in
2b7e3adf27/sphinx/ext/autodoc/preserve_defaults.py (L179)
the inferred type of `value` is `str | None` by ty and Pyright, which is
correct, but it's only ty that raises `invalid-argument-type` error
while Pyright doesn't. The constructor method of `DefaultValue` has
declared type of `str` which is invalid.

There are few cases of false positives resulting due to the fact that ty
doesn't implement narrowing on attribute expressions.
This commit is contained in:
Dhruv Manilawala 2025-06-04 07:42:00 +05:30 committed by GitHub
parent 0079cc6817
commit 7ea773daf2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 955 additions and 25 deletions

View file

@ -0,0 +1,401 @@
# Overloads
When ty evaluates the call of an overloaded function, it attempts to "match" the supplied arguments
with one or more overloads. This document describes the algorithm that it uses for overload
matching, which is the same as the one mentioned in the
[spec](https://typing.python.org/en/latest/spec/overload.html#overload-call-evaluation).
## Arity check
The first step is to perform arity check. The non-overloaded cases are described in the
[function](./function.md) document.
`overloaded.pyi`:
```pyi
from typing import overload
@overload
def f() -> None: ...
@overload
def f(x: int) -> int: ...
```
```py
from overloaded import f
# These match a single overload
reveal_type(f()) # revealed: None
reveal_type(f(1)) # revealed: int
# error: [no-matching-overload] "No overload of function `f` matches arguments"
reveal_type(f("a", "b")) # revealed: Unknown
```
## Type checking
The second step is to perform type checking. This is done for all the overloads that passed the
arity check.
### Single match
`overloaded.pyi`:
```pyi
from typing import overload
@overload
def f(x: int) -> int: ...
@overload
def f(x: str) -> str: ...
@overload
def f(x: bytes) -> bytes: ...
```
Here, all of the calls below pass the arity check for all overloads, so we proceed to type checking
which filters out all but the matching overload:
```py
from overloaded import f
reveal_type(f(1)) # revealed: int
reveal_type(f("a")) # revealed: str
reveal_type(f(b"b")) # revealed: bytes
```
### Single match error
`overloaded.pyi`:
```pyi
from typing import overload
@overload
def f() -> None: ...
@overload
def f(x: int) -> int: ...
```
If the arity check only matches a single overload, it should be evaluated as a regular
(non-overloaded) function call. This means that any diagnostics resulted during type checking that
call should be reported directly and not as a `no-matching-overload` error.
```py
from overloaded import f
reveal_type(f()) # revealed: None
# TODO: This should be `invalid-argument-type` instead
# error: [no-matching-overload]
reveal_type(f("a")) # revealed: Unknown
```
### Multiple matches
`overloaded.pyi`:
```pyi
from typing import overload
class A: ...
class B(A): ...
@overload
def f(x: A) -> A: ...
@overload
def f(x: B, y: int = 0) -> B: ...
```
```py
from overloaded import A, B, f
# These calls pass the arity check, and type checking matches both overloads:
reveal_type(f(A())) # revealed: A
reveal_type(f(B())) # revealed: A
# But, in this case, the arity check filters out the first overload, so we only have one match:
reveal_type(f(B(), 1)) # revealed: B
```
## Argument type expansion
This step is performed only if the previous steps resulted in **no matches**.
In this case, the algorithm would perform
[argument type expansion](https://typing.python.org/en/latest/spec/overload.html#argument-type-expansion)
and loops over from the type checking step, evaluating the argument lists.
### Expanding the only argument
`overloaded.pyi`:
```pyi
from typing import overload
class A: ...
class B: ...
class C: ...
@overload
def f(x: A) -> A: ...
@overload
def f(x: B) -> B: ...
@overload
def f(x: C) -> C: ...
```
```py
from overloaded import A, B, C, f
def _(ab: A | B, ac: A | C, bc: B | C):
reveal_type(f(ab)) # revealed: A | B
reveal_type(f(bc)) # revealed: B | C
reveal_type(f(ac)) # revealed: A | C
```
### Expanding first argument
If the set of argument lists created by expanding the first argument evaluates successfully, the
algorithm shouldn't expand the second argument.
`overloaded.pyi`:
```pyi
from typing import Literal, overload
class A: ...
class B: ...
class C: ...
class D: ...
@overload
def f(x: A, y: C) -> A: ...
@overload
def f(x: A, y: D) -> B: ...
@overload
def f(x: B, y: C) -> C: ...
@overload
def f(x: B, y: D) -> D: ...
```
```py
from overloaded import A, B, C, D, f
def _(a_b: A | B):
reveal_type(f(a_b, C())) # revealed: A | C
reveal_type(f(a_b, D())) # revealed: B | D
# But, if it doesn't, it should expand the second argument and try again:
def _(a_b: A | B, c_d: C | D):
reveal_type(f(a_b, c_d)) # revealed: A | B | C | D
```
### Expanding second argument
If the first argument cannot be expanded, the algorithm should move on to the second argument,
keeping the first argument as is.
`overloaded.pyi`:
```pyi
from typing import overload
class A: ...
class B: ...
class C: ...
class D: ...
@overload
def f(x: A, y: B) -> B: ...
@overload
def f(x: A, y: C) -> C: ...
@overload
def f(x: B, y: D) -> D: ...
```
```py
from overloaded import A, B, C, D, f
def _(a: A, bc: B | C, cd: C | D):
# This also tests that partial matching works correctly as the argument type expansion results
# in matching the first and second overloads, but not the third one.
reveal_type(f(a, bc)) # revealed: B | C
# error: [no-matching-overload] "No overload of function `f` matches arguments"
reveal_type(f(a, cd)) # revealed: Unknown
```
### Generics (legacy)
`overloaded.pyi`:
```pyi
from typing import TypeVar, overload
_T = TypeVar("_T")
class A: ...
class B: ...
@overload
def f(x: A) -> A: ...
@overload
def f(x: _T) -> _T: ...
```
```py
from overloaded import A, f
def _(x: int, y: A | int):
reveal_type(f(x)) # revealed: int
reveal_type(f(y)) # revealed: A | int
```
### Generics (PEP 695)
```toml
[environment]
python-version = "3.12"
```
`overloaded.pyi`:
```pyi
from typing import overload
class A: ...
class B: ...
@overload
def f(x: B) -> B: ...
@overload
def f[T](x: T) -> T: ...
```
```py
from overloaded import B, f
def _(x: int, y: B | int):
reveal_type(f(x)) # revealed: int
reveal_type(f(y)) # revealed: B | int
```
### Expanding `bool`
`overloaded.pyi`:
```pyi
from typing import Literal, overload
class T: ...
class F: ...
@overload
def f(x: Literal[True]) -> T: ...
@overload
def f(x: Literal[False]) -> F: ...
```
```py
from overloaded import f
def _(flag: bool):
reveal_type(f(True)) # revealed: T
reveal_type(f(False)) # revealed: F
reveal_type(f(flag)) # revealed: T | F
```
### Expanding `tuple`
`overloaded.pyi`:
```pyi
from typing import Literal, overload
class A: ...
class B: ...
class C: ...
class D: ...
@overload
def f(x: tuple[A, int], y: tuple[int, Literal[True]]) -> A: ...
@overload
def f(x: tuple[A, int], y: tuple[int, Literal[False]]) -> B: ...
@overload
def f(x: tuple[B, int], y: tuple[int, Literal[True]]) -> C: ...
@overload
def f(x: tuple[B, int], y: tuple[int, Literal[False]]) -> D: ...
```
```py
from overloaded import A, B, f
def _(x: tuple[A | B, int], y: tuple[int, bool]):
reveal_type(f(x, y)) # revealed: A | B | C | D
```
### Expanding `type`
There's no special handling for expanding `type[A | B]` type because ty stores this type in it's
distributed form, which is `type[A] | type[B]`.
`overloaded.pyi`:
```pyi
from typing import overload
class A: ...
class B: ...
@overload
def f(x: type[A]) -> A: ...
@overload
def f(x: type[B]) -> B: ...
```
```py
from overloaded import A, B, f
def _(x: type[A | B]):
reveal_type(x) # revealed: type[A] | type[B]
reveal_type(f(x)) # revealed: A | B
```
### Expanding enums
`overloaded.pyi`:
```pyi
from enum import Enum
from typing import Literal, overload
class SomeEnum(Enum):
A = 1
B = 2
C = 3
class A: ...
class B: ...
class C: ...
@overload
def f(x: Literal[SomeEnum.A]) -> A: ...
@overload
def f(x: Literal[SomeEnum.B]) -> B: ...
@overload
def f(x: Literal[SomeEnum.C]) -> C: ...
```
```py
from overloaded import SomeEnum, A, B, C, f
def _(x: SomeEnum):
reveal_type(f(SomeEnum.A)) # revealed: A
# TODO: This should be `B` once enums are supported and are expanded
reveal_type(f(SomeEnum.B)) # revealed: A
# TODO: This should be `C` once enums are supported and are expanded
reveal_type(f(SomeEnum.C)) # revealed: A
# TODO: This should be `A | B | C` once enums are supported and are expanded
reveal_type(f(x)) # revealed: A
```

View file

@ -1,6 +1,11 @@
use std::borrow::Cow;
use std::ops::{Deref, DerefMut};
use itertools::{Either, Itertools};
use crate::Db;
use crate::types::{KnownClass, TupleType};
use super::Type;
/// Arguments for a single call, in source order.
@ -86,6 +91,10 @@ impl<'a, 'db> CallArgumentTypes<'a, 'db> {
Self { arguments, types }
}
pub(crate) fn types(&self) -> &[Type<'db>] {
&self.types
}
/// Prepend an optional extra synthetic argument (for a `self` or `cls` parameter) to the front
/// of this argument list. (If `bound_self` is none, we return the argument list
/// unmodified.)
@ -108,6 +117,72 @@ impl<'a, 'db> CallArgumentTypes<'a, 'db> {
pub(crate) fn iter(&self) -> impl Iterator<Item = (Argument<'a>, Type<'db>)> + '_ {
self.arguments.iter().zip(self.types.iter().copied())
}
/// Returns an iterator on performing [argument type expansion].
///
/// Each element of the iterator represents a set of argument lists, where each argument list
/// contains the same arguments, but with one or more of the argument types expanded.
///
/// [argument type expansion]: https://typing.python.org/en/latest/spec/overload.html#argument-type-expansion
pub(crate) fn expand(&self, db: &'db dyn Db) -> impl Iterator<Item = Vec<Vec<Type<'db>>>> + '_ {
/// Represents the state of the expansion process.
///
/// This is useful to avoid cloning the initial types vector if none of the types can be
/// expanded.
enum State<'a, 'db> {
Initial(&'a Vec<Type<'db>>),
Expanded(Vec<Vec<Type<'db>>>),
}
impl<'db> State<'_, 'db> {
fn len(&self) -> usize {
match self {
State::Initial(_) => 1,
State::Expanded(expanded) => expanded.len(),
}
}
fn iter(&self) -> impl Iterator<Item = &Vec<Type<'db>>> + '_ {
match self {
State::Initial(types) => std::slice::from_ref(*types).iter(),
State::Expanded(expanded) => expanded.iter(),
}
}
}
let mut index = 0;
std::iter::successors(Some(State::Initial(&self.types)), move |previous| {
// Find the next type that can be expanded.
let expanded_types = loop {
let arg_type = self.types.get(index)?;
if let Some(expanded_types) = expand_type(db, *arg_type) {
break expanded_types;
}
index += 1;
};
let mut expanded_arg_types = Vec::with_capacity(expanded_types.len() * previous.len());
for pre_expanded_types in previous.iter() {
for subtype in &expanded_types {
let mut new_expanded_types = pre_expanded_types.clone();
new_expanded_types[index] = *subtype;
expanded_arg_types.push(new_expanded_types);
}
}
// Increment the index to move to the next argument type for the next iteration.
index += 1;
Some(State::Expanded(expanded_arg_types))
})
.skip(1) // Skip the initial state, which has no expanded types.
.map(|state| match state {
State::Initial(_) => unreachable!("initial state should be skipped"),
State::Expanded(expanded) => expanded,
})
}
}
impl<'a> Deref for CallArgumentTypes<'a, '_> {
@ -122,3 +197,138 @@ impl<'a> DerefMut for CallArgumentTypes<'a, '_> {
&mut self.arguments
}
}
/// Expands a type into its possible subtypes, if applicable.
///
/// Returns [`None`] if the type cannot be expanded.
fn expand_type<'db>(db: &'db dyn Db, ty: Type<'db>) -> Option<Vec<Type<'db>>> {
// TODO: Expand enums to their variants
match ty {
Type::NominalInstance(instance) if instance.class.is_known(db, KnownClass::Bool) => {
Some(vec![
Type::BooleanLiteral(true),
Type::BooleanLiteral(false),
])
}
Type::Tuple(tuple) => {
// Note: This should only account for tuples of known length, i.e., `tuple[bool, ...]`
// should not be expanded here.
let expanded = tuple
.iter(db)
.map(|element| {
if let Some(expanded) = expand_type(db, element) {
Either::Left(expanded.into_iter())
} else {
Either::Right(std::iter::once(element))
}
})
.multi_cartesian_product()
.map(|types| TupleType::from_elements(db, types))
.collect::<Vec<_>>();
if expanded.len() == 1 {
// There are no elements in the tuple type that can be expanded.
None
} else {
Some(expanded)
}
}
Type::Union(union) => Some(union.iter(db).copied().collect()),
// We don't handle `type[A | B]` here because it's already stored in the expanded form
// i.e., `type[A] | type[B]` which is handled by the `Type::Union` case.
_ => None,
}
}
#[cfg(test)]
mod tests {
use crate::db::tests::setup_db;
use crate::types::{KnownClass, TupleType, Type, UnionType};
use super::expand_type;
#[test]
fn expand_union_type() {
let db = setup_db();
let types = [
KnownClass::Int.to_instance(&db),
KnownClass::Str.to_instance(&db),
KnownClass::Bytes.to_instance(&db),
];
let union_type = UnionType::from_elements(&db, types);
let expanded = expand_type(&db, union_type).unwrap();
assert_eq!(expanded.len(), types.len());
assert_eq!(expanded, types);
}
#[test]
fn expand_bool_type() {
let db = setup_db();
let bool_instance = KnownClass::Bool.to_instance(&db);
let expanded = expand_type(&db, bool_instance).unwrap();
let expected_types = [Type::BooleanLiteral(true), Type::BooleanLiteral(false)];
assert_eq!(expanded.len(), expected_types.len());
assert_eq!(expanded, expected_types);
}
#[test]
fn expand_tuple_type() {
let db = setup_db();
let int_ty = KnownClass::Int.to_instance(&db);
let str_ty = KnownClass::Str.to_instance(&db);
let bytes_ty = KnownClass::Bytes.to_instance(&db);
let bool_ty = KnownClass::Bool.to_instance(&db);
let true_ty = Type::BooleanLiteral(true);
let false_ty = Type::BooleanLiteral(false);
// Empty tuple
let empty_tuple = TupleType::empty(&db);
let expanded = expand_type(&db, empty_tuple);
assert!(expanded.is_none());
// None of the elements can be expanded.
let tuple_type1 = TupleType::from_elements(&db, [int_ty, str_ty]);
let expanded = expand_type(&db, tuple_type1);
assert!(expanded.is_none());
// All elements can be expanded.
let tuple_type2 = TupleType::from_elements(
&db,
[
bool_ty,
UnionType::from_elements(&db, [int_ty, str_ty, bytes_ty]),
],
);
let expected_types = [
TupleType::from_elements(&db, [true_ty, int_ty]),
TupleType::from_elements(&db, [true_ty, str_ty]),
TupleType::from_elements(&db, [true_ty, bytes_ty]),
TupleType::from_elements(&db, [false_ty, int_ty]),
TupleType::from_elements(&db, [false_ty, str_ty]),
TupleType::from_elements(&db, [false_ty, bytes_ty]),
];
let expanded = expand_type(&db, tuple_type2).unwrap();
assert_eq!(expanded.len(), expected_types.len());
assert_eq!(expanded, expected_types);
// Mixed set of elements where some can be expanded while others cannot be.
let tuple_type3 = TupleType::from_elements(
&db,
[
bool_ty,
int_ty,
UnionType::from_elements(&db, [str_ty, bytes_ty]),
str_ty,
],
);
let expected_types = [
TupleType::from_elements(&db, [true_ty, int_ty, str_ty, str_ty]),
TupleType::from_elements(&db, [true_ty, int_ty, bytes_ty, str_ty]),
TupleType::from_elements(&db, [false_ty, int_ty, str_ty, str_ty]),
TupleType::from_elements(&db, [false_ty, int_ty, bytes_ty, str_ty]),
];
let expanded = expand_type(&db, tuple_type3).unwrap();
assert_eq!(expanded.len(), expected_types.len());
assert_eq!(expanded, expected_types);
}
}

View file

@ -1012,6 +1012,7 @@ impl<'db> From<Binding<'db>> for Bindings<'db> {
signature_type,
dunder_call_is_possibly_unbound: false,
bound_type: None,
return_type: None,
overloads: smallvec![from],
};
Bindings {
@ -1030,14 +1031,9 @@ impl<'db> From<Binding<'db>> for Bindings<'db> {
/// If the callable has multiple overloads, the first one that matches is used as the overall
/// binding match.
///
/// TODO: Implement the call site evaluation algorithm in the [proposed updated typing
/// spec][overloads], which is much more subtle than “first match wins”.
///
/// If the arguments cannot be matched to formal parameters, we store information about the
/// specific errors that occurred when trying to match them up. If the callable has multiple
/// overloads, we store this error information for each overload.
///
/// [overloads]: https://github.com/python/typing/pull/1839
#[derive(Debug)]
pub(crate) struct CallableBinding<'db> {
/// The type that is (hopefully) callable.
@ -1055,6 +1051,14 @@ pub(crate) struct CallableBinding<'db> {
/// The type of the bound `self` or `cls` parameter if this signature is for a bound method.
pub(crate) bound_type: Option<Type<'db>>,
/// The return type of this callable.
///
/// This is only `Some` if it's an overloaded callable, "argument type expansion" was
/// performed, and one of the expansion evaluated successfully for all of the argument lists.
/// This type is then the union of all the return types of the matched overloads for the
/// expanded argument lists.
return_type: Option<Type<'db>>,
/// The bindings of each overload of this callable. Will be empty if the type is not callable.
///
/// By using `SmallVec`, we avoid an extra heap allocation for the common case of a
@ -1076,6 +1080,7 @@ impl<'db> CallableBinding<'db> {
signature_type,
dunder_call_is_possibly_unbound: false,
bound_type: None,
return_type: None,
overloads,
}
}
@ -1086,6 +1091,7 @@ impl<'db> CallableBinding<'db> {
signature_type,
dunder_call_is_possibly_unbound: false,
bound_type: None,
return_type: None,
overloads: smallvec![],
}
}
@ -1114,12 +1120,6 @@ impl<'db> CallableBinding<'db> {
// before checking.
let arguments = arguments.with_self(self.bound_type);
// TODO: This checks every overload. In the proposed more detailed call checking spec [1],
// arguments are checked for arity first, and are only checked for type assignability against
// the matching overloads. Make sure to implement that as part of separating call binding into
// two phases.
//
// [1] https://typing.python.org/en/latest/spec/overload.html#overload-call-evaluation
for overload in &mut self.overloads {
overload.match_parameters(arguments.as_ref(), argument_forms, conflicting_forms);
}
@ -1129,9 +1129,154 @@ impl<'db> CallableBinding<'db> {
// If this callable is a bound method, prepend the self instance onto the arguments list
// before checking.
let argument_types = argument_types.with_self(self.bound_type);
for overload in &mut self.overloads {
overload.check_types(db, argument_types.as_ref());
// Step 1: Check the result of the arity check which is done by `match_parameters`
let matching_overload_indexes = match self.matching_overload_index() {
MatchingOverloadIndex::None => {
// If no candidate overloads remain from the arity check, we can stop here. We
// still perform type checking for non-overloaded function to provide better user
// experience.
if let [overload] = self.overloads.as_mut_slice() {
overload.check_types(db, argument_types.as_ref(), argument_types.types());
}
return;
}
MatchingOverloadIndex::Single(index) => {
// If only one candidate overload remains, it is the winning match.
// TODO: Evaluate it as a regular (non-overloaded) call. This means that any
// diagnostics reported in this check should be reported directly instead of
// reporting it as `no-matching-overload`.
self.overloads[index].check_types(
db,
argument_types.as_ref(),
argument_types.types(),
);
return;
}
MatchingOverloadIndex::Multiple(indexes) => {
// If two or more candidate overloads remain, proceed to step 2.
indexes
}
};
let snapshotter = MatchingOverloadsSnapshotter::new(matching_overload_indexes);
// State of the bindings _before_ evaluating (type checking) the matching overloads using
// the non-expanded argument types.
let pre_evaluation_snapshot = snapshotter.take(self);
// Step 2: Evaluate each remaining overload as a regular (non-overloaded) call to determine
// whether it is compatible with the supplied argument list.
for (_, overload) in self.matching_overloads_mut() {
overload.check_types(db, argument_types.as_ref(), argument_types.types());
}
match self.matching_overload_index() {
MatchingOverloadIndex::None => {
// If all overloads result in errors, proceed to step 3.
}
MatchingOverloadIndex::Single(_) => {
// If only one overload evaluates without error, it is the winning match.
return;
}
MatchingOverloadIndex::Multiple(_) => {
// If two or more candidate overloads remain, proceed to step 4.
// TODO: Step 4 and Step 5 goes here...
// We're returning here because this shouldn't lead to argument type expansion.
return;
}
}
// Step 3: Perform "argument type expansion". Reference:
// https://typing.python.org/en/latest/spec/overload.html#argument-type-expansion
let mut expansions = argument_types.expand(db).peekable();
if expansions.peek().is_none() {
// Return early if there are no argument types to expand.
return;
}
// State of the bindings _after_ evaluating (type checking) the matching overloads using
// the non-expanded argument types.
let post_evaluation_snapshot = snapshotter.take(self);
// Restore the bindings state to the one prior to the type checking step in preparation
// for evaluating the expanded argument lists.
snapshotter.restore(self, pre_evaluation_snapshot);
for expanded_argument_lists in expansions {
// This is the merged state of the bindings after evaluating all of the expanded
// argument lists. This will be the final state to restore the bindings to if all of
// the expanded argument lists evaluated successfully.
let mut merged_evaluation_state: Option<MatchingOverloadsSnapshot<'db>> = None;
let mut return_types = Vec::new();
for expanded_argument_types in &expanded_argument_lists {
let pre_evaluation_snapshot = snapshotter.take(self);
for (_, overload) in self.matching_overloads_mut() {
overload.check_types(db, argument_types.as_ref(), expanded_argument_types);
}
let return_type = match self.matching_overload_index() {
MatchingOverloadIndex::None => None,
MatchingOverloadIndex::Single(index) => {
Some(self.overloads[index].return_type())
}
MatchingOverloadIndex::Multiple(index) => {
// TODO: Step 4 and Step 5 goes here... but for now we just use the return
// type of the first matched overload.
Some(self.overloads[index[0]].return_type())
}
};
// This split between initializing and updating the merged evaluation state is
// required because otherwise it's difficult to differentiate between the
// following:
// 1. An initial unmatched overload becomes a matched overload when evaluating the
// first argument list
// 2. An unmatched overload after evaluating the first argument list becomes a
// matched overload when evaluating the second argument list
if let Some(merged_evaluation_state) = merged_evaluation_state.as_mut() {
merged_evaluation_state.update(self);
} else {
merged_evaluation_state = Some(snapshotter.take(self));
}
// Restore the bindings state before evaluating the next argument list.
snapshotter.restore(self, pre_evaluation_snapshot);
if let Some(return_type) = return_type {
return_types.push(return_type);
} else {
// No need to check the remaining argument lists if the current argument list
// doesn't evaluate successfully. Move on to expanding the next argument type.
break;
}
}
if return_types.len() == expanded_argument_lists.len() {
// If the number of return types is equal to the number of expanded argument lists,
// they all evaluated successfully. So, we need to combine their return types by
// union to determine the final return type.
self.return_type = Some(UnionType::from_elements(db, return_types));
// Restore the bindings state to the one that merges the bindings state evaluating
// each of the expanded argument list.
if let Some(merged_evaluation_state) = merged_evaluation_state {
snapshotter.restore(self, merged_evaluation_state);
}
return;
}
}
// If the type expansion didn't yield any successful return type, we need to restore the
// bindings state back to the one after the type checking step using the non-expanded
// argument types. This is necessary because we restore the state to the pre-evaluation
// snapshot when processing the expanded argument lists.
snapshotter.restore(self, post_evaluation_snapshot);
}
fn as_result(&self) -> Result<(), CallErrorKind> {
@ -1160,6 +1305,25 @@ impl<'db> CallableBinding<'db> {
self.matching_overloads().next().is_none()
}
/// Returns the index of the matching overload in the form of [`MatchingOverloadIndex`].
fn matching_overload_index(&self) -> MatchingOverloadIndex {
let mut matching_overloads = self.matching_overloads();
match matching_overloads.next() {
None => MatchingOverloadIndex::None,
Some((first, _)) => {
if let Some((second, _)) = matching_overloads.next() {
let mut indexes = vec![first, second];
for (index, _) in matching_overloads {
indexes.push(index);
}
MatchingOverloadIndex::Multiple(indexes)
} else {
MatchingOverloadIndex::Single(first)
}
}
}
}
/// Returns an iterator over all the overloads that matched for this call binding.
pub(crate) fn matching_overloads(&self) -> impl Iterator<Item = (usize, &Binding<'db>)> {
self.overloads
@ -1178,16 +1342,20 @@ impl<'db> CallableBinding<'db> {
.filter(|(_, overload)| overload.as_result().is_ok())
}
/// Returns the return type of this call. For a valid call, this is the return type of the
/// first overload that the arguments matched against. For an invalid call to a non-overloaded
/// function, this is the return type of the function. For an invalid call to an overloaded
/// function, we return `Type::unknown`, since we cannot make any useful conclusions about
/// which overload was intended to be called.
/// Returns the return type of this call.
///
/// For a valid call, this is the return type of either a successful argument type expansion of
/// an overloaded function, or the return type of the first overload that the arguments matched
/// against.
///
/// For an invalid call to a non-overloaded function, this is the return type of the function.
///
/// For an invalid call to an overloaded function, we return `Type::unknown`, since we cannot
/// make any useful conclusions about which overload was intended to be called.
pub(crate) fn return_type(&self) -> Type<'db> {
// TODO: Implement the overload call evaluation algorithm as mentioned in the spec [1] to
// get the matching overload and use that to get the return type.
//
// [1]: https://typing.python.org/en/latest/spec/overload.html#overload-call-evaluation
if let Some(return_type) = self.return_type {
return return_type;
}
if let Some((_, first_overload)) = self.matching_overloads().next() {
return first_overload.return_type();
}
@ -1336,6 +1504,18 @@ impl<'a, 'db> IntoIterator for &'a CallableBinding<'db> {
}
}
#[derive(Debug)]
enum MatchingOverloadIndex {
/// No matching overloads found.
None,
/// Exactly one matching overload found at the given index.
Single(usize),
/// Multiple matching overloads found at the given indexes.
Multiple(Vec<usize>),
}
/// Binding information for one of the overloads of a callable.
#[derive(Debug)]
pub(crate) struct Binding<'db> {
@ -1510,7 +1690,12 @@ impl<'db> Binding<'db> {
self.parameter_tys = vec![None; parameters.len()].into_boxed_slice();
}
fn check_types(&mut self, db: &'db dyn Db, argument_types: &CallArgumentTypes<'_, 'db>) {
fn check_types(
&mut self,
db: &'db dyn Db,
arguments: &CallArguments<'_>,
argument_types: &[Type<'db>],
) {
let mut num_synthetic_args = 0;
let get_argument_index = |argument_index: usize, num_synthetic_args: usize| {
if argument_index >= num_synthetic_args {
@ -1524,13 +1709,20 @@ impl<'db> Binding<'db> {
}
};
let enumerate_argument_types = || {
arguments
.iter()
.zip(argument_types.iter().copied())
.enumerate()
};
// If this overload is generic, first see if we can infer a specialization of the function
// from the arguments that were passed in.
let signature = &self.signature;
let parameters = signature.parameters();
if signature.generic_context.is_some() || signature.inherited_generic_context.is_some() {
let mut builder = SpecializationBuilder::new(db);
for (argument_index, (argument, argument_type)) in argument_types.iter().enumerate() {
for (argument_index, (argument, argument_type)) in enumerate_argument_types() {
if matches!(argument, Argument::Synthetic) {
num_synthetic_args += 1;
}
@ -1562,7 +1754,7 @@ impl<'db> Binding<'db> {
}
num_synthetic_args = 0;
for (argument_index, (argument, mut argument_type)) in argument_types.iter().enumerate() {
for (argument_index, (argument, mut argument_type)) in enumerate_argument_types() {
if matches!(argument, Argument::Synthetic) {
num_synthetic_args += 1;
}
@ -1665,6 +1857,133 @@ impl<'db> Binding<'db> {
}
Ok(())
}
fn snapshot(&self) -> BindingSnapshot<'db> {
BindingSnapshot {
return_ty: self.return_ty,
specialization: self.specialization,
inherited_specialization: self.inherited_specialization,
argument_parameters: self.argument_parameters.clone(),
parameter_tys: self.parameter_tys.clone(),
errors: self.errors.clone(),
}
}
fn restore(&mut self, snapshot: BindingSnapshot<'db>) {
let BindingSnapshot {
return_ty,
specialization,
inherited_specialization,
argument_parameters,
parameter_tys,
errors,
} = snapshot;
self.return_ty = return_ty;
self.specialization = specialization;
self.inherited_specialization = inherited_specialization;
self.argument_parameters = argument_parameters;
self.parameter_tys = parameter_tys;
self.errors = errors;
}
}
#[derive(Clone, Debug)]
struct BindingSnapshot<'db> {
return_ty: Type<'db>,
specialization: Option<Specialization<'db>>,
inherited_specialization: Option<Specialization<'db>>,
argument_parameters: Box<[Option<usize>]>,
parameter_tys: Box<[Option<Type<'db>>]>,
errors: Vec<BindingError<'db>>,
}
/// Represents the snapshot of the matched overload bindings.
///
/// The reason that this only contains the matched overloads are:
/// 1. Avoid creating snapshots for the overloads that have been filtered by the arity check
/// 2. Avoid duplicating errors when merging the snapshots on a successful evaluation of all the
/// expanded argument lists
#[derive(Clone, Debug)]
struct MatchingOverloadsSnapshot<'db>(Vec<(usize, BindingSnapshot<'db>)>);
impl<'db> MatchingOverloadsSnapshot<'db> {
/// Update the state of the matched overload bindings in this snapshot with the current
/// state in the given `binding`.
fn update(&mut self, binding: &CallableBinding<'db>) {
// Here, the `snapshot` is the state of this binding for the previous argument list and
// `binding` would contain the state after evaluating the current argument list.
for (snapshot, binding) in self
.0
.iter_mut()
.map(|(index, snapshot)| (snapshot, &binding.overloads[*index]))
{
if binding.errors.is_empty() {
// If the binding has no errors, this means that the current argument list was
// evaluated successfully and this is the matching overload.
//
// Clear the errors from the snapshot of this overload to signal this change ...
snapshot.errors.clear();
// ... and update the snapshot with the current state of the binding.
snapshot.return_ty = binding.return_ty;
snapshot.specialization = binding.specialization;
snapshot.inherited_specialization = binding.inherited_specialization;
snapshot
.argument_parameters
.clone_from(&binding.argument_parameters);
snapshot.parameter_tys.clone_from(&binding.parameter_tys);
}
// If the errors in the snapshot was empty, then this binding is the matching overload
// for a previously evaluated argument list. This means that we don't need to change
// any information for an already matched overload binding.
//
// If it does have errors, we could extend it with the errors from evaluating the
// current argument list. Arguably, this isn't required, since the errors in the
// snapshot should already signal that this is an unmatched overload which is why we
// don't do it. Similarly, due to this being an unmatched overload, there's no point in
// updating the binding state.
}
}
}
/// A helper to take snapshots of the matched overload bindings for the current state of the
/// bindings.
struct MatchingOverloadsSnapshotter(Vec<usize>);
impl MatchingOverloadsSnapshotter {
/// Creates a new snapshotter for the given indexes of the matched overloads.
fn new(indexes: Vec<usize>) -> Self {
debug_assert!(indexes.len() > 1);
MatchingOverloadsSnapshotter(indexes)
}
/// Takes a snapshot of the current state of the matched overload bindings.
///
/// # Panics
///
/// Panics if the indexes of the matched overloads are not valid for the given binding.
fn take<'db>(&self, binding: &CallableBinding<'db>) -> MatchingOverloadsSnapshot<'db> {
MatchingOverloadsSnapshot(
self.0
.iter()
.map(|index| (*index, binding.overloads[*index].snapshot()))
.collect(),
)
}
/// Restores the state of the matched overload bindings from the given snapshot.
fn restore<'db>(
&self,
binding: &mut CallableBinding<'db>,
snapshot: MatchingOverloadsSnapshot<'db>,
) {
debug_assert_eq!(self.0.len(), snapshot.0.len());
for (index, snapshot) in snapshot.0 {
binding.overloads[index].restore(snapshot);
}
}
}
/// Describes a callable for the purposes of diagnostics.