[ty] Unpack variadic argument type in specialization (#20130)
Some checks are pending
CI / Determine changes (push) Waiting to run
CI / cargo fmt (push) Waiting to run
CI / cargo clippy (push) Blocked by required conditions
CI / cargo test (linux) (push) Blocked by required conditions
CI / cargo test (linux, release) (push) Blocked by required conditions
CI / cargo test (windows) (push) Blocked by required conditions
CI / cargo test (wasm) (push) Blocked by required conditions
CI / cargo build (release) (push) Waiting to run
CI / cargo build (msrv) (push) Blocked by required conditions
CI / cargo fuzz build (push) Blocked by required conditions
CI / fuzz parser (push) Blocked by required conditions
CI / test scripts (push) Blocked by required conditions
CI / ecosystem (push) Blocked by required conditions
CI / Fuzz for new ty panics (push) Blocked by required conditions
CI / cargo shear (push) Blocked by required conditions
CI / python package (push) Waiting to run
CI / pre-commit (push) Waiting to run
CI / mkdocs (push) Waiting to run
CI / formatter instabilities and black similarity (push) Blocked by required conditions
CI / test ruff-lsp (push) Blocked by required conditions
CI / check playground (push) Blocked by required conditions
CI / benchmarks-instrumented (push) Blocked by required conditions
CI / benchmarks-walltime (push) Blocked by required conditions
[ty Playground] Release / publish (push) Waiting to run

## Summary

This PR fixes various TODOs around overload call when a variadic
argument is used.

The reason this bug existed is because the specialization wouldn't
account for unpacking the type of the variadic argument.

This is fixed by expanding `MatchedArgument` to contain the type of that
argument _only_ when it is a variadic argument. The reason is that
there's a split for when the argument type is inferred -- the
non-variadic arguments are inferred using `infer_argument_types` _after_
parameter matching while the variadic argument type is inferred _during_
the parameter matching. And, the `MatchedArgument` is populated _during_
parameter matching which means the unpacking would need to happen during
parameter matching.

This split seems a bit inconsistent but I don't want to spend a lot of
time on trying to merge them such that all argument type inference
happens in a single place. I might look into it while adding support for
`**kwargs`.

## Test Plan

Update existing tests by resolving the todos.

The ecosystem changes looks correct to me except for the `slice` call
but it seems that it's unrelated to this PR as we infer `slice[Any, Any,
Any]` for a `slice(1, 2, 3)` call on `main` as well
([playground](https://play.ty.dev/9eacce00-c7d5-4dd5-a932-4265cb2bb4f6)).
This commit is contained in:
Dhruv Manilawala 2025-08-29 09:57:28 +05:30 committed by GitHub
parent a8039f80f0
commit 4ca38b2974
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 94 additions and 53 deletions

View file

@ -290,16 +290,10 @@ from overloaded import A, f
def _(x: int, y: A | int): def _(x: int, y: A | int):
reveal_type(f(x)) # revealed: int reveal_type(f(x)) # revealed: int
# TODO: revealed: int reveal_type(f(*(x,))) # revealed: int
# TODO: no error
# error: [no-matching-overload]
reveal_type(f(*(x,))) # revealed: Unknown
reveal_type(f(y)) # revealed: A | int reveal_type(f(y)) # revealed: A | int
# TODO: revealed: A | int reveal_type(f(*(y,))) # revealed: A | int
# TODO: no error
# error: [no-matching-overload]
reveal_type(f(*(y,))) # revealed: Unknown
``` ```
### Generics (PEP 695) ### Generics (PEP 695)
@ -328,16 +322,10 @@ from overloaded import B, f
def _(x: int, y: B | int): def _(x: int, y: B | int):
reveal_type(f(x)) # revealed: int reveal_type(f(x)) # revealed: int
# TODO: revealed: int reveal_type(f(*(x,))) # revealed: int
# TODO: no error
# error: [no-matching-overload]
reveal_type(f(*(x,))) # revealed: Unknown
reveal_type(f(y)) # revealed: B | int reveal_type(f(y)) # revealed: B | int
# TODO: revealed: B | int reveal_type(f(*(y,))) # revealed: B | int
# TODO: no error
# error: [no-matching-overload]
reveal_type(f(*(y,))) # revealed: Unknown
``` ```
### Expanding `bool` ### Expanding `bool`
@ -1236,21 +1224,14 @@ def _(integer: int, string: str, any: Any, list_any: list[Any]):
reveal_type(f(*(integer, string))) # revealed: int reveal_type(f(*(integer, string))) # revealed: int
reveal_type(f(string, integer)) # revealed: int reveal_type(f(string, integer)) # revealed: int
# TODO: revealed: int reveal_type(f(*(string, integer))) # revealed: int
# TODO: no error
# error: [no-matching-overload]
reveal_type(f(*(string, integer))) # revealed: Unknown
# This matches the second overload and is _not_ the case of ambiguous overload matching. # This matches the second overload and is _not_ the case of ambiguous overload matching.
reveal_type(f(string, any)) # revealed: Any reveal_type(f(string, any)) # revealed: Any
# TODO: Any reveal_type(f(*(string, any))) # revealed: Any
reveal_type(f(*(string, any))) # revealed: tuple[str, Any]
reveal_type(f(string, list_any)) # revealed: list[Any] reveal_type(f(string, list_any)) # revealed: list[Any]
# TODO: revealed: list[Any] reveal_type(f(*(string, list_any))) # revealed: list[Any]
# TODO: no error
# error: [no-matching-overload]
reveal_type(f(*(string, list_any))) # revealed: Unknown
``` ```
### Generic `self` ### Generic `self`

View file

@ -4843,7 +4843,7 @@ impl<'db> Type<'db> {
argument_types: &CallArguments<'_, 'db>, argument_types: &CallArguments<'_, 'db>,
) -> Result<Bindings<'db>, CallError<'db>> { ) -> Result<Bindings<'db>, CallError<'db>> {
self.bindings(db) self.bindings(db)
.match_parameters(argument_types) .match_parameters(db, argument_types)
.check_types(db, argument_types) .check_types(db, argument_types)
} }
@ -4892,7 +4892,7 @@ impl<'db> Type<'db> {
Place::Type(dunder_callable, boundness) => { Place::Type(dunder_callable, boundness) => {
let bindings = dunder_callable let bindings = dunder_callable
.bindings(db) .bindings(db)
.match_parameters(argument_types) .match_parameters(db, argument_types)
.check_types(db, argument_types)?; .check_types(db, argument_types)?;
if boundness == Boundness::PossiblyUnbound { if boundness == Boundness::PossiblyUnbound {
return Err(CallDunderError::PossiblyUnbound(Box::new(bindings))); return Err(CallDunderError::PossiblyUnbound(Box::new(bindings)));

View file

@ -7,7 +7,7 @@ use std::borrow::Cow;
use std::collections::HashSet; use std::collections::HashSet;
use std::fmt; use std::fmt;
use itertools::Itertools; use itertools::{Either, Itertools};
use ruff_db::parsed::parsed_module; use ruff_db::parsed::parsed_module;
use smallvec::{SmallVec, smallvec, smallvec_inline}; use smallvec::{SmallVec, smallvec, smallvec_inline};
@ -101,11 +101,15 @@ impl<'db> Bindings<'db> {
/// ///
/// Once you have argument types available, you can call [`check_types`][Self::check_types] to /// Once you have argument types available, you can call [`check_types`][Self::check_types] to
/// verify that each argument type is assignable to the corresponding parameter type. /// verify that each argument type is assignable to the corresponding parameter type.
pub(crate) fn match_parameters(mut self, arguments: &CallArguments<'_, 'db>) -> Self { pub(crate) fn match_parameters(
mut self,
db: &'db dyn Db,
arguments: &CallArguments<'_, 'db>,
) -> Self {
let mut argument_forms = vec![None; arguments.len()]; let mut argument_forms = vec![None; arguments.len()];
let mut conflicting_forms = vec![false; arguments.len()]; let mut conflicting_forms = vec![false; arguments.len()];
for binding in &mut self.elements { for binding in &mut self.elements {
binding.match_parameters(arguments, &mut argument_forms, &mut conflicting_forms); binding.match_parameters(db, arguments, &mut argument_forms, &mut conflicting_forms);
} }
self.argument_forms = argument_forms.into(); self.argument_forms = argument_forms.into();
self.conflicting_forms = conflicting_forms.into(); self.conflicting_forms = conflicting_forms.into();
@ -1243,6 +1247,7 @@ impl<'db> CallableBinding<'db> {
fn match_parameters( fn match_parameters(
&mut self, &mut self,
db: &'db dyn Db,
arguments: &CallArguments<'_, 'db>, arguments: &CallArguments<'_, 'db>,
argument_forms: &mut [Option<ParameterForm>], argument_forms: &mut [Option<ParameterForm>],
conflicting_forms: &mut [bool], conflicting_forms: &mut [bool],
@ -1252,7 +1257,7 @@ impl<'db> CallableBinding<'db> {
let arguments = arguments.with_self(self.bound_type); let arguments = arguments.with_self(self.bound_type);
for overload in &mut self.overloads { for overload in &mut self.overloads {
overload.match_parameters(arguments.as_ref(), argument_forms, conflicting_forms); overload.match_parameters(db, arguments.as_ref(), argument_forms, conflicting_forms);
} }
} }
@ -1903,7 +1908,7 @@ struct ArgumentMatcher<'a, 'db> {
conflicting_forms: &'a mut [bool], conflicting_forms: &'a mut [bool],
errors: &'a mut Vec<BindingError<'db>>, errors: &'a mut Vec<BindingError<'db>>,
argument_matches: Vec<MatchedArgument>, argument_matches: Vec<MatchedArgument<'db>>,
parameter_matched: Vec<bool>, parameter_matched: Vec<bool>,
next_positional: usize, next_positional: usize,
first_excess_positional: Option<usize>, first_excess_positional: Option<usize>,
@ -1947,6 +1952,7 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
&mut self, &mut self,
argument_index: usize, argument_index: usize,
argument: Argument<'a>, argument: Argument<'a>,
argument_type: Option<Type<'db>>,
parameter_index: usize, parameter_index: usize,
parameter: &Parameter<'db>, parameter: &Parameter<'db>,
positional: bool, positional: bool,
@ -1970,6 +1976,7 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
} }
let matched_argument = &mut self.argument_matches[argument_index]; let matched_argument = &mut self.argument_matches[argument_index];
matched_argument.parameters.push(parameter_index); matched_argument.parameters.push(parameter_index);
matched_argument.types.push(argument_type);
matched_argument.matched = true; matched_argument.matched = true;
self.parameter_matched[parameter_index] = true; self.parameter_matched[parameter_index] = true;
} }
@ -1978,6 +1985,7 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
&mut self, &mut self,
argument_index: usize, argument_index: usize,
argument: Argument<'a>, argument: Argument<'a>,
argument_type: Option<Type<'db>>,
) -> Result<(), ()> { ) -> Result<(), ()> {
if matches!(argument, Argument::Synthetic) { if matches!(argument, Argument::Synthetic) {
self.num_synthetic_args += 1; self.num_synthetic_args += 1;
@ -1996,6 +2004,7 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
self.assign_argument( self.assign_argument(
argument_index, argument_index,
argument, argument,
argument_type,
parameter_index, parameter_index,
parameter, parameter,
!parameter.is_variadic(), !parameter.is_variadic(),
@ -2020,20 +2029,35 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
}); });
return Err(()); return Err(());
}; };
self.assign_argument(argument_index, argument, parameter_index, parameter, false); self.assign_argument(
argument_index,
argument,
None,
parameter_index,
parameter,
false,
);
Ok(()) Ok(())
} }
fn match_variadic( fn match_variadic(
&mut self, &mut self,
db: &'db dyn Db,
argument_index: usize, argument_index: usize,
argument: Argument<'a>, argument: Argument<'a>,
argument_type: Option<Type<'db>>,
length: TupleLength, length: TupleLength,
) -> Result<(), ()> { ) -> Result<(), ()> {
let tuple = argument_type.map(|ty| ty.iterate(db));
let mut argument_types = match tuple.as_ref() {
Some(tuple) => Either::Left(tuple.all_elements().copied()),
None => Either::Right(std::iter::empty()),
};
// We must be able to match up the fixed-length portion of the argument with positional // We must be able to match up the fixed-length portion of the argument with positional
// parameters, so we pass on any errors that occur. // parameters, so we pass on any errors that occur.
for _ in 0..length.minimum() { for _ in 0..length.minimum() {
self.match_positional(argument_index, argument)?; self.match_positional(argument_index, argument, argument_types.next())?;
} }
// If the tuple is variable-length, we assume that it will soak up all remaining positional // If the tuple is variable-length, we assume that it will soak up all remaining positional
@ -2044,14 +2068,14 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
.get_positional(self.next_positional) .get_positional(self.next_positional)
.is_some() .is_some()
{ {
self.match_positional(argument_index, argument)?; self.match_positional(argument_index, argument, argument_types.next())?;
} }
} }
Ok(()) Ok(())
} }
fn finish(self) -> Box<[MatchedArgument]> { fn finish(self) -> Box<[MatchedArgument<'db>]> {
if let Some(first_excess_argument_index) = self.first_excess_positional { if let Some(first_excess_argument_index) = self.first_excess_positional {
self.errors.push(BindingError::TooManyPositionalArguments { self.errors.push(BindingError::TooManyPositionalArguments {
first_excess_argument_index: self.get_argument_index(first_excess_argument_index), first_excess_argument_index: self.get_argument_index(first_excess_argument_index),
@ -2088,7 +2112,7 @@ struct ArgumentTypeChecker<'a, 'db> {
db: &'db dyn Db, db: &'db dyn Db,
signature: &'a Signature<'db>, signature: &'a Signature<'db>,
arguments: &'a CallArguments<'a, 'db>, arguments: &'a CallArguments<'a, 'db>,
argument_matches: &'a [MatchedArgument], argument_matches: &'a [MatchedArgument<'db>],
parameter_tys: &'a mut [Option<Type<'db>>], parameter_tys: &'a mut [Option<Type<'db>>],
errors: &'a mut Vec<BindingError<'db>>, errors: &'a mut Vec<BindingError<'db>>,
@ -2101,7 +2125,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
db: &'db dyn Db, db: &'db dyn Db,
signature: &'a Signature<'db>, signature: &'a Signature<'db>,
arguments: &'a CallArguments<'a, 'db>, arguments: &'a CallArguments<'a, 'db>,
argument_matches: &'a [MatchedArgument], argument_matches: &'a [MatchedArgument<'db>],
parameter_tys: &'a mut [Option<Type<'db>>], parameter_tys: &'a mut [Option<Type<'db>>],
errors: &'a mut Vec<BindingError<'db>>, errors: &'a mut Vec<BindingError<'db>>,
) -> Self { ) -> Self {
@ -2156,12 +2180,17 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
for (argument_index, adjusted_argument_index, _, argument_type) in for (argument_index, adjusted_argument_index, _, argument_type) in
self.enumerate_argument_types() self.enumerate_argument_types()
{ {
for parameter_index in &self.argument_matches[argument_index].parameters { for (parameter_index, variadic_argument_type) in
let parameter = &parameters[*parameter_index]; self.argument_matches[argument_index].iter()
{
let parameter = &parameters[parameter_index];
let Some(expected_type) = parameter.annotated_type() else { let Some(expected_type) = parameter.annotated_type() else {
continue; continue;
}; };
if let Err(error) = builder.infer(expected_type, argument_type) { if let Err(error) = builder.infer(
expected_type,
variadic_argument_type.unwrap_or(argument_type),
) {
self.errors.push(BindingError::SpecializationError { self.errors.push(BindingError::SpecializationError {
error, error,
argument_index: adjusted_argument_index, argument_index: adjusted_argument_index,
@ -2305,7 +2334,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
/// Information about which parameter(s) an argument was matched against. This is tracked /// Information about which parameter(s) an argument was matched against. This is tracked
/// separately for each overload. /// separately for each overload.
#[derive(Clone, Debug, Default)] #[derive(Clone, Debug, Default)]
pub struct MatchedArgument { pub struct MatchedArgument<'db> {
/// The index of the parameter(s) that an argument was matched against. A splatted argument /// The index of the parameter(s) that an argument was matched against. A splatted argument
/// might be matched against multiple parameters. /// might be matched against multiple parameters.
pub parameters: SmallVec<[usize; 1]>, pub parameters: SmallVec<[usize; 1]>,
@ -2314,6 +2343,33 @@ pub struct MatchedArgument {
/// elements must have been successfully matched. (That means that this can be `false` while /// elements must have been successfully matched. (That means that this can be `false` while
/// the `parameters` field is non-empty.) /// the `parameters` field is non-empty.)
pub matched: bool, pub matched: bool,
/// The types of a variadic argument when it's unpacked.
///
/// The length of this vector is always the same as the `parameters` vector i.e., these are the
/// types assigned to each matched parameter. This isn't necessarily the same as the number of
/// types in the argument type which might not be a fixed-length iterable.
///
/// Another thing to note is that the way this is populated means that for any other argument
/// kind (synthetic, positional, keyword, keyword-variadic), this will be a single-element
/// vector containing `None`, since we don't know the type of the argument when this is
/// constructed. So, this field is populated only for variadic arguments.
///
/// For example, given a `*args` whose type is `tuple[A, B, C]` and the following parameters:
/// - `(x, *args)`: the `types` field will only have two elements (`B`, `C`) since `A` has been
/// matched with `x`.
/// - `(*args)`: the `types` field will have all the three elements (`A`, `B`, `C`)
types: SmallVec<[Option<Type<'db>>; 1]>,
}
impl<'db> MatchedArgument<'db> {
/// Returns an iterator over the parameter indices and the corresponding argument type.
pub fn iter(&self) -> impl Iterator<Item = (usize, Option<Type<'db>>)> + '_ {
self.parameters
.iter()
.copied()
.zip(self.types.iter().copied())
}
} }
/// Binding information for one of the overloads of a callable. /// Binding information for one of the overloads of a callable.
@ -2341,7 +2397,7 @@ pub(crate) struct Binding<'db> {
/// Information about which parameter(s) each argument was matched with, in argument source /// Information about which parameter(s) each argument was matched with, in argument source
/// order. /// order.
argument_matches: Box<[MatchedArgument]>, argument_matches: Box<[MatchedArgument<'db>]>,
/// Bound types for parameters, in parameter source order, or `None` if no argument was matched /// Bound types for parameters, in parameter source order, or `None` if no argument was matched
/// to that parameter. /// to that parameter.
@ -2374,6 +2430,7 @@ impl<'db> Binding<'db> {
pub(crate) fn match_parameters( pub(crate) fn match_parameters(
&mut self, &mut self,
db: &'db dyn Db,
arguments: &CallArguments<'_, 'db>, arguments: &CallArguments<'_, 'db>,
argument_forms: &mut [Option<ParameterForm>], argument_forms: &mut [Option<ParameterForm>],
conflicting_forms: &mut [bool], conflicting_forms: &mut [bool],
@ -2386,16 +2443,17 @@ impl<'db> Binding<'db> {
conflicting_forms, conflicting_forms,
&mut self.errors, &mut self.errors,
); );
for (argument_index, (argument, _)) in arguments.iter().enumerate() { for (argument_index, (argument, argument_type)) in arguments.iter().enumerate() {
match argument { match argument {
Argument::Positional | Argument::Synthetic => { Argument::Positional | Argument::Synthetic => {
let _ = matcher.match_positional(argument_index, argument); let _ = matcher.match_positional(argument_index, argument, None);
} }
Argument::Keyword(name) => { Argument::Keyword(name) => {
let _ = matcher.match_keyword(argument_index, argument, name); let _ = matcher.match_keyword(argument_index, argument, name);
} }
Argument::Variadic(length) => { Argument::Variadic(length) => {
let _ = matcher.match_variadic(argument_index, argument, length); let _ =
matcher.match_variadic(db, argument_index, argument, argument_type, length);
} }
Argument::Keywords => { Argument::Keywords => {
// TODO // TODO
@ -2532,7 +2590,7 @@ impl<'db> Binding<'db> {
/// Returns a vector where each index corresponds to an argument position, /// Returns a vector where each index corresponds to an argument position,
/// and the value is the parameter index that argument maps to (if any). /// and the value is the parameter index that argument maps to (if any).
pub(crate) fn argument_matches(&self) -> &[MatchedArgument] { pub(crate) fn argument_matches(&self) -> &[MatchedArgument<'db>] {
&self.argument_matches &self.argument_matches
} }
} }
@ -2542,7 +2600,7 @@ struct BindingSnapshot<'db> {
return_ty: Type<'db>, return_ty: Type<'db>,
specialization: Option<Specialization<'db>>, specialization: Option<Specialization<'db>>,
inherited_specialization: Option<Specialization<'db>>, inherited_specialization: Option<Specialization<'db>>,
argument_matches: Box<[MatchedArgument]>, argument_matches: Box<[MatchedArgument<'db>]>,
parameter_tys: Box<[Option<Type<'db>>]>, parameter_tys: Box<[Option<Type<'db>>]>,
errors: Vec<BindingError<'db>>, errors: Vec<BindingError<'db>>,
} }

View file

@ -801,7 +801,7 @@ pub struct CallSignatureDetails<'db> {
/// Mapping from argument indices to parameter indices. This helps /// Mapping from argument indices to parameter indices. This helps
/// determine which parameter corresponds to which argument position. /// determine which parameter corresponds to which argument position.
pub argument_to_parameter_mapping: Vec<MatchedArgument>, pub argument_to_parameter_mapping: Vec<MatchedArgument<'db>>,
} }
/// Extract signature details from a function call expression. /// Extract signature details from a function call expression.
@ -821,7 +821,9 @@ pub fn call_signature_details<'db>(
CallArguments::from_arguments(db, &call_expr.arguments, |_, splatted_value| { CallArguments::from_arguments(db, &call_expr.arguments, |_, splatted_value| {
splatted_value.inferred_type(model) splatted_value.inferred_type(model)
}); });
let bindings = callable_type.bindings(db).match_parameters(&call_arguments); let bindings = callable_type
.bindings(db)
.match_parameters(db, &call_arguments);
// Extract signature details from all callable bindings // Extract signature details from all callable bindings
bindings bindings

View file

@ -6360,7 +6360,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let bindings = callable_type let bindings = callable_type
.bindings(self.db()) .bindings(self.db())
.match_parameters(&call_arguments); .match_parameters(self.db(), &call_arguments);
self.infer_argument_types(arguments, &mut call_arguments, &bindings.argument_forms); self.infer_argument_types(arguments, &mut call_arguments, &bindings.argument_forms);
// Validate `TypedDict` constructor calls after argument type inference // Validate `TypedDict` constructor calls after argument type inference
@ -8812,7 +8812,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}; };
let binding = Binding::single(value_ty, generic_context.signature(self.db())); let binding = Binding::single(value_ty, generic_context.signature(self.db()));
let bindings = match Bindings::from(binding) let bindings = match Bindings::from(binding)
.match_parameters(&call_argument_types) .match_parameters(self.db(), &call_argument_types)
.check_types(self.db(), &call_argument_types) .check_types(self.db(), &call_argument_types)
{ {
Ok(bindings) => bindings, Ok(bindings) => bindings,