[ty] Match variadic argument to variadic parameter (#20511)

## Summary

Closes: https://github.com/astral-sh/ty/issues/1236

This PR fixes a bug where the variadic argument wouldn't match against
the variadic parameter in certain scenarios.

This was happening because I didn't realize that the `all_elements`
iterator wouldn't keep on returning the variable element (which is
correct, I just didn't realize it back then).

I don't think we can use the `resize` method here because we don't know
how many parameters this variadic argument is matching against as this
is where the actual parameter matching occurs.

## Test Plan

Expand test cases to consider a few more combinations of arguments and
parameters which are variadic.
This commit is contained in:
Dhruv Manilawala 2025-09-25 13:21:56 +05:30 committed by GitHub
parent edeb45804e
commit e1bb74b25a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 150 additions and 67 deletions

View file

@ -642,6 +642,96 @@ def f(*args: int) -> int:
reveal_type(f("foo")) # revealed: int
```
### Variadic argument, variadic parameter
```toml
[environment]
python-version = "3.11"
```
```py
def f(*args: int) -> int:
return 1
def _(args: list[str]) -> None:
# error: [invalid-argument-type] "Argument to function `f` is incorrect: Expected `int`, found `str`"
reveal_type(f(*args)) # revealed: int
```
Considering a few different shapes of tuple for the splatted argument:
```py
def f1(*args: str): ...
def _(
args1: tuple[str, ...],
args2: tuple[str, *tuple[str, ...]],
args3: tuple[str, *tuple[str, ...], str],
args4: tuple[int, *tuple[str, ...]],
args5: tuple[int, *tuple[str, ...], str],
args6: tuple[*tuple[str, ...], str],
args7: tuple[*tuple[str, ...], int],
args8: tuple[int, *tuple[str, ...], int],
args9: tuple[str, *tuple[str, ...], int],
args10: tuple[str, *tuple[int, ...], str],
):
f1(*args1)
f1(*args2)
f1(*args3)
f1(*args4) # error: [invalid-argument-type]
f1(*args5) # error: [invalid-argument-type]
f1(*args6)
f1(*args7) # error: [invalid-argument-type]
# The reason for two errors here is because of the two fixed elements in the tuple of `args8`
# which are both `int`
# error: [invalid-argument-type]
# error: [invalid-argument-type]
f1(*args8)
f1(*args9) # error: [invalid-argument-type]
f1(*args10) # error: [invalid-argument-type]
```
### Mixed argument and parameter containing variadic
```toml
[environment]
python-version = "3.11"
```
```py
def f(x: int, *args: str) -> int:
return 1
def _(
args1: list[int],
args2: tuple[int],
args3: tuple[int, int],
args4: tuple[int, ...],
args5: tuple[int, *tuple[str, ...]],
args6: tuple[int, int, *tuple[str, ...]],
) -> None:
# error: [invalid-argument-type] "Argument to function `f` is incorrect: Expected `str`, found `int`"
reveal_type(f(*args1)) # revealed: int
# This shouldn't raise an error because the unpacking doesn't match the variadic parameter.
reveal_type(f(*args2)) # revealed: int
# But, this should because the second tuple element is not assignable.
# error: [invalid-argument-type] "Argument to function `f` is incorrect: Expected `str`, found `int`"
reveal_type(f(*args3)) # revealed: int
# error: [invalid-argument-type] "Argument to function `f` is incorrect: Expected `str`, found `int`"
reveal_type(f(*args4)) # revealed: int
# The first element of the tuple matches the required argument;
# all subsequent elements match the variadic argument
reveal_type(f(*args5)) # revealed: int
# error: [invalid-argument-type] "Argument to function `f` is incorrect: Expected `str`, found `int`"
reveal_type(f(*args6)) # revealed: int
```
### Keyword argument, positional-or-keyword parameter
```py

View file

@ -6,7 +6,7 @@ use ruff_python_ast as ast;
use crate::Db;
use crate::types::KnownClass;
use crate::types::enums::{enum_member_literals, enum_metadata};
use crate::types::tuple::{Tuple, TupleLength, TupleType};
use crate::types::tuple::{Tuple, TupleType};
use super::Type;
@ -17,7 +17,7 @@ pub(crate) enum Argument<'a> {
/// A positional argument.
Positional,
/// A starred positional argument (e.g. `*args`) containing the specified number of elements.
Variadic(TupleLength),
Variadic,
/// A keyword argument (e.g. `a=1`).
Keyword(&'a str),
/// The double-starred keywords argument (e.g. `**kwargs`).
@ -41,7 +41,6 @@ impl<'a, 'db> CallArguments<'a, 'db> {
/// type of each splatted argument, so that we can determine its length. All other arguments
/// will remain uninitialized as `Unknown`.
pub(crate) fn from_arguments(
db: &'db dyn Db,
arguments: &'a ast::Arguments,
mut infer_argument_type: impl FnMut(Option<&ast::Expr>, &ast::Expr) -> Type<'db>,
) -> Self {
@ -51,11 +50,7 @@ impl<'a, 'db> CallArguments<'a, 'db> {
ast::ArgOrKeyword::Arg(arg) => match arg {
ast::Expr::Starred(ast::ExprStarred { value, .. }) => {
let ty = infer_argument_type(Some(arg), value);
let length = ty
.try_iterate(db)
.map(|tuple| tuple.len())
.unwrap_or(TupleLength::unknown());
(Argument::Variadic(length), Some(ty))
(Argument::Variadic, Some(ty))
}
_ => (Argument::Positional, None),
},
@ -203,25 +198,10 @@ impl<'a, 'db> CallArguments<'a, 'db> {
for subtype in &expanded_types {
let mut new_expanded_types = pre_expanded_types.to_vec();
new_expanded_types[index] = Some(*subtype);
// Update the arguments list to handle variadic argument expansion
let mut new_arguments = self.arguments.clone();
if let Argument::Variadic(_) = self.arguments[index] {
// If the argument corresponding to this type is variadic, we need to
// update the tuple length because expanding could change the length.
// For example, in `tuple[int] | tuple[int, int]`, the length of the
// first type is 1, while the length of the second type is 2.
if let Some(expanded_type) = new_expanded_types[index] {
let length = expanded_type
.try_iterate(db)
.map(|tuple| tuple.len())
.unwrap_or(TupleLength::unknown());
new_arguments[index] = Argument::Variadic(length);
}
}
expanded_arguments
.push(CallArguments::new(new_arguments, new_expanded_types));
expanded_arguments.push(CallArguments::new(
self.arguments.clone(),
new_expanded_types,
));
}
}

View file

@ -2135,24 +2135,36 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
Ok(())
}
/// Match a variadic argument to the remaining positional, standard or variadic parameters.
fn match_variadic(
&mut self,
db: &'db dyn Db,
argument_index: usize,
argument: Argument<'a>,
argument_type: Option<Type<'db>>,
length: TupleLength,
) -> Result<(), ()> {
let tuple = argument_type.map(|ty| ty.iterate(db));
let mut argument_types = match tuple.as_ref() {
Some(tuple) => Either::Left(tuple.all_elements().copied()),
None => Either::Right(std::iter::empty()),
let (mut argument_types, length, variable_element) = match tuple.as_ref() {
Some(tuple) => (
Either::Left(tuple.all_elements().copied()),
tuple.len(),
tuple.variable_element().copied(),
),
None => (
Either::Right(std::iter::empty()),
TupleLength::unknown(),
None,
),
};
// We must be able to match up the fixed-length portion of the argument with positional
// parameters, so we pass on any errors that occur.
for _ in 0..length.minimum() {
self.match_positional(argument_index, argument, argument_types.next())?;
self.match_positional(
argument_index,
argument,
argument_types.next().or(variable_element),
)?;
}
// If the tuple is variable-length, we assume that it will soak up all remaining positional
@ -2163,7 +2175,24 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
.get_positional(self.next_positional)
.is_some()
{
self.match_positional(argument_index, argument, argument_types.next())?;
self.match_positional(
argument_index,
argument,
argument_types.next().or(variable_element),
)?;
}
}
// Finally, if there is a variadic parameter we can match any of the remaining unpacked
// argument types to it, but only if there is at least one remaining argument type. This is
// because a variadic parameter is optional, so if this was done unconditionally, ty could
// raise a false positive as "too many arguments".
if self.parameters.variadic().is_some() {
if let Some(argument_type) = argument_types.next().or(variable_element) {
self.match_positional(argument_index, argument, Some(argument_type))?;
for argument_type in argument_types {
self.match_positional(argument_index, argument, Some(argument_type))?;
}
}
}
@ -2433,11 +2462,10 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
self.enumerate_argument_types()
{
match argument {
Argument::Variadic(_) => self.check_variadic_argument_type(
Argument::Variadic => self.check_variadic_argument_type(
argument_index,
adjusted_argument_index,
argument,
argument_type,
),
Argument::Keywords => self.check_keyword_variadic_argument_type(
argument_index,
@ -2465,37 +2493,15 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
argument_index: usize,
adjusted_argument_index: Option<usize>,
argument: Argument<'a>,
argument_type: Type<'db>,
) {
// If the argument is splatted, convert its type into a tuple describing the splatted
// elements. For tuples, we don't have to do anything! For other types, we treat it as
// an iterator, and create a homogeneous tuple of its output type, since we don't know
// how many elements the iterator will produce.
let argument_types = argument_type.iterate(self.db);
// Resize the tuple of argument types to line up with the number of parameters this
// argument was matched against. If parameter matching succeeded, then we can (TODO:
// should be able to, see above) guarantee that all of the required elements of the
// splatted tuple will have been matched with a parameter. But if parameter matching
// failed, there might be more required elements. That means we can't use
// TupleLength::Fixed below, because we would otherwise get a "too many values" error
// when parameter matching failed.
let desired_size =
TupleLength::Variable(self.argument_matches[argument_index].parameters.len(), 0);
let argument_types = argument_types
.resize(self.db, desired_size)
.expect("argument type should be consistent with its arity");
// Check the types by zipping through the splatted argument types and their matched
// parameters.
for (argument_type, parameter_index) in
(argument_types.all_elements()).zip(&self.argument_matches[argument_index].parameters)
for (parameter_index, variadic_argument_type) in
self.argument_matches[argument_index].iter()
{
self.check_argument_type(
adjusted_argument_index,
argument,
*argument_type,
*parameter_index,
variadic_argument_type.unwrap_or_else(Type::unknown),
parameter_index,
);
}
}
@ -2711,9 +2717,8 @@ impl<'db> Binding<'db> {
Argument::Keyword(name) => {
let _ = matcher.match_keyword(argument_index, argument, None, name);
}
Argument::Variadic(length) => {
let _ =
matcher.match_variadic(db, argument_index, argument, argument_type, length);
Argument::Variadic => {
let _ = matcher.match_variadic(db, argument_index, argument, argument_type);
}
Argument::Keywords => {
keywords_arguments.push((argument_index, argument_type));

View file

@ -874,7 +874,7 @@ pub fn call_signature_details<'db>(
// Use into_callable to handle all the complex type conversions
if let Some(callable_type) = func_type.into_callable(db) {
let call_arguments =
CallArguments::from_arguments(db, &call_expr.arguments, |_, splatted_value| {
CallArguments::from_arguments(&call_expr.arguments, |_, splatted_value| {
splatted_value.inferred_type(model)
});
let bindings = callable_type

View file

@ -1733,7 +1733,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let previous_deferred_state =
std::mem::replace(&mut self.deferred_state, in_stub.into());
let mut call_arguments =
CallArguments::from_arguments(self.db(), arguments, |argument, splatted_value| {
CallArguments::from_arguments(arguments, |argument, splatted_value| {
let ty = self.infer_expression(splatted_value, TypeContext::default());
if let Some(argument) = argument {
self.store_expression_type(argument, ty);
@ -5831,7 +5831,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// arguments after matching them to parameters, but before checking that the argument types
// are assignable to any parameter annotations.
let mut call_arguments =
CallArguments::from_arguments(self.db(), arguments, |argument, splatted_value| {
CallArguments::from_arguments(arguments, |argument, splatted_value| {
let ty = self.infer_expression(splatted_value, TypeContext::default());
if let Some(argument) = argument {
self.store_expression_type(argument, ty);

View file

@ -970,6 +970,14 @@ impl<T> Tuple<T> {
FixedLengthTuple::from_elements(elements).into()
}
/// Returns the variable-length element of this tuple, if it has one.
pub(crate) fn variable_element(&self) -> Option<&T> {
match self {
Tuple::Fixed(_) => None,
Tuple::Variable(tuple) => Some(&tuple.variable),
}
}
/// Returns an iterator of all of the fixed-length element types of this tuple.
pub(crate) fn fixed_elements(&self) -> impl Iterator<Item = &T> + '_ {
match self {