[ty] Limit argument expansion size for overload call evaluation (#20041)

## Summary

This PR limits the argument type expansion size for an overload call
evaluation to 512.

The limit chosen is arbitrary but I've taken the 256 limit from Pyright
into account and bumped it x2 to start with.

Initially, I actually started out by trying to refactor the entire
argument type expansion to be lazy. Currently, expanding a single
argument at any position eagerly creates the combination (argument
lists) and returns that (`Vec<CallArguments>`) but I thought we could
make it lazier by converting the return type of `expand` from
`Iterator<Item = Vec<CallArguments>>` to `Iterator<Item = Iterator<Item
= CallArguments>>` but that's proving to be difficult to implement
mainly because we **need** to maintain the previous expansion to
generate the next expansion which is the main reason to use
`std::iter::successors` in the first place.

Another approach would be to eagerly expand all the argument types and
then use the `combinations` from `itertools` to generate the
combinations but we would need to find the "boundary" between arguments
lists produced from expanding argument at position 1 and position 2
because that's important for the algorithm.

Closes: https://github.com/astral-sh/ty/issues/868

## Test Plan

Add test case to demonstrate the limit along with the diagnostic
snapshot stating that the limit has been reached.
This commit is contained in:
Dhruv Manilawala 2025-08-25 15:13:04 +05:30 committed by GitHub
parent b57cc5be33
commit 376e3ff395
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 373 additions and 44 deletions

View file

@ -660,9 +660,9 @@ class Foo:
```py ```py
from overloaded import A, B, C, Foo, f from overloaded import A, B, C, Foo, f
from typing_extensions import reveal_type from typing_extensions import Any, reveal_type
def _(ab: A | B, a=1): def _(ab: A | B, a: int | Any):
reveal_type(f(a1=a, a2=a, a3=a)) # revealed: C reveal_type(f(a1=a, a2=a, a3=a)) # revealed: C
reveal_type(f(A(), a1=a, a2=a, a3=a)) # revealed: A reveal_type(f(A(), a1=a, a2=a, a3=a)) # revealed: A
reveal_type(f(B(), a1=a, a2=a, a3=a)) # revealed: B reveal_type(f(B(), a1=a, a2=a, a3=a)) # revealed: B
@ -750,7 +750,7 @@ def _(ab: A | B, a=1):
) )
) )
def _(foo: Foo, ab: A | B, a=1): def _(foo: Foo, ab: A | B, a: int | Any):
reveal_type(foo.f(a1=a, a2=a, a3=a)) # revealed: C reveal_type(foo.f(a1=a, a2=a, a3=a)) # revealed: C
reveal_type(foo.f(A(), a1=a, a2=a, a3=a)) # revealed: A reveal_type(foo.f(A(), a1=a, a2=a, a3=a)) # revealed: A
reveal_type(foo.f(B(), a1=a, a2=a, a3=a)) # revealed: B reveal_type(foo.f(B(), a1=a, a2=a, a3=a)) # revealed: B
@ -831,6 +831,76 @@ def _(foo: Foo, ab: A | B, a=1):
) )
``` ```
### Optimization: Limit expansion size
<!-- snapshot-diagnostics -->
To prevent combinatorial explosion, ty limits the number of argument lists created by expanding a
single argument.
`overloaded.pyi`:
```pyi
from typing import overload
class A: ...
class B: ...
class C: ...
@overload
def f() -> None: ...
@overload
def f(**kwargs: int) -> C: ...
@overload
def f(x: A, /, **kwargs: int) -> A: ...
@overload
def f(x: B, /, **kwargs: int) -> B: ...
```
```py
from overloaded import A, B, f
from typing_extensions import reveal_type
def _(a: int | None):
reveal_type(
# error: [no-matching-overload]
# revealed: Unknown
f(
A(),
a1=a,
a2=a,
a3=a,
a4=a,
a5=a,
a6=a,
a7=a,
a8=a,
a9=a,
a10=a,
a11=a,
a12=a,
a13=a,
a14=a,
a15=a,
a16=a,
a17=a,
a18=a,
a19=a,
a20=a,
a21=a,
a22=a,
a23=a,
a24=a,
a25=a,
a26=a,
a27=a,
a28=a,
a29=a,
a30=a,
)
)
```
## Filtering based on `Any` / `Unknown` ## Filtering based on `Any` / `Unknown`
This is the step 5 of the overload call evaluation algorithm which specifies that: This is the step 5 of the overload call evaluation algorithm which specifies that:

View file

@ -0,0 +1,183 @@
---
source: crates/ty_test/src/lib.rs
expression: snapshot
---
---
mdtest name: overloads.md - Overloads - Argument type expansion - Optimization: Limit expansion size
mdtest path: crates/ty_python_semantic/resources/mdtest/call/overloads.md
---
# Python source files
## overloaded.pyi
```
1 | from typing import overload
2 |
3 | class A: ...
4 | class B: ...
5 | class C: ...
6 |
7 | @overload
8 | def f() -> None: ...
9 | @overload
10 | def f(**kwargs: int) -> C: ...
11 | @overload
12 | def f(x: A, /, **kwargs: int) -> A: ...
13 | @overload
14 | def f(x: B, /, **kwargs: int) -> B: ...
```
## mdtest_snippet.py
```
1 | from overloaded import A, B, f
2 | from typing_extensions import reveal_type
3 |
4 | def _(a: int | None):
5 | reveal_type(
6 | # error: [no-matching-overload]
7 | # revealed: Unknown
8 | f(
9 | A(),
10 | a1=a,
11 | a2=a,
12 | a3=a,
13 | a4=a,
14 | a5=a,
15 | a6=a,
16 | a7=a,
17 | a8=a,
18 | a9=a,
19 | a10=a,
20 | a11=a,
21 | a12=a,
22 | a13=a,
23 | a14=a,
24 | a15=a,
25 | a16=a,
26 | a17=a,
27 | a18=a,
28 | a19=a,
29 | a20=a,
30 | a21=a,
31 | a22=a,
32 | a23=a,
33 | a24=a,
34 | a25=a,
35 | a26=a,
36 | a27=a,
37 | a28=a,
38 | a29=a,
39 | a30=a,
40 | )
41 | )
```
# Diagnostics
```
error[no-matching-overload]: No overload of function `f` matches arguments
--> src/mdtest_snippet.py:8:9
|
6 | # error: [no-matching-overload]
7 | # revealed: Unknown
8 | / f(
9 | | A(),
10 | | a1=a,
11 | | a2=a,
12 | | a3=a,
13 | | a4=a,
14 | | a5=a,
15 | | a6=a,
16 | | a7=a,
17 | | a8=a,
18 | | a9=a,
19 | | a10=a,
20 | | a11=a,
21 | | a12=a,
22 | | a13=a,
23 | | a14=a,
24 | | a15=a,
25 | | a16=a,
26 | | a17=a,
27 | | a18=a,
28 | | a19=a,
29 | | a20=a,
30 | | a21=a,
31 | | a22=a,
32 | | a23=a,
33 | | a24=a,
34 | | a25=a,
35 | | a26=a,
36 | | a27=a,
37 | | a28=a,
38 | | a29=a,
39 | | a30=a,
40 | | )
| |_________^
41 | )
|
info: Limit of argument type expansion reached at argument 10
info: First overload defined here
--> src/overloaded.pyi:8:5
|
7 | @overload
8 | def f() -> None: ...
| ^^^^^^^^^^^
9 | @overload
10 | def f(**kwargs: int) -> C: ...
|
info: Possible overloads for function `f`:
info: () -> None
info: (**kwargs: int) -> C
info: (x: A, /, **kwargs: int) -> A
info: (x: B, /, **kwargs: int) -> B
info: rule `no-matching-overload` is enabled by default
```
```
info[revealed-type]: Revealed type
--> src/mdtest_snippet.py:8:9
|
6 | # error: [no-matching-overload]
7 | # revealed: Unknown
8 | / f(
9 | | A(),
10 | | a1=a,
11 | | a2=a,
12 | | a3=a,
13 | | a4=a,
14 | | a5=a,
15 | | a6=a,
16 | | a7=a,
17 | | a8=a,
18 | | a9=a,
19 | | a10=a,
20 | | a11=a,
21 | | a12=a,
22 | | a13=a,
23 | | a14=a,
24 | | a15=a,
25 | | a16=a,
26 | | a17=a,
27 | | a18=a,
28 | | a19=a,
29 | | a20=a,
30 | | a21=a,
31 | | a22=a,
32 | | a23=a,
33 | | a24=a,
34 | | a25=a,
35 | | a26=a,
36 | | a27=a,
37 | | a28=a,
38 | | a29=a,
39 | | a30=a,
40 | | )
| |_________^ `Unknown`
41 | )
|
```

View file

@ -127,31 +127,39 @@ impl<'a, 'db> CallArguments<'a, 'db> {
/// contains the same arguments, but with one or more of the argument types expanded. /// contains the same arguments, but with one or more of the argument types expanded.
/// ///
/// [argument type expansion]: https://typing.python.org/en/latest/spec/overload.html#argument-type-expansion /// [argument type expansion]: https://typing.python.org/en/latest/spec/overload.html#argument-type-expansion
pub(crate) fn expand( pub(super) fn expand(&self, db: &'db dyn Db) -> impl Iterator<Item = Expansion<'a, 'db>> + '_ {
&self, /// Maximum number of argument lists that can be generated in a single expansion step.
db: &'db dyn Db, static MAX_EXPANSIONS: usize = 512;
) -> impl Iterator<Item = Vec<CallArguments<'a, 'db>>> + '_ {
/// Represents the state of the expansion process. /// Represents the state of the expansion process.
enum State<'a, 'b, 'db> {
LimitReached(usize),
Expanding(ExpandingState<'a, 'b, 'db>),
}
/// Represents the expanding state with either the initial types or the expanded types.
/// ///
/// This is useful to avoid cloning the initial types vector if none of the types can be /// This is useful to avoid cloning the initial types vector if none of the types can be
/// expanded. /// expanded.
enum State<'a, 'b, 'db> { enum ExpandingState<'a, 'b, 'db> {
Initial(&'b Vec<Option<Type<'db>>>), Initial(&'b Vec<Option<Type<'db>>>),
Expanded(Vec<CallArguments<'a, 'db>>), Expanded(Vec<CallArguments<'a, 'db>>),
} }
impl<'db> State<'_, '_, 'db> { impl<'db> ExpandingState<'_, '_, 'db> {
fn len(&self) -> usize { fn len(&self) -> usize {
match self { match self {
State::Initial(_) => 1, ExpandingState::Initial(_) => 1,
State::Expanded(expanded) => expanded.len(), ExpandingState::Expanded(expanded) => expanded.len(),
} }
} }
fn iter(&self) -> impl Iterator<Item = &[Option<Type<'db>>]> + '_ { fn iter(&self) -> impl Iterator<Item = &[Option<Type<'db>>]> + '_ {
match self { match self {
State::Initial(types) => Either::Left(std::iter::once(types.as_slice())), ExpandingState::Initial(types) => {
State::Expanded(expanded) => { Either::Left(std::iter::once(types.as_slice()))
}
ExpandingState::Expanded(expanded) => {
Either::Right(expanded.iter().map(CallArguments::types)) Either::Right(expanded.iter().map(CallArguments::types))
} }
} }
@ -160,7 +168,14 @@ impl<'a, 'db> CallArguments<'a, 'db> {
let mut index = 0; let mut index = 0;
std::iter::successors(Some(State::Initial(&self.types)), move |previous| { std::iter::successors(
Some(State::Expanding(ExpandingState::Initial(&self.types))),
move |previous| {
let state = match previous {
State::LimitReached(index) => return Some(State::LimitReached(*index)),
State::Expanding(expanding_state) => expanding_state,
};
// Find the next type that can be expanded. // Find the next type that can be expanded.
let expanded_types = loop { let expanded_types = loop {
let arg_type = self.types.get(index)?; let arg_type = self.types.get(index)?;
@ -172,9 +187,18 @@ impl<'a, 'db> CallArguments<'a, 'db> {
index += 1; index += 1;
}; };
let mut expanded_arguments = Vec::with_capacity(expanded_types.len() * previous.len()); let expansion_size = expanded_types.len() * state.len();
if expansion_size > MAX_EXPANSIONS {
tracing::debug!(
"Skipping argument type expansion as it would exceed the \
maximum number of expansions ({MAX_EXPANSIONS})"
);
return Some(State::LimitReached(index));
}
for pre_expanded_types in previous.iter() { let mut expanded_arguments = Vec::with_capacity(expansion_size);
for pre_expanded_types in state.iter() {
for subtype in &expanded_types { for subtype in &expanded_types {
let mut new_expanded_types = pre_expanded_types.to_vec(); let mut new_expanded_types = pre_expanded_types.to_vec();
new_expanded_types[index] = Some(*subtype); new_expanded_types[index] = Some(*subtype);
@ -188,16 +212,38 @@ impl<'a, 'db> CallArguments<'a, 'db> {
// Increment the index to move to the next argument type for the next iteration. // Increment the index to move to the next argument type for the next iteration.
index += 1; index += 1;
Some(State::Expanded(expanded_arguments)) Some(State::Expanding(ExpandingState::Expanded(
}) expanded_arguments,
)))
},
)
.skip(1) // Skip the initial state, which has no expanded types. .skip(1) // Skip the initial state, which has no expanded types.
.map(|state| match state { .map(|state| match state {
State::Initial(_) => unreachable!("initial state should be skipped"), State::LimitReached(index) => Expansion::LimitReached(index),
State::Expanded(expanded) => expanded, State::Expanding(ExpandingState::Initial(_)) => {
unreachable!("initial state should be skipped")
}
State::Expanding(ExpandingState::Expanded(expanded)) => Expansion::Expanded(expanded),
}) })
} }
} }
/// Represents a single element of the expansion process for argument types for [`expand`].
///
/// [`expand`]: CallArguments::expand
pub(super) enum Expansion<'a, 'db> {
/// Indicates that the expansion process has reached the maximum number of argument lists
/// that can be generated in a single step.
///
/// The contained `usize` is the index of the argument type which would have been expanded
/// next, if not for the limit.
LimitReached(usize),
/// Contains the expanded argument lists, where each list contains the same arguments, but with
/// one or more of the argument types expanded.
Expanded(Vec<CallArguments<'a, 'db>>),
}
impl<'a, 'db> FromIterator<(Argument<'a>, Option<Type<'db>>)> for CallArguments<'a, 'db> { impl<'a, 'db> FromIterator<(Argument<'a>, Option<Type<'db>>)> for CallArguments<'a, 'db> {
fn from_iter<T>(iter: T) -> Self fn from_iter<T>(iter: T) -> Self
where where

View file

@ -16,7 +16,7 @@ use crate::Program;
use crate::db::Db; use crate::db::Db;
use crate::dunder_all::dunder_all_names; use crate::dunder_all::dunder_all_names;
use crate::place::{Boundness, Place}; use crate::place::{Boundness, Place};
use crate::types::call::arguments::is_expandable_type; use crate::types::call::arguments::{Expansion, is_expandable_type};
use crate::types::diagnostic::{ use crate::types::diagnostic::{
CALL_NON_CALLABLE, CONFLICTING_ARGUMENT_FORMS, INVALID_ARGUMENT_TYPE, MISSING_ARGUMENT, CALL_NON_CALLABLE, CONFLICTING_ARGUMENT_FORMS, INVALID_ARGUMENT_TYPE, MISSING_ARGUMENT,
NO_MATCHING_OVERLOAD, PARAMETER_ALREADY_ASSIGNED, TOO_MANY_POSITIONAL_ARGUMENTS, NO_MATCHING_OVERLOAD, PARAMETER_ALREADY_ASSIGNED, TOO_MANY_POSITIONAL_ARGUMENTS,
@ -1368,7 +1368,18 @@ impl<'db> CallableBinding<'db> {
} }
} }
for expanded_argument_lists in expansions { for expansion in expansions {
let expanded_argument_lists = match expansion {
Expansion::LimitReached(index) => {
snapshotter.restore(self, post_evaluation_snapshot);
self.overload_call_return_type = Some(
OverloadCallReturnType::ArgumentTypeExpansionLimitReached(index),
);
return;
}
Expansion::Expanded(argument_lists) => argument_lists,
};
// This is the merged state of the bindings after evaluating all of the expanded // This is the merged state of the bindings after evaluating all of the expanded
// argument lists. This will be the final state to restore the bindings to if all of // argument lists. This will be the final state to restore the bindings to if all of
// the expanded argument lists evaluated successfully. // the expanded argument lists evaluated successfully.
@ -1667,7 +1678,8 @@ impl<'db> CallableBinding<'db> {
if let Some(overload_call_return_type) = self.overload_call_return_type { if let Some(overload_call_return_type) = self.overload_call_return_type {
return match overload_call_return_type { return match overload_call_return_type {
OverloadCallReturnType::ArgumentTypeExpansion(return_type) => return_type, OverloadCallReturnType::ArgumentTypeExpansion(return_type) => return_type,
OverloadCallReturnType::Ambiguous => Type::unknown(), OverloadCallReturnType::ArgumentTypeExpansionLimitReached(_)
| OverloadCallReturnType::Ambiguous => Type::unknown(),
}; };
} }
if let Some((_, first_overload)) = self.matching_overloads().next() { if let Some((_, first_overload)) = self.matching_overloads().next() {
@ -1778,6 +1790,23 @@ impl<'db> CallableBinding<'db> {
String::new() String::new()
} }
)); ));
if let Some(index) =
self.overload_call_return_type
.and_then(
|overload_call_return_type| match overload_call_return_type {
OverloadCallReturnType::ArgumentTypeExpansionLimitReached(
index,
) => Some(index),
_ => None,
},
)
{
diag.info(format_args!(
"Limit of argument type expansion reached at argument {index}"
));
}
if let Some((kind, function)) = function_type_and_kind { if let Some((kind, function)) = function_type_and_kind {
let (overloads, implementation) = let (overloads, implementation) =
function.overloads_and_implementation(context.db()); function.overloads_and_implementation(context.db());
@ -1844,6 +1873,7 @@ impl<'a, 'db> IntoIterator for &'a CallableBinding<'db> {
#[derive(Debug, Copy, Clone)] #[derive(Debug, Copy, Clone)]
enum OverloadCallReturnType<'db> { enum OverloadCallReturnType<'db> {
ArgumentTypeExpansion(Type<'db>), ArgumentTypeExpansion(Type<'db>),
ArgumentTypeExpansionLimitReached(usize),
Ambiguous, Ambiguous,
} }