[ty] Filter overloads using variadic parameters (#20547)

## Summary

Closes: https://github.com/astral-sh/ty/issues/551

This PR adds support for step 4 of the overload call evaluation
algorithm which states that:

> If the argument list is compatible with two or more overloads,
determine whether one or more of the overloads has a variadic parameter
(either `*args` or `**kwargs`) that maps to a corresponding argument
that supplies an indeterminate number of positional or keyword
arguments. If so, eliminate overloads that do not have a variadic
parameter.

And, with that, the overload call evaluation algorithm has been
implemented completely end to end as stated in the typing spec.

## Test Plan

Expand the overload call test suite.
This commit is contained in:
Dhruv Manilawala 2025-09-25 20:28:00 +05:30 committed by GitHub
parent b0bdf0334e
commit 35ed55ec8c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 238 additions and 16 deletions

View file

@ -159,6 +159,8 @@ def _(args: list[int]) -> None:
takes_zero(*args) takes_zero(*args)
takes_one(*args) takes_one(*args)
takes_two(*args) takes_two(*args)
takes_two(*b"ab")
takes_two(*b"abc") # error: [too-many-positional-arguments]
takes_two_positional_only(*args) takes_two_positional_only(*args)
takes_two_different(*args) # error: [invalid-argument-type] takes_two_different(*args) # error: [invalid-argument-type]
takes_two_different_positional_only(*args) # error: [invalid-argument-type] takes_two_different_positional_only(*args) # error: [invalid-argument-type]

View file

@ -931,6 +931,134 @@ def _(t: tuple[int, str] | tuple[int, str, int]) -> None:
f(*t) # error: [no-matching-overload] f(*t) # error: [no-matching-overload]
``` ```
## Filtering based on variaidic arguments
This is step 4 of the overload call evaluation algorithm which specifies that:
> If the argument list is compatible with two or more overloads, determine whether one or more of
> the overloads has a variadic parameter (either `*args` or `**kwargs`) that maps to a corresponding
> argument that supplies an indeterminate number of positional or keyword arguments. If so,
> eliminate overloads that do not have a variadic parameter.
This is only performed if the previous step resulted in more than one matching overload.
### Simple `*args`
`overloaded.pyi`:
```pyi
from typing import overload
@overload
def f(x1: int) -> tuple[int]: ...
@overload
def f(x1: int, x2: int) -> tuple[int, int]: ...
@overload
def f(*args: int) -> int: ...
```
```py
from overloaded import f
def _(x1: int, x2: int, args: list[int]):
reveal_type(f(x1)) # revealed: tuple[int]
reveal_type(f(x1, x2)) # revealed: tuple[int, int]
reveal_type(f(*(x1, x2))) # revealed: tuple[int, int]
# Step 4 should filter out all but the last overload.
reveal_type(f(*args)) # revealed: int
```
### Variable `*args`
```toml
[environment]
python-version = "3.11"
```
`overloaded.pyi`:
```pyi
from typing import overload
@overload
def f(x1: int) -> tuple[int]: ...
@overload
def f(x1: int, x2: int) -> tuple[int, int]: ...
@overload
def f(x1: int, *args: int) -> tuple[int, ...]: ...
```
```py
from overloaded import f
def _(x1: int, x2: int, args1: list[int], args2: tuple[int, *tuple[int, ...]]):
reveal_type(f(x1, x2)) # revealed: tuple[int, int]
reveal_type(f(*(x1, x2))) # revealed: tuple[int, int]
# Step 4 should filter out all but the last overload.
reveal_type(f(x1, *args1)) # revealed: tuple[int, ...]
reveal_type(f(*args2)) # revealed: tuple[int, ...]
```
### Simple `**kwargs`
`overloaded.pyi`:
```pyi
from typing import overload
@overload
def f(*, x1: int) -> int: ...
@overload
def f(*, x1: int, x2: int) -> tuple[int, int]: ...
@overload
def f(**kwargs: int) -> int: ...
```
```py
from overloaded import f
def _(x1: int, x2: int, kwargs: dict[str, int]):
reveal_type(f(x1=x1)) # revealed: int
reveal_type(f(x1=x1, x2=x2)) # revealed: tuple[int, int]
# Step 4 should filter out all but the last overload.
reveal_type(f(**{"x1": x1, "x2": x2})) # revealed: int
reveal_type(f(**kwargs)) # revealed: int
```
### `TypedDict`
The keys in a `TypedDict` are static so there's no variable part to it, so step 4 shouldn't filter
out any overloads.
`overloaded.pyi`:
```pyi
from typing import TypedDict, overload
@overload
def f(*, x: int) -> int: ...
@overload
def f(*, x: int, y: int) -> tuple[int, int]: ...
@overload
def f(**kwargs: int) -> tuple[int, ...]: ...
```
```py
from typing import TypedDict
from overloaded import f
class Foo(TypedDict):
x: int
y: int
def _(foo: Foo, kwargs: dict[str, int]):
reveal_type(f(**foo)) # revealed: tuple[int, int]
reveal_type(f(**kwargs)) # revealed: tuple[int, ...]
```
## Filtering based on `Any` / `Unknown` ## Filtering based on `Any` / `Unknown`
This is the step 5 of the overload call evaluation algorithm which specifies that: This is the step 5 of the overload call evaluation algorithm which specifies that:

View file

@ -1333,10 +1333,30 @@ impl<'db> CallableBinding<'db> {
} }
MatchingOverloadIndex::Multiple(indexes) => { MatchingOverloadIndex::Multiple(indexes) => {
// If two or more candidate overloads remain, proceed to step 4. // If two or more candidate overloads remain, proceed to step 4.
// TODO: Step 4 self.filter_overloads_containing_variadic(&indexes);
// Step 5 match self.matching_overload_index() {
self.filter_overloads_using_any_or_unknown(db, argument_types.as_ref(), &indexes); MatchingOverloadIndex::None => {
// This shouldn't be possible because step 4 can only filter out overloads
// when there _is_ a matching variadic argument.
tracing::debug!("All overloads have been filtered out in step 4");
return None;
}
MatchingOverloadIndex::Single(index) => {
// If only one candidate overload remains, it is the winning match.
// Evaluate it as a regular (non-overloaded) call.
self.matching_overload_index = Some(index);
return None;
}
MatchingOverloadIndex::Multiple(indexes) => {
// If two or more candidate overloads remain, proceed to step 5.
self.filter_overloads_using_any_or_unknown(
db,
argument_types.as_ref(),
&indexes,
);
}
}
// This shouldn't lead to argument type expansion. // This shouldn't lead to argument type expansion.
return None; return None;
@ -1446,15 +1466,28 @@ impl<'db> CallableBinding<'db> {
Some(self.overloads[index].return_type()) Some(self.overloads[index].return_type())
} }
MatchingOverloadIndex::Multiple(matching_overload_indexes) => { MatchingOverloadIndex::Multiple(matching_overload_indexes) => {
// TODO: Step 4 self.filter_overloads_containing_variadic(&matching_overload_indexes);
self.filter_overloads_using_any_or_unknown( match self.matching_overload_index() {
db, MatchingOverloadIndex::None => {
expanded_arguments, tracing::debug!(
&matching_overload_indexes, "All overloads have been filtered out in step 4 during argument type expansion"
); );
None
Some(self.return_type()) }
MatchingOverloadIndex::Single(index) => {
self.matching_overload_index = Some(index);
Some(self.return_type())
}
MatchingOverloadIndex::Multiple(indexes) => {
self.filter_overloads_using_any_or_unknown(
db,
expanded_arguments,
&indexes,
);
Some(self.return_type())
}
}
} }
}; };
@ -1511,6 +1544,32 @@ impl<'db> CallableBinding<'db> {
None None
} }
/// Filter overloads based on variadic argument to variadic parameter match.
///
/// This is the step 4 of the [overload call evaluation algorithm][1].
///
/// [1]: https://typing.python.org/en/latest/spec/overload.html#overload-call-evaluation
fn filter_overloads_containing_variadic(&mut self, matching_overload_indexes: &[usize]) {
let variadic_matching_overloads = matching_overload_indexes
.iter()
.filter(|&&overload_index| {
self.overloads[overload_index].variadic_argument_matched_to_variadic_parameter
})
.collect::<HashSet<_>>();
if variadic_matching_overloads.is_empty()
|| variadic_matching_overloads.len() == matching_overload_indexes.len()
{
return;
}
for overload_index in matching_overload_indexes {
if !variadic_matching_overloads.contains(overload_index) {
self.overloads[*overload_index].mark_as_unmatched_overload();
}
}
}
/// Filter overloads based on [`Any`] or [`Unknown`] argument types. /// Filter overloads based on [`Any`] or [`Unknown`] argument types.
/// ///
/// This is the step 5 of the [overload call evaluation algorithm][1]. /// This is the step 5 of the [overload call evaluation algorithm][1].
@ -1995,6 +2054,7 @@ struct ArgumentMatcher<'a, 'db> {
next_positional: usize, next_positional: usize,
first_excess_positional: Option<usize>, first_excess_positional: Option<usize>,
num_synthetic_args: usize, num_synthetic_args: usize,
variadic_argument_matched_to_variadic_parameter: bool,
} }
impl<'a, 'db> ArgumentMatcher<'a, 'db> { impl<'a, 'db> ArgumentMatcher<'a, 'db> {
@ -2014,6 +2074,7 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
next_positional: 0, next_positional: 0,
first_excess_positional: None, first_excess_positional: None,
num_synthetic_args: 0, num_synthetic_args: 0,
variadic_argument_matched_to_variadic_parameter: false,
} }
} }
@ -2029,6 +2090,7 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
} }
} }
#[expect(clippy::too_many_arguments)]
fn assign_argument( fn assign_argument(
&mut self, &mut self,
argument_index: usize, argument_index: usize,
@ -2037,6 +2099,7 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
parameter_index: usize, parameter_index: usize,
parameter: &Parameter<'db>, parameter: &Parameter<'db>,
positional: bool, positional: bool,
variable_argument_length: bool,
) { ) {
if !matches!(argument, Argument::Synthetic) { if !matches!(argument, Argument::Synthetic) {
let adjusted_argument_index = argument_index - self.num_synthetic_args; let adjusted_argument_index = argument_index - self.num_synthetic_args;
@ -2057,6 +2120,15 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
}); });
} }
} }
if variable_argument_length
&& matches!(
(argument, parameter.kind()),
(Argument::Variadic, ParameterKind::Variadic { .. })
| (Argument::Keywords, ParameterKind::KeywordVariadic { .. })
)
{
self.variadic_argument_matched_to_variadic_parameter = true;
}
let matched_argument = &mut self.argument_matches[argument_index]; let matched_argument = &mut self.argument_matches[argument_index];
matched_argument.parameters.push(parameter_index); matched_argument.parameters.push(parameter_index);
matched_argument.types.push(argument_type); matched_argument.types.push(argument_type);
@ -2069,6 +2141,7 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
argument_index: usize, argument_index: usize,
argument: Argument<'a>, argument: Argument<'a>,
argument_type: Option<Type<'db>>, argument_type: Option<Type<'db>>,
variable_argument_length: bool,
) -> Result<(), ()> { ) -> Result<(), ()> {
if matches!(argument, Argument::Synthetic) { if matches!(argument, Argument::Synthetic) {
self.num_synthetic_args += 1; self.num_synthetic_args += 1;
@ -2091,6 +2164,7 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
parameter_index, parameter_index,
parameter, parameter,
!parameter.is_variadic(), !parameter.is_variadic(),
variable_argument_length,
); );
Ok(()) Ok(())
} }
@ -2131,6 +2205,7 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
parameter_index, parameter_index,
parameter, parameter,
false, false,
false,
); );
Ok(()) Ok(())
} }
@ -2157,6 +2232,8 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
), ),
}; };
let is_variable = length.is_variable();
// We must be able to match up the fixed-length portion of the argument with positional // We must be able to match up the fixed-length portion of the argument with positional
// parameters, so we pass on any errors that occur. // parameters, so we pass on any errors that occur.
for _ in 0..length.minimum() { for _ in 0..length.minimum() {
@ -2164,12 +2241,13 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
argument_index, argument_index,
argument, argument,
argument_types.next().or(variable_element), argument_types.next().or(variable_element),
is_variable,
)?; )?;
} }
// If the tuple is variable-length, we assume that it will soak up all remaining positional // If the tuple is variable-length, we assume that it will soak up all remaining positional
// parameters. // parameters.
if length.is_variable() { if is_variable {
while self while self
.parameters .parameters
.get_positional(self.next_positional) .get_positional(self.next_positional)
@ -2179,6 +2257,7 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
argument_index, argument_index,
argument, argument,
argument_types.next().or(variable_element), argument_types.next().or(variable_element),
is_variable,
)?; )?;
} }
} }
@ -2189,9 +2268,14 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
// raise a false positive as "too many arguments". // raise a false positive as "too many arguments".
if self.parameters.variadic().is_some() { if self.parameters.variadic().is_some() {
if let Some(argument_type) = argument_types.next().or(variable_element) { if let Some(argument_type) = argument_types.next().or(variable_element) {
self.match_positional(argument_index, argument, Some(argument_type))?; self.match_positional(argument_index, argument, Some(argument_type), is_variable)?;
for argument_type in argument_types { for argument_type in argument_types {
self.match_positional(argument_index, argument, Some(argument_type))?; self.match_positional(
argument_index,
argument,
Some(argument_type),
is_variable,
)?;
} }
} }
} }
@ -2248,6 +2332,7 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
parameter_index, parameter_index,
parameter, parameter,
false, false,
true,
); );
} }
} }
@ -2670,6 +2755,10 @@ pub(crate) struct Binding<'db> {
/// order. /// order.
argument_matches: Box<[MatchedArgument<'db>]>, argument_matches: Box<[MatchedArgument<'db>]>,
/// Whether an argument that supplies an indeterminate number of positional or keyword
/// arguments is mapped to a variadic parameter (`*args` or `**kwargs`).
variadic_argument_matched_to_variadic_parameter: bool,
/// Bound types for parameters, in parameter source order, or `None` if no argument was matched /// Bound types for parameters, in parameter source order, or `None` if no argument was matched
/// to that parameter. /// to that parameter.
parameter_tys: Box<[Option<Type<'db>>]>, parameter_tys: Box<[Option<Type<'db>>]>,
@ -2688,6 +2777,7 @@ impl<'db> Binding<'db> {
specialization: None, specialization: None,
inherited_specialization: None, inherited_specialization: None,
argument_matches: Box::from([]), argument_matches: Box::from([]),
variadic_argument_matched_to_variadic_parameter: false,
parameter_tys: Box::from([]), parameter_tys: Box::from([]),
errors: vec![], errors: vec![],
} }
@ -2712,7 +2802,7 @@ impl<'db> Binding<'db> {
for (argument_index, (argument, argument_type)) in arguments.iter().enumerate() { for (argument_index, (argument, argument_type)) in arguments.iter().enumerate() {
match argument { match argument {
Argument::Positional | Argument::Synthetic => { Argument::Positional | Argument::Synthetic => {
let _ = matcher.match_positional(argument_index, argument, None); let _ = matcher.match_positional(argument_index, argument, None, false);
} }
Argument::Keyword(name) => { Argument::Keyword(name) => {
let _ = matcher.match_keyword(argument_index, argument, None, name); let _ = matcher.match_keyword(argument_index, argument, None, name);
@ -2730,6 +2820,8 @@ impl<'db> Binding<'db> {
} }
self.return_ty = self.signature.return_ty.unwrap_or(Type::unknown()); self.return_ty = self.signature.return_ty.unwrap_or(Type::unknown());
self.parameter_tys = vec![None; parameters.len()].into_boxed_slice(); self.parameter_tys = vec![None; parameters.len()].into_boxed_slice();
self.variadic_argument_matched_to_variadic_parameter =
matcher.variadic_argument_matched_to_variadic_parameter;
self.argument_matches = matcher.finish(); self.argument_matches = matcher.finish();
} }

View file

@ -44,7 +44,7 @@ impl TupleLength {
TupleLength::Variable(0, 0) TupleLength::Variable(0, 0)
} }
pub(crate) fn is_variable(self) -> bool { pub(crate) const fn is_variable(self) -> bool {
matches!(self, TupleLength::Variable(_, _)) matches!(self, TupleLength::Variable(_, _))
} }