From e1bb74b25a0f730a5baff263d753ffe03cf857b1 Mon Sep 17 00:00:00 2001 From: Dhruv Manilawala Date: Thu, 25 Sep 2025 13:21:56 +0530 Subject: [PATCH] [ty] Match variadic argument to variadic parameter (#20511) ## Summary Closes: https://github.com/astral-sh/ty/issues/1236 This PR fixes a bug where the variadic argument wouldn't match against the variadic parameter in certain scenarios. This was happening because I didn't realize that the `all_elements` iterator wouldn't keep on returning the variable element (which is correct, I just didn't realize it back then). I don't think we can use the `resize` method here because we don't know how many parameters this variadic argument is matching against as this is where the actual parameter matching occurs. ## Test Plan Expand test cases to consider a few more combinations of arguments and parameters which are variadic. --- .../resources/mdtest/call/function.md | 90 +++++++++++++++++++ .../src/types/call/arguments.rs | 34 ++----- .../ty_python_semantic/src/types/call/bind.rs | 79 ++++++++-------- .../src/types/ide_support.rs | 2 +- .../src/types/infer/builder.rs | 4 +- crates/ty_python_semantic/src/types/tuple.rs | 8 ++ 6 files changed, 150 insertions(+), 67 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/call/function.md b/crates/ty_python_semantic/resources/mdtest/call/function.md index 58ee1428a3..de4bbe2e7f 100644 --- a/crates/ty_python_semantic/resources/mdtest/call/function.md +++ b/crates/ty_python_semantic/resources/mdtest/call/function.md @@ -642,6 +642,96 @@ def f(*args: int) -> int: reveal_type(f("foo")) # revealed: int ``` +### Variadic argument, variadic parameter + +```toml +[environment] +python-version = "3.11" +``` + +```py +def f(*args: int) -> int: + return 1 + +def _(args: list[str]) -> None: + # error: [invalid-argument-type] "Argument to function `f` is incorrect: Expected `int`, found `str`" + reveal_type(f(*args)) # revealed: int +``` + +Considering a few different shapes of tuple for the splatted argument: + +```py +def f1(*args: str): ... +def _( + args1: tuple[str, ...], + args2: tuple[str, *tuple[str, ...]], + args3: tuple[str, *tuple[str, ...], str], + args4: tuple[int, *tuple[str, ...]], + args5: tuple[int, *tuple[str, ...], str], + args6: tuple[*tuple[str, ...], str], + args7: tuple[*tuple[str, ...], int], + args8: tuple[int, *tuple[str, ...], int], + args9: tuple[str, *tuple[str, ...], int], + args10: tuple[str, *tuple[int, ...], str], +): + f1(*args1) + f1(*args2) + f1(*args3) + f1(*args4) # error: [invalid-argument-type] + f1(*args5) # error: [invalid-argument-type] + f1(*args6) + f1(*args7) # error: [invalid-argument-type] + + # The reason for two errors here is because of the two fixed elements in the tuple of `args8` + # which are both `int` + # error: [invalid-argument-type] + # error: [invalid-argument-type] + f1(*args8) + + f1(*args9) # error: [invalid-argument-type] + f1(*args10) # error: [invalid-argument-type] +``` + +### Mixed argument and parameter containing variadic + +```toml +[environment] +python-version = "3.11" +``` + +```py +def f(x: int, *args: str) -> int: + return 1 + +def _( + args1: list[int], + args2: tuple[int], + args3: tuple[int, int], + args4: tuple[int, ...], + args5: tuple[int, *tuple[str, ...]], + args6: tuple[int, int, *tuple[str, ...]], +) -> None: + # error: [invalid-argument-type] "Argument to function `f` is incorrect: Expected `str`, found `int`" + reveal_type(f(*args1)) # revealed: int + + # This shouldn't raise an error because the unpacking doesn't match the variadic parameter. + reveal_type(f(*args2)) # revealed: int + + # But, this should because the second tuple element is not assignable. + # error: [invalid-argument-type] "Argument to function `f` is incorrect: Expected `str`, found `int`" + reveal_type(f(*args3)) # revealed: int + + # error: [invalid-argument-type] "Argument to function `f` is incorrect: Expected `str`, found `int`" + reveal_type(f(*args4)) # revealed: int + + # The first element of the tuple matches the required argument; + # all subsequent elements match the variadic argument + reveal_type(f(*args5)) # revealed: int + + # error: [invalid-argument-type] "Argument to function `f` is incorrect: Expected `str`, found `int`" + reveal_type(f(*args6)) # revealed: int +``` + ### Keyword argument, positional-or-keyword parameter ```py diff --git a/crates/ty_python_semantic/src/types/call/arguments.rs b/crates/ty_python_semantic/src/types/call/arguments.rs index ab0c9fece2..fc8bf871e5 100644 --- a/crates/ty_python_semantic/src/types/call/arguments.rs +++ b/crates/ty_python_semantic/src/types/call/arguments.rs @@ -6,7 +6,7 @@ use ruff_python_ast as ast; use crate::Db; use crate::types::KnownClass; use crate::types::enums::{enum_member_literals, enum_metadata}; -use crate::types::tuple::{Tuple, TupleLength, TupleType}; +use crate::types::tuple::{Tuple, TupleType}; use super::Type; @@ -17,7 +17,7 @@ pub(crate) enum Argument<'a> { /// A positional argument. Positional, /// A starred positional argument (e.g. `*args`) containing the specified number of elements. - Variadic(TupleLength), + Variadic, /// A keyword argument (e.g. `a=1`). Keyword(&'a str), /// The double-starred keywords argument (e.g. `**kwargs`). @@ -41,7 +41,6 @@ impl<'a, 'db> CallArguments<'a, 'db> { /// type of each splatted argument, so that we can determine its length. All other arguments /// will remain uninitialized as `Unknown`. pub(crate) fn from_arguments( - db: &'db dyn Db, arguments: &'a ast::Arguments, mut infer_argument_type: impl FnMut(Option<&ast::Expr>, &ast::Expr) -> Type<'db>, ) -> Self { @@ -51,11 +50,7 @@ impl<'a, 'db> CallArguments<'a, 'db> { ast::ArgOrKeyword::Arg(arg) => match arg { ast::Expr::Starred(ast::ExprStarred { value, .. }) => { let ty = infer_argument_type(Some(arg), value); - let length = ty - .try_iterate(db) - .map(|tuple| tuple.len()) - .unwrap_or(TupleLength::unknown()); - (Argument::Variadic(length), Some(ty)) + (Argument::Variadic, Some(ty)) } _ => (Argument::Positional, None), }, @@ -203,25 +198,10 @@ impl<'a, 'db> CallArguments<'a, 'db> { for subtype in &expanded_types { let mut new_expanded_types = pre_expanded_types.to_vec(); new_expanded_types[index] = Some(*subtype); - - // Update the arguments list to handle variadic argument expansion - let mut new_arguments = self.arguments.clone(); - if let Argument::Variadic(_) = self.arguments[index] { - // If the argument corresponding to this type is variadic, we need to - // update the tuple length because expanding could change the length. - // For example, in `tuple[int] | tuple[int, int]`, the length of the - // first type is 1, while the length of the second type is 2. - if let Some(expanded_type) = new_expanded_types[index] { - let length = expanded_type - .try_iterate(db) - .map(|tuple| tuple.len()) - .unwrap_or(TupleLength::unknown()); - new_arguments[index] = Argument::Variadic(length); - } - } - - expanded_arguments - .push(CallArguments::new(new_arguments, new_expanded_types)); + expanded_arguments.push(CallArguments::new( + self.arguments.clone(), + new_expanded_types, + )); } } diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index f5b2e4e349..46628983e2 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -2135,24 +2135,36 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> { Ok(()) } + /// Match a variadic argument to the remaining positional, standard or variadic parameters. fn match_variadic( &mut self, db: &'db dyn Db, argument_index: usize, argument: Argument<'a>, argument_type: Option>, - length: TupleLength, ) -> Result<(), ()> { let tuple = argument_type.map(|ty| ty.iterate(db)); - let mut argument_types = match tuple.as_ref() { - Some(tuple) => Either::Left(tuple.all_elements().copied()), - None => Either::Right(std::iter::empty()), + let (mut argument_types, length, variable_element) = match tuple.as_ref() { + Some(tuple) => ( + Either::Left(tuple.all_elements().copied()), + tuple.len(), + tuple.variable_element().copied(), + ), + None => ( + Either::Right(std::iter::empty()), + TupleLength::unknown(), + None, + ), }; // We must be able to match up the fixed-length portion of the argument with positional // parameters, so we pass on any errors that occur. for _ in 0..length.minimum() { - self.match_positional(argument_index, argument, argument_types.next())?; + self.match_positional( + argument_index, + argument, + argument_types.next().or(variable_element), + )?; } // If the tuple is variable-length, we assume that it will soak up all remaining positional @@ -2163,7 +2175,24 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> { .get_positional(self.next_positional) .is_some() { - self.match_positional(argument_index, argument, argument_types.next())?; + self.match_positional( + argument_index, + argument, + argument_types.next().or(variable_element), + )?; + } + } + + // Finally, if there is a variadic parameter we can match any of the remaining unpacked + // argument types to it, but only if there is at least one remaining argument type. This is + // because a variadic parameter is optional, so if this was done unconditionally, ty could + // raise a false positive as "too many arguments". + if self.parameters.variadic().is_some() { + if let Some(argument_type) = argument_types.next().or(variable_element) { + self.match_positional(argument_index, argument, Some(argument_type))?; + for argument_type in argument_types { + self.match_positional(argument_index, argument, Some(argument_type))?; + } } } @@ -2433,11 +2462,10 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { self.enumerate_argument_types() { match argument { - Argument::Variadic(_) => self.check_variadic_argument_type( + Argument::Variadic => self.check_variadic_argument_type( argument_index, adjusted_argument_index, argument, - argument_type, ), Argument::Keywords => self.check_keyword_variadic_argument_type( argument_index, @@ -2465,37 +2493,15 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { argument_index: usize, adjusted_argument_index: Option, argument: Argument<'a>, - argument_type: Type<'db>, ) { - // If the argument is splatted, convert its type into a tuple describing the splatted - // elements. For tuples, we don't have to do anything! For other types, we treat it as - // an iterator, and create a homogeneous tuple of its output type, since we don't know - // how many elements the iterator will produce. - let argument_types = argument_type.iterate(self.db); - - // Resize the tuple of argument types to line up with the number of parameters this - // argument was matched against. If parameter matching succeeded, then we can (TODO: - // should be able to, see above) guarantee that all of the required elements of the - // splatted tuple will have been matched with a parameter. But if parameter matching - // failed, there might be more required elements. That means we can't use - // TupleLength::Fixed below, because we would otherwise get a "too many values" error - // when parameter matching failed. - let desired_size = - TupleLength::Variable(self.argument_matches[argument_index].parameters.len(), 0); - let argument_types = argument_types - .resize(self.db, desired_size) - .expect("argument type should be consistent with its arity"); - - // Check the types by zipping through the splatted argument types and their matched - // parameters. - for (argument_type, parameter_index) in - (argument_types.all_elements()).zip(&self.argument_matches[argument_index].parameters) + for (parameter_index, variadic_argument_type) in + self.argument_matches[argument_index].iter() { self.check_argument_type( adjusted_argument_index, argument, - *argument_type, - *parameter_index, + variadic_argument_type.unwrap_or_else(Type::unknown), + parameter_index, ); } } @@ -2711,9 +2717,8 @@ impl<'db> Binding<'db> { Argument::Keyword(name) => { let _ = matcher.match_keyword(argument_index, argument, None, name); } - Argument::Variadic(length) => { - let _ = - matcher.match_variadic(db, argument_index, argument, argument_type, length); + Argument::Variadic => { + let _ = matcher.match_variadic(db, argument_index, argument, argument_type); } Argument::Keywords => { keywords_arguments.push((argument_index, argument_type)); diff --git a/crates/ty_python_semantic/src/types/ide_support.rs b/crates/ty_python_semantic/src/types/ide_support.rs index 4f435c89b0..1b00a41007 100644 --- a/crates/ty_python_semantic/src/types/ide_support.rs +++ b/crates/ty_python_semantic/src/types/ide_support.rs @@ -874,7 +874,7 @@ pub fn call_signature_details<'db>( // Use into_callable to handle all the complex type conversions if let Some(callable_type) = func_type.into_callable(db) { let call_arguments = - CallArguments::from_arguments(db, &call_expr.arguments, |_, splatted_value| { + CallArguments::from_arguments(&call_expr.arguments, |_, splatted_value| { splatted_value.inferred_type(model) }); let bindings = callable_type diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 6f20e78272..298c3b4644 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -1733,7 +1733,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let previous_deferred_state = std::mem::replace(&mut self.deferred_state, in_stub.into()); let mut call_arguments = - CallArguments::from_arguments(self.db(), arguments, |argument, splatted_value| { + CallArguments::from_arguments(arguments, |argument, splatted_value| { let ty = self.infer_expression(splatted_value, TypeContext::default()); if let Some(argument) = argument { self.store_expression_type(argument, ty); @@ -5831,7 +5831,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // arguments after matching them to parameters, but before checking that the argument types // are assignable to any parameter annotations. let mut call_arguments = - CallArguments::from_arguments(self.db(), arguments, |argument, splatted_value| { + CallArguments::from_arguments(arguments, |argument, splatted_value| { let ty = self.infer_expression(splatted_value, TypeContext::default()); if let Some(argument) = argument { self.store_expression_type(argument, ty); diff --git a/crates/ty_python_semantic/src/types/tuple.rs b/crates/ty_python_semantic/src/types/tuple.rs index 2f10e79949..1b0ab92fb5 100644 --- a/crates/ty_python_semantic/src/types/tuple.rs +++ b/crates/ty_python_semantic/src/types/tuple.rs @@ -970,6 +970,14 @@ impl Tuple { FixedLengthTuple::from_elements(elements).into() } + /// Returns the variable-length element of this tuple, if it has one. + pub(crate) fn variable_element(&self) -> Option<&T> { + match self { + Tuple::Fixed(_) => None, + Tuple::Variable(tuple) => Some(&tuple.variable), + } + } + /// Returns an iterator of all of the fixed-length element types of this tuple. pub(crate) fn fixed_elements(&self) -> impl Iterator + '_ { match self {