mirror of
https://github.com/astral-sh/ruff.git
synced 2025-09-29 21:34:57 +00:00
[ty] Unpack variadic argument type in specialization (#20130)
Some checks are pending
CI / Determine changes (push) Waiting to run
CI / cargo fmt (push) Waiting to run
CI / cargo clippy (push) Blocked by required conditions
CI / cargo test (linux) (push) Blocked by required conditions
CI / cargo test (linux, release) (push) Blocked by required conditions
CI / cargo test (windows) (push) Blocked by required conditions
CI / cargo test (wasm) (push) Blocked by required conditions
CI / cargo build (release) (push) Waiting to run
CI / cargo build (msrv) (push) Blocked by required conditions
CI / cargo fuzz build (push) Blocked by required conditions
CI / fuzz parser (push) Blocked by required conditions
CI / test scripts (push) Blocked by required conditions
CI / ecosystem (push) Blocked by required conditions
CI / Fuzz for new ty panics (push) Blocked by required conditions
CI / cargo shear (push) Blocked by required conditions
CI / python package (push) Waiting to run
CI / pre-commit (push) Waiting to run
CI / mkdocs (push) Waiting to run
CI / formatter instabilities and black similarity (push) Blocked by required conditions
CI / test ruff-lsp (push) Blocked by required conditions
CI / check playground (push) Blocked by required conditions
CI / benchmarks-instrumented (push) Blocked by required conditions
CI / benchmarks-walltime (push) Blocked by required conditions
[ty Playground] Release / publish (push) Waiting to run
Some checks are pending
CI / Determine changes (push) Waiting to run
CI / cargo fmt (push) Waiting to run
CI / cargo clippy (push) Blocked by required conditions
CI / cargo test (linux) (push) Blocked by required conditions
CI / cargo test (linux, release) (push) Blocked by required conditions
CI / cargo test (windows) (push) Blocked by required conditions
CI / cargo test (wasm) (push) Blocked by required conditions
CI / cargo build (release) (push) Waiting to run
CI / cargo build (msrv) (push) Blocked by required conditions
CI / cargo fuzz build (push) Blocked by required conditions
CI / fuzz parser (push) Blocked by required conditions
CI / test scripts (push) Blocked by required conditions
CI / ecosystem (push) Blocked by required conditions
CI / Fuzz for new ty panics (push) Blocked by required conditions
CI / cargo shear (push) Blocked by required conditions
CI / python package (push) Waiting to run
CI / pre-commit (push) Waiting to run
CI / mkdocs (push) Waiting to run
CI / formatter instabilities and black similarity (push) Blocked by required conditions
CI / test ruff-lsp (push) Blocked by required conditions
CI / check playground (push) Blocked by required conditions
CI / benchmarks-instrumented (push) Blocked by required conditions
CI / benchmarks-walltime (push) Blocked by required conditions
[ty Playground] Release / publish (push) Waiting to run
## Summary This PR fixes various TODOs around overload call when a variadic argument is used. The reason this bug existed is because the specialization wouldn't account for unpacking the type of the variadic argument. This is fixed by expanding `MatchedArgument` to contain the type of that argument _only_ when it is a variadic argument. The reason is that there's a split for when the argument type is inferred -- the non-variadic arguments are inferred using `infer_argument_types` _after_ parameter matching while the variadic argument type is inferred _during_ the parameter matching. And, the `MatchedArgument` is populated _during_ parameter matching which means the unpacking would need to happen during parameter matching. This split seems a bit inconsistent but I don't want to spend a lot of time on trying to merge them such that all argument type inference happens in a single place. I might look into it while adding support for `**kwargs`. ## Test Plan Update existing tests by resolving the todos. The ecosystem changes looks correct to me except for the `slice` call but it seems that it's unrelated to this PR as we infer `slice[Any, Any, Any]` for a `slice(1, 2, 3)` call on `main` as well ([playground](https://play.ty.dev/9eacce00-c7d5-4dd5-a932-4265cb2bb4f6)).
This commit is contained in:
parent
a8039f80f0
commit
4ca38b2974
5 changed files with 94 additions and 53 deletions
|
@ -4843,7 +4843,7 @@ impl<'db> Type<'db> {
|
|||
argument_types: &CallArguments<'_, 'db>,
|
||||
) -> Result<Bindings<'db>, CallError<'db>> {
|
||||
self.bindings(db)
|
||||
.match_parameters(argument_types)
|
||||
.match_parameters(db, argument_types)
|
||||
.check_types(db, argument_types)
|
||||
}
|
||||
|
||||
|
@ -4892,7 +4892,7 @@ impl<'db> Type<'db> {
|
|||
Place::Type(dunder_callable, boundness) => {
|
||||
let bindings = dunder_callable
|
||||
.bindings(db)
|
||||
.match_parameters(argument_types)
|
||||
.match_parameters(db, argument_types)
|
||||
.check_types(db, argument_types)?;
|
||||
if boundness == Boundness::PossiblyUnbound {
|
||||
return Err(CallDunderError::PossiblyUnbound(Box::new(bindings)));
|
||||
|
|
|
@ -7,7 +7,7 @@ use std::borrow::Cow;
|
|||
use std::collections::HashSet;
|
||||
use std::fmt;
|
||||
|
||||
use itertools::Itertools;
|
||||
use itertools::{Either, Itertools};
|
||||
use ruff_db::parsed::parsed_module;
|
||||
use smallvec::{SmallVec, smallvec, smallvec_inline};
|
||||
|
||||
|
@ -101,11 +101,15 @@ impl<'db> Bindings<'db> {
|
|||
///
|
||||
/// Once you have argument types available, you can call [`check_types`][Self::check_types] to
|
||||
/// verify that each argument type is assignable to the corresponding parameter type.
|
||||
pub(crate) fn match_parameters(mut self, arguments: &CallArguments<'_, 'db>) -> Self {
|
||||
pub(crate) fn match_parameters(
|
||||
mut self,
|
||||
db: &'db dyn Db,
|
||||
arguments: &CallArguments<'_, 'db>,
|
||||
) -> Self {
|
||||
let mut argument_forms = vec![None; arguments.len()];
|
||||
let mut conflicting_forms = vec![false; arguments.len()];
|
||||
for binding in &mut self.elements {
|
||||
binding.match_parameters(arguments, &mut argument_forms, &mut conflicting_forms);
|
||||
binding.match_parameters(db, arguments, &mut argument_forms, &mut conflicting_forms);
|
||||
}
|
||||
self.argument_forms = argument_forms.into();
|
||||
self.conflicting_forms = conflicting_forms.into();
|
||||
|
@ -1243,6 +1247,7 @@ impl<'db> CallableBinding<'db> {
|
|||
|
||||
fn match_parameters(
|
||||
&mut self,
|
||||
db: &'db dyn Db,
|
||||
arguments: &CallArguments<'_, 'db>,
|
||||
argument_forms: &mut [Option<ParameterForm>],
|
||||
conflicting_forms: &mut [bool],
|
||||
|
@ -1252,7 +1257,7 @@ impl<'db> CallableBinding<'db> {
|
|||
let arguments = arguments.with_self(self.bound_type);
|
||||
|
||||
for overload in &mut self.overloads {
|
||||
overload.match_parameters(arguments.as_ref(), argument_forms, conflicting_forms);
|
||||
overload.match_parameters(db, arguments.as_ref(), argument_forms, conflicting_forms);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1903,7 +1908,7 @@ struct ArgumentMatcher<'a, 'db> {
|
|||
conflicting_forms: &'a mut [bool],
|
||||
errors: &'a mut Vec<BindingError<'db>>,
|
||||
|
||||
argument_matches: Vec<MatchedArgument>,
|
||||
argument_matches: Vec<MatchedArgument<'db>>,
|
||||
parameter_matched: Vec<bool>,
|
||||
next_positional: usize,
|
||||
first_excess_positional: Option<usize>,
|
||||
|
@ -1947,6 +1952,7 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
|
|||
&mut self,
|
||||
argument_index: usize,
|
||||
argument: Argument<'a>,
|
||||
argument_type: Option<Type<'db>>,
|
||||
parameter_index: usize,
|
||||
parameter: &Parameter<'db>,
|
||||
positional: bool,
|
||||
|
@ -1970,6 +1976,7 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
|
|||
}
|
||||
let matched_argument = &mut self.argument_matches[argument_index];
|
||||
matched_argument.parameters.push(parameter_index);
|
||||
matched_argument.types.push(argument_type);
|
||||
matched_argument.matched = true;
|
||||
self.parameter_matched[parameter_index] = true;
|
||||
}
|
||||
|
@ -1978,6 +1985,7 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
|
|||
&mut self,
|
||||
argument_index: usize,
|
||||
argument: Argument<'a>,
|
||||
argument_type: Option<Type<'db>>,
|
||||
) -> Result<(), ()> {
|
||||
if matches!(argument, Argument::Synthetic) {
|
||||
self.num_synthetic_args += 1;
|
||||
|
@ -1996,6 +2004,7 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
|
|||
self.assign_argument(
|
||||
argument_index,
|
||||
argument,
|
||||
argument_type,
|
||||
parameter_index,
|
||||
parameter,
|
||||
!parameter.is_variadic(),
|
||||
|
@ -2020,20 +2029,35 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
|
|||
});
|
||||
return Err(());
|
||||
};
|
||||
self.assign_argument(argument_index, argument, parameter_index, parameter, false);
|
||||
self.assign_argument(
|
||||
argument_index,
|
||||
argument,
|
||||
None,
|
||||
parameter_index,
|
||||
parameter,
|
||||
false,
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn match_variadic(
|
||||
&mut self,
|
||||
db: &'db dyn Db,
|
||||
argument_index: usize,
|
||||
argument: Argument<'a>,
|
||||
argument_type: Option<Type<'db>>,
|
||||
length: TupleLength,
|
||||
) -> Result<(), ()> {
|
||||
let tuple = argument_type.map(|ty| ty.iterate(db));
|
||||
let mut argument_types = match tuple.as_ref() {
|
||||
Some(tuple) => Either::Left(tuple.all_elements().copied()),
|
||||
None => Either::Right(std::iter::empty()),
|
||||
};
|
||||
|
||||
// We must be able to match up the fixed-length portion of the argument with positional
|
||||
// parameters, so we pass on any errors that occur.
|
||||
for _ in 0..length.minimum() {
|
||||
self.match_positional(argument_index, argument)?;
|
||||
self.match_positional(argument_index, argument, argument_types.next())?;
|
||||
}
|
||||
|
||||
// If the tuple is variable-length, we assume that it will soak up all remaining positional
|
||||
|
@ -2044,14 +2068,14 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
|
|||
.get_positional(self.next_positional)
|
||||
.is_some()
|
||||
{
|
||||
self.match_positional(argument_index, argument)?;
|
||||
self.match_positional(argument_index, argument, argument_types.next())?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn finish(self) -> Box<[MatchedArgument]> {
|
||||
fn finish(self) -> Box<[MatchedArgument<'db>]> {
|
||||
if let Some(first_excess_argument_index) = self.first_excess_positional {
|
||||
self.errors.push(BindingError::TooManyPositionalArguments {
|
||||
first_excess_argument_index: self.get_argument_index(first_excess_argument_index),
|
||||
|
@ -2088,7 +2112,7 @@ struct ArgumentTypeChecker<'a, 'db> {
|
|||
db: &'db dyn Db,
|
||||
signature: &'a Signature<'db>,
|
||||
arguments: &'a CallArguments<'a, 'db>,
|
||||
argument_matches: &'a [MatchedArgument],
|
||||
argument_matches: &'a [MatchedArgument<'db>],
|
||||
parameter_tys: &'a mut [Option<Type<'db>>],
|
||||
errors: &'a mut Vec<BindingError<'db>>,
|
||||
|
||||
|
@ -2101,7 +2125,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
|
|||
db: &'db dyn Db,
|
||||
signature: &'a Signature<'db>,
|
||||
arguments: &'a CallArguments<'a, 'db>,
|
||||
argument_matches: &'a [MatchedArgument],
|
||||
argument_matches: &'a [MatchedArgument<'db>],
|
||||
parameter_tys: &'a mut [Option<Type<'db>>],
|
||||
errors: &'a mut Vec<BindingError<'db>>,
|
||||
) -> Self {
|
||||
|
@ -2156,12 +2180,17 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
|
|||
for (argument_index, adjusted_argument_index, _, argument_type) in
|
||||
self.enumerate_argument_types()
|
||||
{
|
||||
for parameter_index in &self.argument_matches[argument_index].parameters {
|
||||
let parameter = ¶meters[*parameter_index];
|
||||
for (parameter_index, variadic_argument_type) in
|
||||
self.argument_matches[argument_index].iter()
|
||||
{
|
||||
let parameter = ¶meters[parameter_index];
|
||||
let Some(expected_type) = parameter.annotated_type() else {
|
||||
continue;
|
||||
};
|
||||
if let Err(error) = builder.infer(expected_type, argument_type) {
|
||||
if let Err(error) = builder.infer(
|
||||
expected_type,
|
||||
variadic_argument_type.unwrap_or(argument_type),
|
||||
) {
|
||||
self.errors.push(BindingError::SpecializationError {
|
||||
error,
|
||||
argument_index: adjusted_argument_index,
|
||||
|
@ -2305,7 +2334,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
|
|||
/// Information about which parameter(s) an argument was matched against. This is tracked
|
||||
/// separately for each overload.
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct MatchedArgument {
|
||||
pub struct MatchedArgument<'db> {
|
||||
/// The index of the parameter(s) that an argument was matched against. A splatted argument
|
||||
/// might be matched against multiple parameters.
|
||||
pub parameters: SmallVec<[usize; 1]>,
|
||||
|
@ -2314,6 +2343,33 @@ pub struct MatchedArgument {
|
|||
/// elements must have been successfully matched. (That means that this can be `false` while
|
||||
/// the `parameters` field is non-empty.)
|
||||
pub matched: bool,
|
||||
|
||||
/// The types of a variadic argument when it's unpacked.
|
||||
///
|
||||
/// The length of this vector is always the same as the `parameters` vector i.e., these are the
|
||||
/// types assigned to each matched parameter. This isn't necessarily the same as the number of
|
||||
/// types in the argument type which might not be a fixed-length iterable.
|
||||
///
|
||||
/// Another thing to note is that the way this is populated means that for any other argument
|
||||
/// kind (synthetic, positional, keyword, keyword-variadic), this will be a single-element
|
||||
/// vector containing `None`, since we don't know the type of the argument when this is
|
||||
/// constructed. So, this field is populated only for variadic arguments.
|
||||
///
|
||||
/// For example, given a `*args` whose type is `tuple[A, B, C]` and the following parameters:
|
||||
/// - `(x, *args)`: the `types` field will only have two elements (`B`, `C`) since `A` has been
|
||||
/// matched with `x`.
|
||||
/// - `(*args)`: the `types` field will have all the three elements (`A`, `B`, `C`)
|
||||
types: SmallVec<[Option<Type<'db>>; 1]>,
|
||||
}
|
||||
|
||||
impl<'db> MatchedArgument<'db> {
|
||||
/// Returns an iterator over the parameter indices and the corresponding argument type.
|
||||
pub fn iter(&self) -> impl Iterator<Item = (usize, Option<Type<'db>>)> + '_ {
|
||||
self.parameters
|
||||
.iter()
|
||||
.copied()
|
||||
.zip(self.types.iter().copied())
|
||||
}
|
||||
}
|
||||
|
||||
/// Binding information for one of the overloads of a callable.
|
||||
|
@ -2341,7 +2397,7 @@ pub(crate) struct Binding<'db> {
|
|||
|
||||
/// Information about which parameter(s) each argument was matched with, in argument source
|
||||
/// order.
|
||||
argument_matches: Box<[MatchedArgument]>,
|
||||
argument_matches: Box<[MatchedArgument<'db>]>,
|
||||
|
||||
/// Bound types for parameters, in parameter source order, or `None` if no argument was matched
|
||||
/// to that parameter.
|
||||
|
@ -2374,6 +2430,7 @@ impl<'db> Binding<'db> {
|
|||
|
||||
pub(crate) fn match_parameters(
|
||||
&mut self,
|
||||
db: &'db dyn Db,
|
||||
arguments: &CallArguments<'_, 'db>,
|
||||
argument_forms: &mut [Option<ParameterForm>],
|
||||
conflicting_forms: &mut [bool],
|
||||
|
@ -2386,16 +2443,17 @@ impl<'db> Binding<'db> {
|
|||
conflicting_forms,
|
||||
&mut self.errors,
|
||||
);
|
||||
for (argument_index, (argument, _)) in arguments.iter().enumerate() {
|
||||
for (argument_index, (argument, argument_type)) in arguments.iter().enumerate() {
|
||||
match argument {
|
||||
Argument::Positional | Argument::Synthetic => {
|
||||
let _ = matcher.match_positional(argument_index, argument);
|
||||
let _ = matcher.match_positional(argument_index, argument, None);
|
||||
}
|
||||
Argument::Keyword(name) => {
|
||||
let _ = matcher.match_keyword(argument_index, argument, name);
|
||||
}
|
||||
Argument::Variadic(length) => {
|
||||
let _ = matcher.match_variadic(argument_index, argument, length);
|
||||
let _ =
|
||||
matcher.match_variadic(db, argument_index, argument, argument_type, length);
|
||||
}
|
||||
Argument::Keywords => {
|
||||
// TODO
|
||||
|
@ -2532,7 +2590,7 @@ impl<'db> Binding<'db> {
|
|||
|
||||
/// Returns a vector where each index corresponds to an argument position,
|
||||
/// and the value is the parameter index that argument maps to (if any).
|
||||
pub(crate) fn argument_matches(&self) -> &[MatchedArgument] {
|
||||
pub(crate) fn argument_matches(&self) -> &[MatchedArgument<'db>] {
|
||||
&self.argument_matches
|
||||
}
|
||||
}
|
||||
|
@ -2542,7 +2600,7 @@ struct BindingSnapshot<'db> {
|
|||
return_ty: Type<'db>,
|
||||
specialization: Option<Specialization<'db>>,
|
||||
inherited_specialization: Option<Specialization<'db>>,
|
||||
argument_matches: Box<[MatchedArgument]>,
|
||||
argument_matches: Box<[MatchedArgument<'db>]>,
|
||||
parameter_tys: Box<[Option<Type<'db>>]>,
|
||||
errors: Vec<BindingError<'db>>,
|
||||
}
|
||||
|
|
|
@ -801,7 +801,7 @@ pub struct CallSignatureDetails<'db> {
|
|||
|
||||
/// Mapping from argument indices to parameter indices. This helps
|
||||
/// determine which parameter corresponds to which argument position.
|
||||
pub argument_to_parameter_mapping: Vec<MatchedArgument>,
|
||||
pub argument_to_parameter_mapping: Vec<MatchedArgument<'db>>,
|
||||
}
|
||||
|
||||
/// Extract signature details from a function call expression.
|
||||
|
@ -821,7 +821,9 @@ pub fn call_signature_details<'db>(
|
|||
CallArguments::from_arguments(db, &call_expr.arguments, |_, splatted_value| {
|
||||
splatted_value.inferred_type(model)
|
||||
});
|
||||
let bindings = callable_type.bindings(db).match_parameters(&call_arguments);
|
||||
let bindings = callable_type
|
||||
.bindings(db)
|
||||
.match_parameters(db, &call_arguments);
|
||||
|
||||
// Extract signature details from all callable bindings
|
||||
bindings
|
||||
|
|
|
@ -6360,7 +6360,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
|
||||
let bindings = callable_type
|
||||
.bindings(self.db())
|
||||
.match_parameters(&call_arguments);
|
||||
.match_parameters(self.db(), &call_arguments);
|
||||
self.infer_argument_types(arguments, &mut call_arguments, &bindings.argument_forms);
|
||||
|
||||
// Validate `TypedDict` constructor calls after argument type inference
|
||||
|
@ -8812,7 +8812,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
};
|
||||
let binding = Binding::single(value_ty, generic_context.signature(self.db()));
|
||||
let bindings = match Bindings::from(binding)
|
||||
.match_parameters(&call_argument_types)
|
||||
.match_parameters(self.db(), &call_argument_types)
|
||||
.check_types(self.db(), &call_argument_types)
|
||||
{
|
||||
Ok(bindings) => bindings,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue