[ty] Retry parameter matching for argument type expansion (#20153)

## Summary

This PR addresses an issue for a variadic argument when involved in
argument type expansion of overload call evaluation.

The issue is that the expansion of the variadic argument could result in
argument list of different arity. For example, in `*args: tuple[int] |
tuple[int, str]`, the expansion would lead to the variadic argument
being unpacked into 1 and 2 element respectively. This means that the
parameter matching that was performed initially isn't sufficient and
each expanded argument list would need to redo the parameter matching
again.

This is currently done by redoing the parameter matching directly,
maintaining the state of argument forms (and the conflicting forms), and
updating the `Bindings` values if it changes.

Closes: astral-sh/ty#735

## Test Plan

Update existing mdtest.
This commit is contained in:
Dhruv Manilawala 2025-09-12 14:10:07 +05:30 committed by GitHub
parent 1cd8ab3f26
commit bb9be263c7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 205 additions and 156 deletions

View file

@ -202,10 +202,25 @@ impl<'a, 'db> CallArguments<'a, 'db> {
for subtype in &expanded_types {
let mut new_expanded_types = pre_expanded_types.to_vec();
new_expanded_types[index] = Some(*subtype);
expanded_arguments.push(CallArguments::new(
self.arguments.clone(),
new_expanded_types,
));
// Update the arguments list to handle variadic argument expansion
let mut new_arguments = self.arguments.clone();
if let Argument::Variadic(_) = self.arguments[index] {
// If the argument corresponding to this type is variadic, we need to
// update the tuple length because expanding could change the length.
// For example, in `tuple[int] | tuple[int, int]`, the length of the
// first type is 1, while the length of the second type is 2.
if let Some(expanded_type) = new_expanded_types[index] {
let length = expanded_type
.try_iterate(db)
.map(|tuple| tuple.len())
.unwrap_or(TupleLength::unknown());
new_arguments[index] = Argument::Variadic(length);
}
}
expanded_arguments
.push(CallArguments::new(new_arguments, new_expanded_types));
}
}

View file

@ -3,7 +3,6 @@
//! [signatures][crate::types::signatures], we have to handle the fact that the callable might be a
//! union of types, each of which might contain multiple overloads.
use std::borrow::Cow;
use std::collections::HashSet;
use std::fmt;
@ -28,7 +27,7 @@ use crate::types::function::{
};
use crate::types::generics::{Specialization, SpecializationBuilder, SpecializationError};
use crate::types::signatures::{Parameter, ParameterForm, Parameters};
use crate::types::tuple::{Tuple, TupleLength, TupleType};
use crate::types::tuple::{TupleLength, TupleType};
use crate::types::{
BoundMethodType, ClassLiteral, DataclassParams, FieldInstance, KnownBoundMethodType,
KnownClass, KnownInstanceType, PropertyInstanceType, SpecialFormType, TrackedConstraintSet,
@ -51,9 +50,7 @@ pub(crate) struct Bindings<'db> {
elements: SmallVec<[CallableBinding<'db>; 1]>,
/// Whether each argument will be used as a value and/or a type form in this call.
pub(crate) argument_forms: Box<[Option<ParameterForm>]>,
conflicting_forms: Box<[bool]>,
argument_forms: ArgumentForms,
}
impl<'db> Bindings<'db> {
@ -71,8 +68,7 @@ impl<'db> Bindings<'db> {
Self {
callable_type,
elements,
argument_forms: Box::from([]),
conflicting_forms: Box::from([]),
argument_forms: ArgumentForms::new(0),
}
}
@ -91,6 +87,10 @@ impl<'db> Bindings<'db> {
}
}
pub(crate) fn argument_forms(&self) -> &[Option<ParameterForm>] {
&self.argument_forms.values
}
/// Match the arguments of a call site against the parameters of a collection of possibly
/// unioned, possibly overloaded signatures.
///
@ -105,13 +105,12 @@ impl<'db> Bindings<'db> {
db: &'db dyn Db,
arguments: &CallArguments<'_, 'db>,
) -> Self {
let mut argument_forms = vec![None; arguments.len()];
let mut conflicting_forms = vec![false; arguments.len()];
let mut argument_forms = ArgumentForms::new(arguments.len());
for binding in &mut self.elements {
binding.match_parameters(db, arguments, &mut argument_forms, &mut conflicting_forms);
binding.match_parameters(db, arguments, &mut argument_forms);
}
self.argument_forms = argument_forms.into();
self.conflicting_forms = conflicting_forms.into();
argument_forms.shrink_to_fit();
self.argument_forms = argument_forms;
self
}
@ -130,7 +129,12 @@ impl<'db> Bindings<'db> {
argument_types: &CallArguments<'_, 'db>,
) -> Result<Self, CallError<'db>> {
for element in &mut self.elements {
element.check_types(db, argument_types);
if let Some(mut updated_argument_forms) = element.check_types(db, argument_types) {
// If this element returned a new set of argument forms (indicating successful
// argument type expansion), update the `Bindings` with these forms.
updated_argument_forms.shrink_to_fit();
self.argument_forms = updated_argument_forms;
}
}
self.evaluate_known_cases(db);
@ -153,7 +157,7 @@ impl<'db> Bindings<'db> {
let mut all_ok = true;
let mut any_binding_error = false;
let mut all_not_callable = true;
if self.conflicting_forms.contains(&true) {
if self.argument_forms.conflicting.contains(&true) {
all_ok = false;
any_binding_error = true;
all_not_callable = false;
@ -226,7 +230,7 @@ impl<'db> Bindings<'db> {
return;
}
for (index, conflicting_form) in self.conflicting_forms.iter().enumerate() {
for (index, conflicting_form) in self.argument_forms.conflicting.iter().enumerate() {
if *conflicting_form {
let node = BindingError::get_node(node, Some(index));
if let Some(builder) = context.report_lint(&CONFLICTING_ARGUMENT_FORMS, node) {
@ -1118,8 +1122,7 @@ impl<'db> From<CallableBinding<'db>> for Bindings<'db> {
Bindings {
callable_type: from.callable_type,
elements: smallvec_inline![from],
argument_forms: Box::from([]),
conflicting_forms: Box::from([]),
argument_forms: ArgumentForms::new(0),
}
}
}
@ -1140,8 +1143,7 @@ impl<'db> From<Binding<'db>> for Bindings<'db> {
Bindings {
callable_type,
elements: smallvec_inline![callable_binding],
argument_forms: Box::from([]),
conflicting_forms: Box::from([]),
argument_forms: ArgumentForms::new(0),
}
}
}
@ -1262,19 +1264,22 @@ impl<'db> CallableBinding<'db> {
&mut self,
db: &'db dyn Db,
arguments: &CallArguments<'_, 'db>,
argument_forms: &mut [Option<ParameterForm>],
conflicting_forms: &mut [bool],
argument_forms: &mut ArgumentForms,
) {
// If this callable is a bound method, prepend the self instance onto the arguments list
// before checking.
let arguments = arguments.with_self(self.bound_type);
for overload in &mut self.overloads {
overload.match_parameters(db, arguments.as_ref(), argument_forms, conflicting_forms);
overload.match_parameters(db, arguments.as_ref(), argument_forms);
}
}
fn check_types(&mut self, db: &'db dyn Db, argument_types: &CallArguments<'_, 'db>) {
fn check_types(
&mut self,
db: &'db dyn Db,
argument_types: &CallArguments<'_, 'db>,
) -> Option<ArgumentForms> {
// If this callable is a bound method, prepend the self instance onto the arguments list
// before checking.
let argument_types = argument_types.with_self(self.bound_type);
@ -1288,14 +1293,14 @@ impl<'db> CallableBinding<'db> {
if let [overload] = self.overloads.as_mut_slice() {
overload.check_types(db, argument_types.as_ref());
}
return;
return None;
}
MatchingOverloadIndex::Single(index) => {
// If only one candidate overload remains, it is the winning match. Evaluate it as
// a regular (non-overloaded) call.
self.matching_overload_index = Some(index);
self.overloads[index].check_types(db, argument_types.as_ref());
return;
return None;
}
MatchingOverloadIndex::Multiple(indexes) => {
// If two or more candidate overloads remain, proceed to step 2.
@ -1303,12 +1308,6 @@ impl<'db> CallableBinding<'db> {
}
};
let snapshotter = CallableBindingSnapshotter::new(matching_overload_indexes);
// State of the bindings _before_ evaluating (type checking) the matching overloads using
// the non-expanded argument types.
let pre_evaluation_snapshot = snapshotter.take(self);
// Step 2: Evaluate each remaining overload as a regular (non-overloaded) call to determine
// whether it is compatible with the supplied argument list.
for (_, overload) in self.matching_overloads_mut() {
@ -1321,7 +1320,7 @@ impl<'db> CallableBinding<'db> {
}
MatchingOverloadIndex::Single(_) => {
// If only one overload evaluates without error, it is the winning match.
return;
return None;
}
MatchingOverloadIndex::Multiple(indexes) => {
// If two or more candidate overloads remain, proceed to step 4.
@ -1330,8 +1329,8 @@ impl<'db> CallableBinding<'db> {
// Step 5
self.filter_overloads_using_any_or_unknown(db, argument_types.as_ref(), &indexes);
// We're returning here because this shouldn't lead to argument type expansion.
return;
// This shouldn't lead to argument type expansion.
return None;
}
}
@ -1339,27 +1338,14 @@ impl<'db> CallableBinding<'db> {
// https://typing.python.org/en/latest/spec/overload.html#argument-type-expansion
let mut expansions = argument_types.expand(db).peekable();
if expansions.peek().is_none() {
// Return early if there are no argument types to expand.
return;
}
// State of the bindings _after_ evaluating (type checking) the matching overloads using
// the non-expanded argument types.
let post_evaluation_snapshot = snapshotter.take(self);
// Restore the bindings state to the one prior to the type checking step in preparation
// for evaluating the expanded argument lists.
snapshotter.restore(self, pre_evaluation_snapshot);
// Return early if there are no argument types to expand.
expansions.peek()?;
// At this point, there's at least one argument that can be expanded.
//
// This heuristic tries to detect if there's any need to perform argument type expansion or
// not by checking whether there are any non-expandable argument type that cannot be
// assigned to any of the remaining overloads.
//
// This heuristic needs to be applied after restoring the bindings state to the one before
// type checking as argument type expansion would evaluate it from that point on.
// assigned to any of the overloads.
for (argument_index, (argument, argument_type)) in argument_types.iter().enumerate() {
// TODO: Remove `Keywords` once `**kwargs` support is added
if matches!(argument, Argument::Synthetic | Argument::Keywords) {
@ -1372,7 +1358,7 @@ impl<'db> CallableBinding<'db> {
continue;
}
let mut is_argument_assignable_to_any_overload = false;
'overload: for (_, overload) in self.matching_overloads() {
'overload: for overload in &self.overloads {
for parameter_index in &overload.argument_matches[argument_index].parameters {
let parameter_type = overload.signature.parameters()[*parameter_index]
.annotated_type()
@ -1389,11 +1375,16 @@ impl<'db> CallableBinding<'db> {
remaining overloads, skipping argument type expansion",
argument_type.display(db)
);
snapshotter.restore(self, post_evaluation_snapshot);
return;
return None;
}
}
let snapshotter = CallableBindingSnapshotter::new(matching_overload_indexes);
// State of the bindings _after_ evaluating (type checking) the matching overloads using
// the non-expanded argument types.
let post_evaluation_snapshot = snapshotter.take(self);
for expansion in expansions {
let expanded_argument_lists = match expansion {
Expansion::LimitReached(index) => {
@ -1401,7 +1392,7 @@ impl<'db> CallableBinding<'db> {
self.overload_call_return_type = Some(
OverloadCallReturnType::ArgumentTypeExpansionLimitReached(index),
);
return;
return None;
}
Expansion::Expanded(argument_lists) => argument_lists,
};
@ -1411,13 +1402,33 @@ impl<'db> CallableBinding<'db> {
// the expanded argument lists evaluated successfully.
let mut merged_evaluation_state: Option<CallableBindingSnapshot<'db>> = None;
// Merged argument forms after evaluating all the argument lists in this expansion.
let mut merged_argument_forms = ArgumentForms::default();
// The return types of each of the expanded argument lists that evaluated successfully.
let mut return_types = Vec::new();
for expanded_argument_types in &expanded_argument_lists {
let pre_evaluation_snapshot = snapshotter.take(self);
for expanded_arguments in &expanded_argument_lists {
let mut argument_forms = ArgumentForms::new(expanded_arguments.len());
// The spec mentions that each expanded argument list should be re-evaluated from
// step 2 but we need to re-evaluate from step 1 because our step 1 does more than
// what the spec mentions. Step 1 of the spec means only "eliminate impossible
// overloads due to arity mismatch" while our step 1 (`match_parameters`) also
// includes "match arguments to the parameters". This is important because it
// allows us to correctly handle cases involving a variadic argument that could
// expand into different number of arguments with each expansion. Refer to
// https://github.com/astral-sh/ty/issues/735 for more details.
for overload in &mut self.overloads {
// Clear the state of all overloads before re-evaluating from step 1
overload.reset();
overload.match_parameters(db, expanded_arguments, &mut argument_forms);
}
merged_argument_forms.merge(&argument_forms);
for (_, overload) in self.matching_overloads_mut() {
overload.check_types(db, expanded_argument_types);
overload.check_types(db, expanded_arguments);
}
let return_type = match self.matching_overload_index() {
@ -1430,7 +1441,7 @@ impl<'db> CallableBinding<'db> {
self.filter_overloads_using_any_or_unknown(
db,
expanded_argument_types,
expanded_arguments,
&matching_overload_indexes,
);
@ -1451,9 +1462,6 @@ impl<'db> CallableBinding<'db> {
merged_evaluation_state = Some(snapshotter.take(self));
}
// Restore the bindings state before evaluating the next argument list.
snapshotter.restore(self, pre_evaluation_snapshot);
if let Some(return_type) = return_type {
return_types.push(return_type);
} else {
@ -1481,7 +1489,7 @@ impl<'db> CallableBinding<'db> {
UnionType::from_elements(db, return_types),
));
return;
return Some(merged_argument_forms);
}
}
@ -1490,6 +1498,8 @@ impl<'db> CallableBinding<'db> {
// argument types. This is necessary because we restore the state to the pre-evaluation
// snapshot when processing the expanded argument lists.
snapshotter.restore(self, post_evaluation_snapshot);
None
}
/// Filter overloads based on [`Any`] or [`Unknown`] argument types.
@ -1915,10 +1925,59 @@ enum MatchingOverloadIndex {
Multiple(Vec<usize>),
}
#[derive(Default, Debug)]
struct ArgumentForms {
values: Vec<Option<ParameterForm>>,
conflicting: Vec<bool>,
}
impl ArgumentForms {
/// Create a new argument forms initialized to the given length and the default values.
fn new(len: usize) -> Self {
Self {
values: vec![None; len],
conflicting: vec![false; len],
}
}
fn merge(&mut self, other: &ArgumentForms) {
if self.values.len() < other.values.len() {
self.values.resize(other.values.len(), None);
self.conflicting.resize(other.conflicting.len(), false);
}
for (index, (other_form, other_conflict)) in other
.values
.iter()
.zip(other.conflicting.iter())
.enumerate()
{
if let Some(self_form) = &mut self.values[index] {
if let Some(other_form) = other_form {
if *self_form != *other_form {
// Different parameter forms, mark as conflicting
self.conflicting[index] = true;
*self_form = *other_form; // Use the new form
}
}
} else {
self.values[index] = *other_form;
}
// Update the conflicting form (true takes precedence)
self.conflicting[index] |= *other_conflict;
}
}
fn shrink_to_fit(&mut self) {
self.values.shrink_to_fit();
self.conflicting.shrink_to_fit();
}
}
struct ArgumentMatcher<'a, 'db> {
parameters: &'a Parameters<'db>,
argument_forms: &'a mut [Option<ParameterForm>],
conflicting_forms: &'a mut [bool],
argument_forms: &'a mut ArgumentForms,
errors: &'a mut Vec<BindingError<'db>>,
argument_matches: Vec<MatchedArgument<'db>>,
@ -1932,14 +1991,12 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
fn new(
arguments: &CallArguments,
parameters: &'a Parameters<'db>,
argument_forms: &'a mut [Option<ParameterForm>],
conflicting_forms: &'a mut [bool],
argument_forms: &'a mut ArgumentForms,
errors: &'a mut Vec<BindingError<'db>>,
) -> Self {
Self {
parameters,
argument_forms,
conflicting_forms,
errors,
argument_matches: vec![MatchedArgument::default(); arguments.len()],
parameter_matched: vec![false; parameters.len()],
@ -1971,11 +2028,13 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
positional: bool,
) {
if !matches!(argument, Argument::Synthetic) {
if let Some(existing) = self.argument_forms[argument_index - self.num_synthetic_args]
.replace(parameter.form)
let adjusted_argument_index = argument_index - self.num_synthetic_args;
if let Some(existing) =
self.argument_forms.values[adjusted_argument_index].replace(parameter.form)
{
if existing != parameter.form {
self.conflicting_forms[argument_index - self.num_synthetic_args] = true;
self.argument_forms.conflicting[argument_index - self.num_synthetic_args] =
true;
}
}
}
@ -2295,22 +2354,6 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
// how many elements the iterator will produce.
let argument_types = argument_type.iterate(self.db);
// TODO: When we perform argument expansion during overload resolution, we might need
// to retry both `match_parameters` _and_ `check_types` for each expansion. Currently
// we only retry `check_types`. The issue is that argument expansion might produce a
// splatted value with a different arity than what we originally inferred for the
// unexpanded value, and that in turn can affect which parameters the splatted value is
// matched with. As a workaround, make sure that the splatted tuple contains an
// arbitrary number of `Unknown`s at the end, so that if the expanded value has a
// smaller arity than the unexpanded value, we still have enough values to assign to
// the already matched parameters.
let argument_types = match argument_types.as_ref() {
Tuple::Fixed(_) => {
Cow::Owned(argument_types.concat(self.db, &Tuple::homogeneous(Type::unknown())))
}
Tuple::Variable(_) => argument_types,
};
// Resize the tuple of argument types to line up with the number of parameters this
// argument was matched against. If parameter matching succeeded, then we can (TODO:
// should be able to, see above) guarantee that all of the required elements of the
@ -2441,21 +2484,15 @@ impl<'db> Binding<'db> {
}
}
pub(crate) fn match_parameters(
fn match_parameters(
&mut self,
db: &'db dyn Db,
arguments: &CallArguments<'_, 'db>,
argument_forms: &mut [Option<ParameterForm>],
conflicting_forms: &mut [bool],
argument_forms: &mut ArgumentForms,
) {
let parameters = self.signature.parameters();
let mut matcher = ArgumentMatcher::new(
arguments,
parameters,
argument_forms,
conflicting_forms,
&mut self.errors,
);
let mut matcher =
ArgumentMatcher::new(arguments, parameters, argument_forms, &mut self.errors);
for (argument_index, (argument, argument_type)) in arguments.iter().enumerate() {
match argument {
Argument::Positional | Argument::Synthetic => {
@ -2610,6 +2647,16 @@ impl<'db> Binding<'db> {
pub(crate) fn errors(&self) -> &[BindingError<'db>] {
&self.errors
}
/// Resets the state of this binding to its initial state.
fn reset(&mut self) {
self.return_ty = Type::unknown();
self.specialization = None;
self.inherited_specialization = None;
self.argument_matches = Box::from([]);
self.parameter_tys = Box::from([]);
self.errors.clear();
}
}
#[derive(Clone, Debug)]

View file

@ -5833,7 +5833,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let bindings = callable_type
.bindings(self.db())
.match_parameters(self.db(), &call_arguments);
self.infer_argument_types(arguments, &mut call_arguments, &bindings.argument_forms);
self.infer_argument_types(arguments, &mut call_arguments, bindings.argument_forms());
// Validate `TypedDict` constructor calls after argument type inference
if let Some(class_literal) = callable_type.into_class_literal() {

View file

@ -1015,11 +1015,6 @@ impl<'db> Tuple<Type<'db>> {
UnionType::from_elements(db, self.all_elements())
}
/// Concatenates another tuple to the end of this tuple, returning a new tuple.
pub(crate) fn concat(&self, db: &'db dyn Db, other: &Self) -> Self {
TupleSpecBuilder::from(self).concat(db, other).build()
}
/// Resizes this tuple to a different length, if possible. If this tuple cannot satisfy the
/// desired minimum or maximum length, we return an error. If we return an `Ok` result, the
/// [`len`][Self::len] of the resulting tuple is guaranteed to be equal to `new_length`.