mirror of
https://github.com/astral-sh/ruff.git
synced 2025-09-29 13:24:57 +00:00
[ty] Limit argument expansion size for overload call evaluation (#20041)
## Summary This PR limits the argument type expansion size for an overload call evaluation to 512. The limit chosen is arbitrary but I've taken the 256 limit from Pyright into account and bumped it x2 to start with. Initially, I actually started out by trying to refactor the entire argument type expansion to be lazy. Currently, expanding a single argument at any position eagerly creates the combination (argument lists) and returns that (`Vec<CallArguments>`) but I thought we could make it lazier by converting the return type of `expand` from `Iterator<Item = Vec<CallArguments>>` to `Iterator<Item = Iterator<Item = CallArguments>>` but that's proving to be difficult to implement mainly because we **need** to maintain the previous expansion to generate the next expansion which is the main reason to use `std::iter::successors` in the first place. Another approach would be to eagerly expand all the argument types and then use the `combinations` from `itertools` to generate the combinations but we would need to find the "boundary" between arguments lists produced from expanding argument at position 1 and position 2 because that's important for the algorithm. Closes: https://github.com/astral-sh/ty/issues/868 ## Test Plan Add test case to demonstrate the limit along with the diagnostic snapshot stating that the limit has been reached.
This commit is contained in:
parent
b57cc5be33
commit
376e3ff395
4 changed files with 373 additions and 44 deletions
|
@ -127,31 +127,39 @@ impl<'a, 'db> CallArguments<'a, 'db> {
|
|||
/// contains the same arguments, but with one or more of the argument types expanded.
|
||||
///
|
||||
/// [argument type expansion]: https://typing.python.org/en/latest/spec/overload.html#argument-type-expansion
|
||||
pub(crate) fn expand(
|
||||
&self,
|
||||
db: &'db dyn Db,
|
||||
) -> impl Iterator<Item = Vec<CallArguments<'a, 'db>>> + '_ {
|
||||
pub(super) fn expand(&self, db: &'db dyn Db) -> impl Iterator<Item = Expansion<'a, 'db>> + '_ {
|
||||
/// Maximum number of argument lists that can be generated in a single expansion step.
|
||||
static MAX_EXPANSIONS: usize = 512;
|
||||
|
||||
/// Represents the state of the expansion process.
|
||||
enum State<'a, 'b, 'db> {
|
||||
LimitReached(usize),
|
||||
Expanding(ExpandingState<'a, 'b, 'db>),
|
||||
}
|
||||
|
||||
/// Represents the expanding state with either the initial types or the expanded types.
|
||||
///
|
||||
/// This is useful to avoid cloning the initial types vector if none of the types can be
|
||||
/// expanded.
|
||||
enum State<'a, 'b, 'db> {
|
||||
enum ExpandingState<'a, 'b, 'db> {
|
||||
Initial(&'b Vec<Option<Type<'db>>>),
|
||||
Expanded(Vec<CallArguments<'a, 'db>>),
|
||||
}
|
||||
|
||||
impl<'db> State<'_, '_, 'db> {
|
||||
impl<'db> ExpandingState<'_, '_, 'db> {
|
||||
fn len(&self) -> usize {
|
||||
match self {
|
||||
State::Initial(_) => 1,
|
||||
State::Expanded(expanded) => expanded.len(),
|
||||
ExpandingState::Initial(_) => 1,
|
||||
ExpandingState::Expanded(expanded) => expanded.len(),
|
||||
}
|
||||
}
|
||||
|
||||
fn iter(&self) -> impl Iterator<Item = &[Option<Type<'db>>]> + '_ {
|
||||
match self {
|
||||
State::Initial(types) => Either::Left(std::iter::once(types.as_slice())),
|
||||
State::Expanded(expanded) => {
|
||||
ExpandingState::Initial(types) => {
|
||||
Either::Left(std::iter::once(types.as_slice()))
|
||||
}
|
||||
ExpandingState::Expanded(expanded) => {
|
||||
Either::Right(expanded.iter().map(CallArguments::types))
|
||||
}
|
||||
}
|
||||
|
@ -160,44 +168,82 @@ impl<'a, 'db> CallArguments<'a, 'db> {
|
|||
|
||||
let mut index = 0;
|
||||
|
||||
std::iter::successors(Some(State::Initial(&self.types)), move |previous| {
|
||||
// Find the next type that can be expanded.
|
||||
let expanded_types = loop {
|
||||
let arg_type = self.types.get(index)?;
|
||||
if let Some(arg_type) = arg_type {
|
||||
if let Some(expanded_types) = expand_type(db, *arg_type) {
|
||||
break expanded_types;
|
||||
std::iter::successors(
|
||||
Some(State::Expanding(ExpandingState::Initial(&self.types))),
|
||||
move |previous| {
|
||||
let state = match previous {
|
||||
State::LimitReached(index) => return Some(State::LimitReached(*index)),
|
||||
State::Expanding(expanding_state) => expanding_state,
|
||||
};
|
||||
|
||||
// Find the next type that can be expanded.
|
||||
let expanded_types = loop {
|
||||
let arg_type = self.types.get(index)?;
|
||||
if let Some(arg_type) = arg_type {
|
||||
if let Some(expanded_types) = expand_type(db, *arg_type) {
|
||||
break expanded_types;
|
||||
}
|
||||
}
|
||||
index += 1;
|
||||
};
|
||||
|
||||
let expansion_size = expanded_types.len() * state.len();
|
||||
if expansion_size > MAX_EXPANSIONS {
|
||||
tracing::debug!(
|
||||
"Skipping argument type expansion as it would exceed the \
|
||||
maximum number of expansions ({MAX_EXPANSIONS})"
|
||||
);
|
||||
return Some(State::LimitReached(index));
|
||||
}
|
||||
|
||||
let mut expanded_arguments = Vec::with_capacity(expansion_size);
|
||||
|
||||
for pre_expanded_types in state.iter() {
|
||||
for subtype in &expanded_types {
|
||||
let mut new_expanded_types = pre_expanded_types.to_vec();
|
||||
new_expanded_types[index] = Some(*subtype);
|
||||
expanded_arguments.push(CallArguments::new(
|
||||
self.arguments.clone(),
|
||||
new_expanded_types,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Increment the index to move to the next argument type for the next iteration.
|
||||
index += 1;
|
||||
};
|
||||
|
||||
let mut expanded_arguments = Vec::with_capacity(expanded_types.len() * previous.len());
|
||||
|
||||
for pre_expanded_types in previous.iter() {
|
||||
for subtype in &expanded_types {
|
||||
let mut new_expanded_types = pre_expanded_types.to_vec();
|
||||
new_expanded_types[index] = Some(*subtype);
|
||||
expanded_arguments.push(CallArguments::new(
|
||||
self.arguments.clone(),
|
||||
new_expanded_types,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Increment the index to move to the next argument type for the next iteration.
|
||||
index += 1;
|
||||
|
||||
Some(State::Expanded(expanded_arguments))
|
||||
})
|
||||
Some(State::Expanding(ExpandingState::Expanded(
|
||||
expanded_arguments,
|
||||
)))
|
||||
},
|
||||
)
|
||||
.skip(1) // Skip the initial state, which has no expanded types.
|
||||
.map(|state| match state {
|
||||
State::Initial(_) => unreachable!("initial state should be skipped"),
|
||||
State::Expanded(expanded) => expanded,
|
||||
State::LimitReached(index) => Expansion::LimitReached(index),
|
||||
State::Expanding(ExpandingState::Initial(_)) => {
|
||||
unreachable!("initial state should be skipped")
|
||||
}
|
||||
State::Expanding(ExpandingState::Expanded(expanded)) => Expansion::Expanded(expanded),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a single element of the expansion process for argument types for [`expand`].
|
||||
///
|
||||
/// [`expand`]: CallArguments::expand
|
||||
pub(super) enum Expansion<'a, 'db> {
|
||||
/// Indicates that the expansion process has reached the maximum number of argument lists
|
||||
/// that can be generated in a single step.
|
||||
///
|
||||
/// The contained `usize` is the index of the argument type which would have been expanded
|
||||
/// next, if not for the limit.
|
||||
LimitReached(usize),
|
||||
|
||||
/// Contains the expanded argument lists, where each list contains the same arguments, but with
|
||||
/// one or more of the argument types expanded.
|
||||
Expanded(Vec<CallArguments<'a, 'db>>),
|
||||
}
|
||||
|
||||
impl<'a, 'db> FromIterator<(Argument<'a>, Option<Type<'db>>)> for CallArguments<'a, 'db> {
|
||||
fn from_iter<T>(iter: T) -> Self
|
||||
where
|
||||
|
|
|
@ -16,7 +16,7 @@ use crate::Program;
|
|||
use crate::db::Db;
|
||||
use crate::dunder_all::dunder_all_names;
|
||||
use crate::place::{Boundness, Place};
|
||||
use crate::types::call::arguments::is_expandable_type;
|
||||
use crate::types::call::arguments::{Expansion, is_expandable_type};
|
||||
use crate::types::diagnostic::{
|
||||
CALL_NON_CALLABLE, CONFLICTING_ARGUMENT_FORMS, INVALID_ARGUMENT_TYPE, MISSING_ARGUMENT,
|
||||
NO_MATCHING_OVERLOAD, PARAMETER_ALREADY_ASSIGNED, TOO_MANY_POSITIONAL_ARGUMENTS,
|
||||
|
@ -1368,7 +1368,18 @@ impl<'db> CallableBinding<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
for expanded_argument_lists in expansions {
|
||||
for expansion in expansions {
|
||||
let expanded_argument_lists = match expansion {
|
||||
Expansion::LimitReached(index) => {
|
||||
snapshotter.restore(self, post_evaluation_snapshot);
|
||||
self.overload_call_return_type = Some(
|
||||
OverloadCallReturnType::ArgumentTypeExpansionLimitReached(index),
|
||||
);
|
||||
return;
|
||||
}
|
||||
Expansion::Expanded(argument_lists) => argument_lists,
|
||||
};
|
||||
|
||||
// This is the merged state of the bindings after evaluating all of the expanded
|
||||
// argument lists. This will be the final state to restore the bindings to if all of
|
||||
// the expanded argument lists evaluated successfully.
|
||||
|
@ -1667,7 +1678,8 @@ impl<'db> CallableBinding<'db> {
|
|||
if let Some(overload_call_return_type) = self.overload_call_return_type {
|
||||
return match overload_call_return_type {
|
||||
OverloadCallReturnType::ArgumentTypeExpansion(return_type) => return_type,
|
||||
OverloadCallReturnType::Ambiguous => Type::unknown(),
|
||||
OverloadCallReturnType::ArgumentTypeExpansionLimitReached(_)
|
||||
| OverloadCallReturnType::Ambiguous => Type::unknown(),
|
||||
};
|
||||
}
|
||||
if let Some((_, first_overload)) = self.matching_overloads().next() {
|
||||
|
@ -1778,6 +1790,23 @@ impl<'db> CallableBinding<'db> {
|
|||
String::new()
|
||||
}
|
||||
));
|
||||
|
||||
if let Some(index) =
|
||||
self.overload_call_return_type
|
||||
.and_then(
|
||||
|overload_call_return_type| match overload_call_return_type {
|
||||
OverloadCallReturnType::ArgumentTypeExpansionLimitReached(
|
||||
index,
|
||||
) => Some(index),
|
||||
_ => None,
|
||||
},
|
||||
)
|
||||
{
|
||||
diag.info(format_args!(
|
||||
"Limit of argument type expansion reached at argument {index}"
|
||||
));
|
||||
}
|
||||
|
||||
if let Some((kind, function)) = function_type_and_kind {
|
||||
let (overloads, implementation) =
|
||||
function.overloads_and_implementation(context.db());
|
||||
|
@ -1844,6 +1873,7 @@ impl<'a, 'db> IntoIterator for &'a CallableBinding<'db> {
|
|||
#[derive(Debug, Copy, Clone)]
|
||||
enum OverloadCallReturnType<'db> {
|
||||
ArgumentTypeExpansion(Type<'db>),
|
||||
ArgumentTypeExpansionLimitReached(usize),
|
||||
Ambiguous,
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue