[ty] Limit argument expansion size for overload call evaluation (#20041)

## Summary

This PR limits the argument type expansion size for an overload call
evaluation to 512.

The limit chosen is arbitrary but I've taken the 256 limit from Pyright
into account and bumped it x2 to start with.

Initially, I actually started out by trying to refactor the entire
argument type expansion to be lazy. Currently, expanding a single
argument at any position eagerly creates the combination (argument
lists) and returns that (`Vec<CallArguments>`) but I thought we could
make it lazier by converting the return type of `expand` from
`Iterator<Item = Vec<CallArguments>>` to `Iterator<Item = Iterator<Item
= CallArguments>>` but that's proving to be difficult to implement
mainly because we **need** to maintain the previous expansion to
generate the next expansion which is the main reason to use
`std::iter::successors` in the first place.

Another approach would be to eagerly expand all the argument types and
then use the `combinations` from `itertools` to generate the
combinations but we would need to find the "boundary" between arguments
lists produced from expanding argument at position 1 and position 2
because that's important for the algorithm.

Closes: https://github.com/astral-sh/ty/issues/868

## Test Plan

Add test case to demonstrate the limit along with the diagnostic
snapshot stating that the limit has been reached.
This commit is contained in:
Dhruv Manilawala 2025-08-25 15:13:04 +05:30 committed by GitHub
parent b57cc5be33
commit 376e3ff395
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 373 additions and 44 deletions

View file

@ -127,31 +127,39 @@ impl<'a, 'db> CallArguments<'a, 'db> {
/// contains the same arguments, but with one or more of the argument types expanded.
///
/// [argument type expansion]: https://typing.python.org/en/latest/spec/overload.html#argument-type-expansion
pub(crate) fn expand(
&self,
db: &'db dyn Db,
) -> impl Iterator<Item = Vec<CallArguments<'a, 'db>>> + '_ {
pub(super) fn expand(&self, db: &'db dyn Db) -> impl Iterator<Item = Expansion<'a, 'db>> + '_ {
/// Maximum number of argument lists that can be generated in a single expansion step.
static MAX_EXPANSIONS: usize = 512;
/// Represents the state of the expansion process.
enum State<'a, 'b, 'db> {
LimitReached(usize),
Expanding(ExpandingState<'a, 'b, 'db>),
}
/// Represents the expanding state with either the initial types or the expanded types.
///
/// This is useful to avoid cloning the initial types vector if none of the types can be
/// expanded.
enum State<'a, 'b, 'db> {
enum ExpandingState<'a, 'b, 'db> {
Initial(&'b Vec<Option<Type<'db>>>),
Expanded(Vec<CallArguments<'a, 'db>>),
}
impl<'db> State<'_, '_, 'db> {
impl<'db> ExpandingState<'_, '_, 'db> {
fn len(&self) -> usize {
match self {
State::Initial(_) => 1,
State::Expanded(expanded) => expanded.len(),
ExpandingState::Initial(_) => 1,
ExpandingState::Expanded(expanded) => expanded.len(),
}
}
fn iter(&self) -> impl Iterator<Item = &[Option<Type<'db>>]> + '_ {
match self {
State::Initial(types) => Either::Left(std::iter::once(types.as_slice())),
State::Expanded(expanded) => {
ExpandingState::Initial(types) => {
Either::Left(std::iter::once(types.as_slice()))
}
ExpandingState::Expanded(expanded) => {
Either::Right(expanded.iter().map(CallArguments::types))
}
}
@ -160,44 +168,82 @@ impl<'a, 'db> CallArguments<'a, 'db> {
let mut index = 0;
std::iter::successors(Some(State::Initial(&self.types)), move |previous| {
// Find the next type that can be expanded.
let expanded_types = loop {
let arg_type = self.types.get(index)?;
if let Some(arg_type) = arg_type {
if let Some(expanded_types) = expand_type(db, *arg_type) {
break expanded_types;
std::iter::successors(
Some(State::Expanding(ExpandingState::Initial(&self.types))),
move |previous| {
let state = match previous {
State::LimitReached(index) => return Some(State::LimitReached(*index)),
State::Expanding(expanding_state) => expanding_state,
};
// Find the next type that can be expanded.
let expanded_types = loop {
let arg_type = self.types.get(index)?;
if let Some(arg_type) = arg_type {
if let Some(expanded_types) = expand_type(db, *arg_type) {
break expanded_types;
}
}
index += 1;
};
let expansion_size = expanded_types.len() * state.len();
if expansion_size > MAX_EXPANSIONS {
tracing::debug!(
"Skipping argument type expansion as it would exceed the \
maximum number of expansions ({MAX_EXPANSIONS})"
);
return Some(State::LimitReached(index));
}
let mut expanded_arguments = Vec::with_capacity(expansion_size);
for pre_expanded_types in state.iter() {
for subtype in &expanded_types {
let mut new_expanded_types = pre_expanded_types.to_vec();
new_expanded_types[index] = Some(*subtype);
expanded_arguments.push(CallArguments::new(
self.arguments.clone(),
new_expanded_types,
));
}
}
// Increment the index to move to the next argument type for the next iteration.
index += 1;
};
let mut expanded_arguments = Vec::with_capacity(expanded_types.len() * previous.len());
for pre_expanded_types in previous.iter() {
for subtype in &expanded_types {
let mut new_expanded_types = pre_expanded_types.to_vec();
new_expanded_types[index] = Some(*subtype);
expanded_arguments.push(CallArguments::new(
self.arguments.clone(),
new_expanded_types,
));
}
}
// Increment the index to move to the next argument type for the next iteration.
index += 1;
Some(State::Expanded(expanded_arguments))
})
Some(State::Expanding(ExpandingState::Expanded(
expanded_arguments,
)))
},
)
.skip(1) // Skip the initial state, which has no expanded types.
.map(|state| match state {
State::Initial(_) => unreachable!("initial state should be skipped"),
State::Expanded(expanded) => expanded,
State::LimitReached(index) => Expansion::LimitReached(index),
State::Expanding(ExpandingState::Initial(_)) => {
unreachable!("initial state should be skipped")
}
State::Expanding(ExpandingState::Expanded(expanded)) => Expansion::Expanded(expanded),
})
}
}
/// Represents a single element of the expansion process for argument types for [`expand`].
///
/// [`expand`]: CallArguments::expand
pub(super) enum Expansion<'a, 'db> {
/// Indicates that the expansion process has reached the maximum number of argument lists
/// that can be generated in a single step.
///
/// The contained `usize` is the index of the argument type which would have been expanded
/// next, if not for the limit.
LimitReached(usize),
/// Contains the expanded argument lists, where each list contains the same arguments, but with
/// one or more of the argument types expanded.
Expanded(Vec<CallArguments<'a, 'db>>),
}
impl<'a, 'db> FromIterator<(Argument<'a>, Option<Type<'db>>)> for CallArguments<'a, 'db> {
fn from_iter<T>(iter: T) -> Self
where

View file

@ -16,7 +16,7 @@ use crate::Program;
use crate::db::Db;
use crate::dunder_all::dunder_all_names;
use crate::place::{Boundness, Place};
use crate::types::call::arguments::is_expandable_type;
use crate::types::call::arguments::{Expansion, is_expandable_type};
use crate::types::diagnostic::{
CALL_NON_CALLABLE, CONFLICTING_ARGUMENT_FORMS, INVALID_ARGUMENT_TYPE, MISSING_ARGUMENT,
NO_MATCHING_OVERLOAD, PARAMETER_ALREADY_ASSIGNED, TOO_MANY_POSITIONAL_ARGUMENTS,
@ -1368,7 +1368,18 @@ impl<'db> CallableBinding<'db> {
}
}
for expanded_argument_lists in expansions {
for expansion in expansions {
let expanded_argument_lists = match expansion {
Expansion::LimitReached(index) => {
snapshotter.restore(self, post_evaluation_snapshot);
self.overload_call_return_type = Some(
OverloadCallReturnType::ArgumentTypeExpansionLimitReached(index),
);
return;
}
Expansion::Expanded(argument_lists) => argument_lists,
};
// This is the merged state of the bindings after evaluating all of the expanded
// argument lists. This will be the final state to restore the bindings to if all of
// the expanded argument lists evaluated successfully.
@ -1667,7 +1678,8 @@ impl<'db> CallableBinding<'db> {
if let Some(overload_call_return_type) = self.overload_call_return_type {
return match overload_call_return_type {
OverloadCallReturnType::ArgumentTypeExpansion(return_type) => return_type,
OverloadCallReturnType::Ambiguous => Type::unknown(),
OverloadCallReturnType::ArgumentTypeExpansionLimitReached(_)
| OverloadCallReturnType::Ambiguous => Type::unknown(),
};
}
if let Some((_, first_overload)) = self.matching_overloads().next() {
@ -1778,6 +1790,23 @@ impl<'db> CallableBinding<'db> {
String::new()
}
));
if let Some(index) =
self.overload_call_return_type
.and_then(
|overload_call_return_type| match overload_call_return_type {
OverloadCallReturnType::ArgumentTypeExpansionLimitReached(
index,
) => Some(index),
_ => None,
},
)
{
diag.info(format_args!(
"Limit of argument type expansion reached at argument {index}"
));
}
if let Some((kind, function)) = function_type_and_kind {
let (overloads, implementation) =
function.overloads_and_implementation(context.db());
@ -1844,6 +1873,7 @@ impl<'a, 'db> IntoIterator for &'a CallableBinding<'db> {
#[derive(Debug, Copy, Clone)]
enum OverloadCallReturnType<'db> {
ArgumentTypeExpansion(Type<'db>),
ArgumentTypeExpansionLimitReached(usize),
Ambiguous,
}