mirror of
https://github.com/astral-sh/ruff.git
synced 2025-09-30 13:51:16 +00:00
[ty] Retry parameter matching for argument type expansion (#20153)
## Summary This PR addresses an issue for a variadic argument when involved in argument type expansion of overload call evaluation. The issue is that the expansion of the variadic argument could result in argument list of different arity. For example, in `*args: tuple[int] | tuple[int, str]`, the expansion would lead to the variadic argument being unpacked into 1 and 2 element respectively. This means that the parameter matching that was performed initially isn't sufficient and each expanded argument list would need to redo the parameter matching again. This is currently done by redoing the parameter matching directly, maintaining the state of argument forms (and the conflicting forms), and updating the `Bindings` values if it changes. Closes: astral-sh/ty#735 ## Test Plan Update existing mdtest.
This commit is contained in:
parent
1cd8ab3f26
commit
bb9be263c7
6 changed files with 205 additions and 156 deletions
|
@ -613,56 +613,6 @@ def _(args: str) -> None:
|
||||||
takes_at_least_two_positional_only(*args)
|
takes_at_least_two_positional_only(*args)
|
||||||
```
|
```
|
||||||
|
|
||||||
### Argument expansion regression
|
|
||||||
|
|
||||||
This is a regression that was highlighted by the ecosystem check, which shows that we might need to
|
|
||||||
rethink how we perform argument expansion during overload resolution. In particular, we might need
|
|
||||||
to retry both `match_parameters` *and* `check_types` for each expansion. Currently we only retry
|
|
||||||
`check_types`.
|
|
||||||
|
|
||||||
The issue is that argument expansion might produce a splatted value with a different arity than what
|
|
||||||
we originally inferred for the unexpanded value, and that in turn can affect which parameters the
|
|
||||||
splatted value is matched with.
|
|
||||||
|
|
||||||
The first example correctly produces an error. The `tuple[int, str]` union element has a precise
|
|
||||||
arity of two, and so parameter matching chooses the first overload. The second element of the tuple
|
|
||||||
does not match the second parameter type, which yielding an `invalid-argument-type` error.
|
|
||||||
|
|
||||||
The third example should produce the same error. However, because we have a union, we do not see the
|
|
||||||
precise arity of each union element during parameter matching. Instead, we infer an arity of "zero
|
|
||||||
or more" for the union as a whole, and use that less precise arity when matching parameters. We
|
|
||||||
therefore consider the second overload to still be a potential candidate for the `tuple[int, str]`
|
|
||||||
union element. During type checking, we have to force the arity of each union element to match the
|
|
||||||
inferred arity of the union as a whole (turning `tuple[int, str]` into `tuple[int | str, ...]`).
|
|
||||||
That less precise tuple type-checks successfully against the second overload, making us incorrectly
|
|
||||||
think that `tuple[int, str]` is a valid splatted call.
|
|
||||||
|
|
||||||
If we update argument expansion to retry parameter matching with the precise arity of each union
|
|
||||||
element, we will correctly rule out the second overload for `tuple[int, str]`, just like we do when
|
|
||||||
splatting that tuple directly (instead of as part of a union).
|
|
||||||
|
|
||||||
```py
|
|
||||||
from typing import overload
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def f(x: int, y: int) -> None: ...
|
|
||||||
@overload
|
|
||||||
def f(x: int, y: str, z: int) -> None: ...
|
|
||||||
def f(*args): ...
|
|
||||||
|
|
||||||
# Test all of the above with a number of different splatted argument types
|
|
||||||
|
|
||||||
def _(t: tuple[int, str]) -> None:
|
|
||||||
f(*t) # error: [invalid-argument-type]
|
|
||||||
|
|
||||||
def _(t: tuple[int, str, int]) -> None:
|
|
||||||
f(*t)
|
|
||||||
|
|
||||||
def _(t: tuple[int, str] | tuple[int, str, int]) -> None:
|
|
||||||
# TODO: error: [invalid-argument-type]
|
|
||||||
f(*t)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Wrong argument type
|
## Wrong argument type
|
||||||
|
|
||||||
### Positional argument, positional-or-keyword parameter
|
### Positional argument, positional-or-keyword parameter
|
||||||
|
|
|
@ -889,6 +889,48 @@ def _(a: int | None):
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Retry from parameter matching
|
||||||
|
|
||||||
|
As per the spec, the argument type expansion should retry evaluating the expanded argument list from
|
||||||
|
the type checking step. However, that creates an issue when variadic arguments are involved because
|
||||||
|
if a variadic argument is a union type, it could be expanded to have different arities. So, ty
|
||||||
|
retries it from the start which includes parameter matching as well.
|
||||||
|
|
||||||
|
`overloaded.pyi`:
|
||||||
|
|
||||||
|
```pyi
|
||||||
|
from typing import overload
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def f(x: int, y: int) -> None: ...
|
||||||
|
@overload
|
||||||
|
def f(x: int, y: str, z: int) -> None: ...
|
||||||
|
```
|
||||||
|
|
||||||
|
```py
|
||||||
|
from overloaded import f
|
||||||
|
|
||||||
|
# Test all of the above with a number of different splatted argument types
|
||||||
|
|
||||||
|
def _(t: tuple[int, str]) -> None:
|
||||||
|
# This correctly produces an error because the first element of the union has a precise arity of
|
||||||
|
# 2, which matches the first overload, but the second element of the tuple doesn't match the
|
||||||
|
# second parameter type, yielding an `invalid-argument-type` error.
|
||||||
|
f(*t) # error: [invalid-argument-type]
|
||||||
|
|
||||||
|
def _(t: tuple[int, str, int]) -> None:
|
||||||
|
# This correctly produces no error because the first element of the union has a precise arity of
|
||||||
|
# 3, which matches the second overload.
|
||||||
|
f(*t)
|
||||||
|
|
||||||
|
def _(t: tuple[int, str] | tuple[int, str, int]) -> None:
|
||||||
|
# This produces an error because the expansion produces two argument lists: `[*tuple[int, str]]`
|
||||||
|
# and `[*tuple[int, str, int]]`. The first list produces produces a type checking error as
|
||||||
|
# described in the first example, while the second list matches the second overload. And,
|
||||||
|
# because not all of the expanded argument list evaluates successfully, we produce an error.
|
||||||
|
f(*t) # error: [no-matching-overload]
|
||||||
|
```
|
||||||
|
|
||||||
## Filtering based on `Any` / `Unknown`
|
## Filtering based on `Any` / `Unknown`
|
||||||
|
|
||||||
This is the step 5 of the overload call evaluation algorithm which specifies that:
|
This is the step 5 of the overload call evaluation algorithm which specifies that:
|
||||||
|
|
|
@ -202,10 +202,25 @@ impl<'a, 'db> CallArguments<'a, 'db> {
|
||||||
for subtype in &expanded_types {
|
for subtype in &expanded_types {
|
||||||
let mut new_expanded_types = pre_expanded_types.to_vec();
|
let mut new_expanded_types = pre_expanded_types.to_vec();
|
||||||
new_expanded_types[index] = Some(*subtype);
|
new_expanded_types[index] = Some(*subtype);
|
||||||
expanded_arguments.push(CallArguments::new(
|
|
||||||
self.arguments.clone(),
|
// Update the arguments list to handle variadic argument expansion
|
||||||
new_expanded_types,
|
let mut new_arguments = self.arguments.clone();
|
||||||
));
|
if let Argument::Variadic(_) = self.arguments[index] {
|
||||||
|
// If the argument corresponding to this type is variadic, we need to
|
||||||
|
// update the tuple length because expanding could change the length.
|
||||||
|
// For example, in `tuple[int] | tuple[int, int]`, the length of the
|
||||||
|
// first type is 1, while the length of the second type is 2.
|
||||||
|
if let Some(expanded_type) = new_expanded_types[index] {
|
||||||
|
let length = expanded_type
|
||||||
|
.try_iterate(db)
|
||||||
|
.map(|tuple| tuple.len())
|
||||||
|
.unwrap_or(TupleLength::unknown());
|
||||||
|
new_arguments[index] = Argument::Variadic(length);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
expanded_arguments
|
||||||
|
.push(CallArguments::new(new_arguments, new_expanded_types));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,6 @@
|
||||||
//! [signatures][crate::types::signatures], we have to handle the fact that the callable might be a
|
//! [signatures][crate::types::signatures], we have to handle the fact that the callable might be a
|
||||||
//! union of types, each of which might contain multiple overloads.
|
//! union of types, each of which might contain multiple overloads.
|
||||||
|
|
||||||
use std::borrow::Cow;
|
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
|
@ -28,7 +27,7 @@ use crate::types::function::{
|
||||||
};
|
};
|
||||||
use crate::types::generics::{Specialization, SpecializationBuilder, SpecializationError};
|
use crate::types::generics::{Specialization, SpecializationBuilder, SpecializationError};
|
||||||
use crate::types::signatures::{Parameter, ParameterForm, Parameters};
|
use crate::types::signatures::{Parameter, ParameterForm, Parameters};
|
||||||
use crate::types::tuple::{Tuple, TupleLength, TupleType};
|
use crate::types::tuple::{TupleLength, TupleType};
|
||||||
use crate::types::{
|
use crate::types::{
|
||||||
BoundMethodType, ClassLiteral, DataclassParams, FieldInstance, KnownBoundMethodType,
|
BoundMethodType, ClassLiteral, DataclassParams, FieldInstance, KnownBoundMethodType,
|
||||||
KnownClass, KnownInstanceType, PropertyInstanceType, SpecialFormType, TrackedConstraintSet,
|
KnownClass, KnownInstanceType, PropertyInstanceType, SpecialFormType, TrackedConstraintSet,
|
||||||
|
@ -51,9 +50,7 @@ pub(crate) struct Bindings<'db> {
|
||||||
elements: SmallVec<[CallableBinding<'db>; 1]>,
|
elements: SmallVec<[CallableBinding<'db>; 1]>,
|
||||||
|
|
||||||
/// Whether each argument will be used as a value and/or a type form in this call.
|
/// Whether each argument will be used as a value and/or a type form in this call.
|
||||||
pub(crate) argument_forms: Box<[Option<ParameterForm>]>,
|
argument_forms: ArgumentForms,
|
||||||
|
|
||||||
conflicting_forms: Box<[bool]>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'db> Bindings<'db> {
|
impl<'db> Bindings<'db> {
|
||||||
|
@ -71,8 +68,7 @@ impl<'db> Bindings<'db> {
|
||||||
Self {
|
Self {
|
||||||
callable_type,
|
callable_type,
|
||||||
elements,
|
elements,
|
||||||
argument_forms: Box::from([]),
|
argument_forms: ArgumentForms::new(0),
|
||||||
conflicting_forms: Box::from([]),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -91,6 +87,10 @@ impl<'db> Bindings<'db> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn argument_forms(&self) -> &[Option<ParameterForm>] {
|
||||||
|
&self.argument_forms.values
|
||||||
|
}
|
||||||
|
|
||||||
/// Match the arguments of a call site against the parameters of a collection of possibly
|
/// Match the arguments of a call site against the parameters of a collection of possibly
|
||||||
/// unioned, possibly overloaded signatures.
|
/// unioned, possibly overloaded signatures.
|
||||||
///
|
///
|
||||||
|
@ -105,13 +105,12 @@ impl<'db> Bindings<'db> {
|
||||||
db: &'db dyn Db,
|
db: &'db dyn Db,
|
||||||
arguments: &CallArguments<'_, 'db>,
|
arguments: &CallArguments<'_, 'db>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let mut argument_forms = vec![None; arguments.len()];
|
let mut argument_forms = ArgumentForms::new(arguments.len());
|
||||||
let mut conflicting_forms = vec![false; arguments.len()];
|
|
||||||
for binding in &mut self.elements {
|
for binding in &mut self.elements {
|
||||||
binding.match_parameters(db, arguments, &mut argument_forms, &mut conflicting_forms);
|
binding.match_parameters(db, arguments, &mut argument_forms);
|
||||||
}
|
}
|
||||||
self.argument_forms = argument_forms.into();
|
argument_forms.shrink_to_fit();
|
||||||
self.conflicting_forms = conflicting_forms.into();
|
self.argument_forms = argument_forms;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -130,7 +129,12 @@ impl<'db> Bindings<'db> {
|
||||||
argument_types: &CallArguments<'_, 'db>,
|
argument_types: &CallArguments<'_, 'db>,
|
||||||
) -> Result<Self, CallError<'db>> {
|
) -> Result<Self, CallError<'db>> {
|
||||||
for element in &mut self.elements {
|
for element in &mut self.elements {
|
||||||
element.check_types(db, argument_types);
|
if let Some(mut updated_argument_forms) = element.check_types(db, argument_types) {
|
||||||
|
// If this element returned a new set of argument forms (indicating successful
|
||||||
|
// argument type expansion), update the `Bindings` with these forms.
|
||||||
|
updated_argument_forms.shrink_to_fit();
|
||||||
|
self.argument_forms = updated_argument_forms;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
self.evaluate_known_cases(db);
|
self.evaluate_known_cases(db);
|
||||||
|
@ -153,7 +157,7 @@ impl<'db> Bindings<'db> {
|
||||||
let mut all_ok = true;
|
let mut all_ok = true;
|
||||||
let mut any_binding_error = false;
|
let mut any_binding_error = false;
|
||||||
let mut all_not_callable = true;
|
let mut all_not_callable = true;
|
||||||
if self.conflicting_forms.contains(&true) {
|
if self.argument_forms.conflicting.contains(&true) {
|
||||||
all_ok = false;
|
all_ok = false;
|
||||||
any_binding_error = true;
|
any_binding_error = true;
|
||||||
all_not_callable = false;
|
all_not_callable = false;
|
||||||
|
@ -226,7 +230,7 @@ impl<'db> Bindings<'db> {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (index, conflicting_form) in self.conflicting_forms.iter().enumerate() {
|
for (index, conflicting_form) in self.argument_forms.conflicting.iter().enumerate() {
|
||||||
if *conflicting_form {
|
if *conflicting_form {
|
||||||
let node = BindingError::get_node(node, Some(index));
|
let node = BindingError::get_node(node, Some(index));
|
||||||
if let Some(builder) = context.report_lint(&CONFLICTING_ARGUMENT_FORMS, node) {
|
if let Some(builder) = context.report_lint(&CONFLICTING_ARGUMENT_FORMS, node) {
|
||||||
|
@ -1118,8 +1122,7 @@ impl<'db> From<CallableBinding<'db>> for Bindings<'db> {
|
||||||
Bindings {
|
Bindings {
|
||||||
callable_type: from.callable_type,
|
callable_type: from.callable_type,
|
||||||
elements: smallvec_inline![from],
|
elements: smallvec_inline![from],
|
||||||
argument_forms: Box::from([]),
|
argument_forms: ArgumentForms::new(0),
|
||||||
conflicting_forms: Box::from([]),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1140,8 +1143,7 @@ impl<'db> From<Binding<'db>> for Bindings<'db> {
|
||||||
Bindings {
|
Bindings {
|
||||||
callable_type,
|
callable_type,
|
||||||
elements: smallvec_inline![callable_binding],
|
elements: smallvec_inline![callable_binding],
|
||||||
argument_forms: Box::from([]),
|
argument_forms: ArgumentForms::new(0),
|
||||||
conflicting_forms: Box::from([]),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1262,19 +1264,22 @@ impl<'db> CallableBinding<'db> {
|
||||||
&mut self,
|
&mut self,
|
||||||
db: &'db dyn Db,
|
db: &'db dyn Db,
|
||||||
arguments: &CallArguments<'_, 'db>,
|
arguments: &CallArguments<'_, 'db>,
|
||||||
argument_forms: &mut [Option<ParameterForm>],
|
argument_forms: &mut ArgumentForms,
|
||||||
conflicting_forms: &mut [bool],
|
|
||||||
) {
|
) {
|
||||||
// If this callable is a bound method, prepend the self instance onto the arguments list
|
// If this callable is a bound method, prepend the self instance onto the arguments list
|
||||||
// before checking.
|
// before checking.
|
||||||
let arguments = arguments.with_self(self.bound_type);
|
let arguments = arguments.with_self(self.bound_type);
|
||||||
|
|
||||||
for overload in &mut self.overloads {
|
for overload in &mut self.overloads {
|
||||||
overload.match_parameters(db, arguments.as_ref(), argument_forms, conflicting_forms);
|
overload.match_parameters(db, arguments.as_ref(), argument_forms);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_types(&mut self, db: &'db dyn Db, argument_types: &CallArguments<'_, 'db>) {
|
fn check_types(
|
||||||
|
&mut self,
|
||||||
|
db: &'db dyn Db,
|
||||||
|
argument_types: &CallArguments<'_, 'db>,
|
||||||
|
) -> Option<ArgumentForms> {
|
||||||
// If this callable is a bound method, prepend the self instance onto the arguments list
|
// If this callable is a bound method, prepend the self instance onto the arguments list
|
||||||
// before checking.
|
// before checking.
|
||||||
let argument_types = argument_types.with_self(self.bound_type);
|
let argument_types = argument_types.with_self(self.bound_type);
|
||||||
|
@ -1288,14 +1293,14 @@ impl<'db> CallableBinding<'db> {
|
||||||
if let [overload] = self.overloads.as_mut_slice() {
|
if let [overload] = self.overloads.as_mut_slice() {
|
||||||
overload.check_types(db, argument_types.as_ref());
|
overload.check_types(db, argument_types.as_ref());
|
||||||
}
|
}
|
||||||
return;
|
return None;
|
||||||
}
|
}
|
||||||
MatchingOverloadIndex::Single(index) => {
|
MatchingOverloadIndex::Single(index) => {
|
||||||
// If only one candidate overload remains, it is the winning match. Evaluate it as
|
// If only one candidate overload remains, it is the winning match. Evaluate it as
|
||||||
// a regular (non-overloaded) call.
|
// a regular (non-overloaded) call.
|
||||||
self.matching_overload_index = Some(index);
|
self.matching_overload_index = Some(index);
|
||||||
self.overloads[index].check_types(db, argument_types.as_ref());
|
self.overloads[index].check_types(db, argument_types.as_ref());
|
||||||
return;
|
return None;
|
||||||
}
|
}
|
||||||
MatchingOverloadIndex::Multiple(indexes) => {
|
MatchingOverloadIndex::Multiple(indexes) => {
|
||||||
// If two or more candidate overloads remain, proceed to step 2.
|
// If two or more candidate overloads remain, proceed to step 2.
|
||||||
|
@ -1303,12 +1308,6 @@ impl<'db> CallableBinding<'db> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let snapshotter = CallableBindingSnapshotter::new(matching_overload_indexes);
|
|
||||||
|
|
||||||
// State of the bindings _before_ evaluating (type checking) the matching overloads using
|
|
||||||
// the non-expanded argument types.
|
|
||||||
let pre_evaluation_snapshot = snapshotter.take(self);
|
|
||||||
|
|
||||||
// Step 2: Evaluate each remaining overload as a regular (non-overloaded) call to determine
|
// Step 2: Evaluate each remaining overload as a regular (non-overloaded) call to determine
|
||||||
// whether it is compatible with the supplied argument list.
|
// whether it is compatible with the supplied argument list.
|
||||||
for (_, overload) in self.matching_overloads_mut() {
|
for (_, overload) in self.matching_overloads_mut() {
|
||||||
|
@ -1321,7 +1320,7 @@ impl<'db> CallableBinding<'db> {
|
||||||
}
|
}
|
||||||
MatchingOverloadIndex::Single(_) => {
|
MatchingOverloadIndex::Single(_) => {
|
||||||
// If only one overload evaluates without error, it is the winning match.
|
// If only one overload evaluates without error, it is the winning match.
|
||||||
return;
|
return None;
|
||||||
}
|
}
|
||||||
MatchingOverloadIndex::Multiple(indexes) => {
|
MatchingOverloadIndex::Multiple(indexes) => {
|
||||||
// If two or more candidate overloads remain, proceed to step 4.
|
// If two or more candidate overloads remain, proceed to step 4.
|
||||||
|
@ -1330,8 +1329,8 @@ impl<'db> CallableBinding<'db> {
|
||||||
// Step 5
|
// Step 5
|
||||||
self.filter_overloads_using_any_or_unknown(db, argument_types.as_ref(), &indexes);
|
self.filter_overloads_using_any_or_unknown(db, argument_types.as_ref(), &indexes);
|
||||||
|
|
||||||
// We're returning here because this shouldn't lead to argument type expansion.
|
// This shouldn't lead to argument type expansion.
|
||||||
return;
|
return None;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1339,27 +1338,14 @@ impl<'db> CallableBinding<'db> {
|
||||||
// https://typing.python.org/en/latest/spec/overload.html#argument-type-expansion
|
// https://typing.python.org/en/latest/spec/overload.html#argument-type-expansion
|
||||||
let mut expansions = argument_types.expand(db).peekable();
|
let mut expansions = argument_types.expand(db).peekable();
|
||||||
|
|
||||||
if expansions.peek().is_none() {
|
|
||||||
// Return early if there are no argument types to expand.
|
// Return early if there are no argument types to expand.
|
||||||
return;
|
expansions.peek()?;
|
||||||
}
|
|
||||||
|
|
||||||
// State of the bindings _after_ evaluating (type checking) the matching overloads using
|
|
||||||
// the non-expanded argument types.
|
|
||||||
let post_evaluation_snapshot = snapshotter.take(self);
|
|
||||||
|
|
||||||
// Restore the bindings state to the one prior to the type checking step in preparation
|
|
||||||
// for evaluating the expanded argument lists.
|
|
||||||
snapshotter.restore(self, pre_evaluation_snapshot);
|
|
||||||
|
|
||||||
// At this point, there's at least one argument that can be expanded.
|
// At this point, there's at least one argument that can be expanded.
|
||||||
//
|
//
|
||||||
// This heuristic tries to detect if there's any need to perform argument type expansion or
|
// This heuristic tries to detect if there's any need to perform argument type expansion or
|
||||||
// not by checking whether there are any non-expandable argument type that cannot be
|
// not by checking whether there are any non-expandable argument type that cannot be
|
||||||
// assigned to any of the remaining overloads.
|
// assigned to any of the overloads.
|
||||||
//
|
|
||||||
// This heuristic needs to be applied after restoring the bindings state to the one before
|
|
||||||
// type checking as argument type expansion would evaluate it from that point on.
|
|
||||||
for (argument_index, (argument, argument_type)) in argument_types.iter().enumerate() {
|
for (argument_index, (argument, argument_type)) in argument_types.iter().enumerate() {
|
||||||
// TODO: Remove `Keywords` once `**kwargs` support is added
|
// TODO: Remove `Keywords` once `**kwargs` support is added
|
||||||
if matches!(argument, Argument::Synthetic | Argument::Keywords) {
|
if matches!(argument, Argument::Synthetic | Argument::Keywords) {
|
||||||
|
@ -1372,7 +1358,7 @@ impl<'db> CallableBinding<'db> {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
let mut is_argument_assignable_to_any_overload = false;
|
let mut is_argument_assignable_to_any_overload = false;
|
||||||
'overload: for (_, overload) in self.matching_overloads() {
|
'overload: for overload in &self.overloads {
|
||||||
for parameter_index in &overload.argument_matches[argument_index].parameters {
|
for parameter_index in &overload.argument_matches[argument_index].parameters {
|
||||||
let parameter_type = overload.signature.parameters()[*parameter_index]
|
let parameter_type = overload.signature.parameters()[*parameter_index]
|
||||||
.annotated_type()
|
.annotated_type()
|
||||||
|
@ -1389,11 +1375,16 @@ impl<'db> CallableBinding<'db> {
|
||||||
remaining overloads, skipping argument type expansion",
|
remaining overloads, skipping argument type expansion",
|
||||||
argument_type.display(db)
|
argument_type.display(db)
|
||||||
);
|
);
|
||||||
snapshotter.restore(self, post_evaluation_snapshot);
|
return None;
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let snapshotter = CallableBindingSnapshotter::new(matching_overload_indexes);
|
||||||
|
|
||||||
|
// State of the bindings _after_ evaluating (type checking) the matching overloads using
|
||||||
|
// the non-expanded argument types.
|
||||||
|
let post_evaluation_snapshot = snapshotter.take(self);
|
||||||
|
|
||||||
for expansion in expansions {
|
for expansion in expansions {
|
||||||
let expanded_argument_lists = match expansion {
|
let expanded_argument_lists = match expansion {
|
||||||
Expansion::LimitReached(index) => {
|
Expansion::LimitReached(index) => {
|
||||||
|
@ -1401,7 +1392,7 @@ impl<'db> CallableBinding<'db> {
|
||||||
self.overload_call_return_type = Some(
|
self.overload_call_return_type = Some(
|
||||||
OverloadCallReturnType::ArgumentTypeExpansionLimitReached(index),
|
OverloadCallReturnType::ArgumentTypeExpansionLimitReached(index),
|
||||||
);
|
);
|
||||||
return;
|
return None;
|
||||||
}
|
}
|
||||||
Expansion::Expanded(argument_lists) => argument_lists,
|
Expansion::Expanded(argument_lists) => argument_lists,
|
||||||
};
|
};
|
||||||
|
@ -1411,13 +1402,33 @@ impl<'db> CallableBinding<'db> {
|
||||||
// the expanded argument lists evaluated successfully.
|
// the expanded argument lists evaluated successfully.
|
||||||
let mut merged_evaluation_state: Option<CallableBindingSnapshot<'db>> = None;
|
let mut merged_evaluation_state: Option<CallableBindingSnapshot<'db>> = None;
|
||||||
|
|
||||||
|
// Merged argument forms after evaluating all the argument lists in this expansion.
|
||||||
|
let mut merged_argument_forms = ArgumentForms::default();
|
||||||
|
|
||||||
|
// The return types of each of the expanded argument lists that evaluated successfully.
|
||||||
let mut return_types = Vec::new();
|
let mut return_types = Vec::new();
|
||||||
|
|
||||||
for expanded_argument_types in &expanded_argument_lists {
|
for expanded_arguments in &expanded_argument_lists {
|
||||||
let pre_evaluation_snapshot = snapshotter.take(self);
|
let mut argument_forms = ArgumentForms::new(expanded_arguments.len());
|
||||||
|
|
||||||
|
// The spec mentions that each expanded argument list should be re-evaluated from
|
||||||
|
// step 2 but we need to re-evaluate from step 1 because our step 1 does more than
|
||||||
|
// what the spec mentions. Step 1 of the spec means only "eliminate impossible
|
||||||
|
// overloads due to arity mismatch" while our step 1 (`match_parameters`) also
|
||||||
|
// includes "match arguments to the parameters". This is important because it
|
||||||
|
// allows us to correctly handle cases involving a variadic argument that could
|
||||||
|
// expand into different number of arguments with each expansion. Refer to
|
||||||
|
// https://github.com/astral-sh/ty/issues/735 for more details.
|
||||||
|
for overload in &mut self.overloads {
|
||||||
|
// Clear the state of all overloads before re-evaluating from step 1
|
||||||
|
overload.reset();
|
||||||
|
overload.match_parameters(db, expanded_arguments, &mut argument_forms);
|
||||||
|
}
|
||||||
|
|
||||||
|
merged_argument_forms.merge(&argument_forms);
|
||||||
|
|
||||||
for (_, overload) in self.matching_overloads_mut() {
|
for (_, overload) in self.matching_overloads_mut() {
|
||||||
overload.check_types(db, expanded_argument_types);
|
overload.check_types(db, expanded_arguments);
|
||||||
}
|
}
|
||||||
|
|
||||||
let return_type = match self.matching_overload_index() {
|
let return_type = match self.matching_overload_index() {
|
||||||
|
@ -1430,7 +1441,7 @@ impl<'db> CallableBinding<'db> {
|
||||||
|
|
||||||
self.filter_overloads_using_any_or_unknown(
|
self.filter_overloads_using_any_or_unknown(
|
||||||
db,
|
db,
|
||||||
expanded_argument_types,
|
expanded_arguments,
|
||||||
&matching_overload_indexes,
|
&matching_overload_indexes,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -1451,9 +1462,6 @@ impl<'db> CallableBinding<'db> {
|
||||||
merged_evaluation_state = Some(snapshotter.take(self));
|
merged_evaluation_state = Some(snapshotter.take(self));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Restore the bindings state before evaluating the next argument list.
|
|
||||||
snapshotter.restore(self, pre_evaluation_snapshot);
|
|
||||||
|
|
||||||
if let Some(return_type) = return_type {
|
if let Some(return_type) = return_type {
|
||||||
return_types.push(return_type);
|
return_types.push(return_type);
|
||||||
} else {
|
} else {
|
||||||
|
@ -1481,7 +1489,7 @@ impl<'db> CallableBinding<'db> {
|
||||||
UnionType::from_elements(db, return_types),
|
UnionType::from_elements(db, return_types),
|
||||||
));
|
));
|
||||||
|
|
||||||
return;
|
return Some(merged_argument_forms);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1490,6 +1498,8 @@ impl<'db> CallableBinding<'db> {
|
||||||
// argument types. This is necessary because we restore the state to the pre-evaluation
|
// argument types. This is necessary because we restore the state to the pre-evaluation
|
||||||
// snapshot when processing the expanded argument lists.
|
// snapshot when processing the expanded argument lists.
|
||||||
snapshotter.restore(self, post_evaluation_snapshot);
|
snapshotter.restore(self, post_evaluation_snapshot);
|
||||||
|
|
||||||
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Filter overloads based on [`Any`] or [`Unknown`] argument types.
|
/// Filter overloads based on [`Any`] or [`Unknown`] argument types.
|
||||||
|
@ -1915,10 +1925,59 @@ enum MatchingOverloadIndex {
|
||||||
Multiple(Vec<usize>),
|
Multiple(Vec<usize>),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Default, Debug)]
|
||||||
|
struct ArgumentForms {
|
||||||
|
values: Vec<Option<ParameterForm>>,
|
||||||
|
conflicting: Vec<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ArgumentForms {
|
||||||
|
/// Create a new argument forms initialized to the given length and the default values.
|
||||||
|
fn new(len: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
values: vec![None; len],
|
||||||
|
conflicting: vec![false; len],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn merge(&mut self, other: &ArgumentForms) {
|
||||||
|
if self.values.len() < other.values.len() {
|
||||||
|
self.values.resize(other.values.len(), None);
|
||||||
|
self.conflicting.resize(other.conflicting.len(), false);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (index, (other_form, other_conflict)) in other
|
||||||
|
.values
|
||||||
|
.iter()
|
||||||
|
.zip(other.conflicting.iter())
|
||||||
|
.enumerate()
|
||||||
|
{
|
||||||
|
if let Some(self_form) = &mut self.values[index] {
|
||||||
|
if let Some(other_form) = other_form {
|
||||||
|
if *self_form != *other_form {
|
||||||
|
// Different parameter forms, mark as conflicting
|
||||||
|
self.conflicting[index] = true;
|
||||||
|
*self_form = *other_form; // Use the new form
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
self.values[index] = *other_form;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the conflicting form (true takes precedence)
|
||||||
|
self.conflicting[index] |= *other_conflict;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn shrink_to_fit(&mut self) {
|
||||||
|
self.values.shrink_to_fit();
|
||||||
|
self.conflicting.shrink_to_fit();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct ArgumentMatcher<'a, 'db> {
|
struct ArgumentMatcher<'a, 'db> {
|
||||||
parameters: &'a Parameters<'db>,
|
parameters: &'a Parameters<'db>,
|
||||||
argument_forms: &'a mut [Option<ParameterForm>],
|
argument_forms: &'a mut ArgumentForms,
|
||||||
conflicting_forms: &'a mut [bool],
|
|
||||||
errors: &'a mut Vec<BindingError<'db>>,
|
errors: &'a mut Vec<BindingError<'db>>,
|
||||||
|
|
||||||
argument_matches: Vec<MatchedArgument<'db>>,
|
argument_matches: Vec<MatchedArgument<'db>>,
|
||||||
|
@ -1932,14 +1991,12 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
|
||||||
fn new(
|
fn new(
|
||||||
arguments: &CallArguments,
|
arguments: &CallArguments,
|
||||||
parameters: &'a Parameters<'db>,
|
parameters: &'a Parameters<'db>,
|
||||||
argument_forms: &'a mut [Option<ParameterForm>],
|
argument_forms: &'a mut ArgumentForms,
|
||||||
conflicting_forms: &'a mut [bool],
|
|
||||||
errors: &'a mut Vec<BindingError<'db>>,
|
errors: &'a mut Vec<BindingError<'db>>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
parameters,
|
parameters,
|
||||||
argument_forms,
|
argument_forms,
|
||||||
conflicting_forms,
|
|
||||||
errors,
|
errors,
|
||||||
argument_matches: vec![MatchedArgument::default(); arguments.len()],
|
argument_matches: vec![MatchedArgument::default(); arguments.len()],
|
||||||
parameter_matched: vec![false; parameters.len()],
|
parameter_matched: vec![false; parameters.len()],
|
||||||
|
@ -1971,11 +2028,13 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
|
||||||
positional: bool,
|
positional: bool,
|
||||||
) {
|
) {
|
||||||
if !matches!(argument, Argument::Synthetic) {
|
if !matches!(argument, Argument::Synthetic) {
|
||||||
if let Some(existing) = self.argument_forms[argument_index - self.num_synthetic_args]
|
let adjusted_argument_index = argument_index - self.num_synthetic_args;
|
||||||
.replace(parameter.form)
|
if let Some(existing) =
|
||||||
|
self.argument_forms.values[adjusted_argument_index].replace(parameter.form)
|
||||||
{
|
{
|
||||||
if existing != parameter.form {
|
if existing != parameter.form {
|
||||||
self.conflicting_forms[argument_index - self.num_synthetic_args] = true;
|
self.argument_forms.conflicting[argument_index - self.num_synthetic_args] =
|
||||||
|
true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2295,22 +2354,6 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
|
||||||
// how many elements the iterator will produce.
|
// how many elements the iterator will produce.
|
||||||
let argument_types = argument_type.iterate(self.db);
|
let argument_types = argument_type.iterate(self.db);
|
||||||
|
|
||||||
// TODO: When we perform argument expansion during overload resolution, we might need
|
|
||||||
// to retry both `match_parameters` _and_ `check_types` for each expansion. Currently
|
|
||||||
// we only retry `check_types`. The issue is that argument expansion might produce a
|
|
||||||
// splatted value with a different arity than what we originally inferred for the
|
|
||||||
// unexpanded value, and that in turn can affect which parameters the splatted value is
|
|
||||||
// matched with. As a workaround, make sure that the splatted tuple contains an
|
|
||||||
// arbitrary number of `Unknown`s at the end, so that if the expanded value has a
|
|
||||||
// smaller arity than the unexpanded value, we still have enough values to assign to
|
|
||||||
// the already matched parameters.
|
|
||||||
let argument_types = match argument_types.as_ref() {
|
|
||||||
Tuple::Fixed(_) => {
|
|
||||||
Cow::Owned(argument_types.concat(self.db, &Tuple::homogeneous(Type::unknown())))
|
|
||||||
}
|
|
||||||
Tuple::Variable(_) => argument_types,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Resize the tuple of argument types to line up with the number of parameters this
|
// Resize the tuple of argument types to line up with the number of parameters this
|
||||||
// argument was matched against. If parameter matching succeeded, then we can (TODO:
|
// argument was matched against. If parameter matching succeeded, then we can (TODO:
|
||||||
// should be able to, see above) guarantee that all of the required elements of the
|
// should be able to, see above) guarantee that all of the required elements of the
|
||||||
|
@ -2441,21 +2484,15 @@ impl<'db> Binding<'db> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn match_parameters(
|
fn match_parameters(
|
||||||
&mut self,
|
&mut self,
|
||||||
db: &'db dyn Db,
|
db: &'db dyn Db,
|
||||||
arguments: &CallArguments<'_, 'db>,
|
arguments: &CallArguments<'_, 'db>,
|
||||||
argument_forms: &mut [Option<ParameterForm>],
|
argument_forms: &mut ArgumentForms,
|
||||||
conflicting_forms: &mut [bool],
|
|
||||||
) {
|
) {
|
||||||
let parameters = self.signature.parameters();
|
let parameters = self.signature.parameters();
|
||||||
let mut matcher = ArgumentMatcher::new(
|
let mut matcher =
|
||||||
arguments,
|
ArgumentMatcher::new(arguments, parameters, argument_forms, &mut self.errors);
|
||||||
parameters,
|
|
||||||
argument_forms,
|
|
||||||
conflicting_forms,
|
|
||||||
&mut self.errors,
|
|
||||||
);
|
|
||||||
for (argument_index, (argument, argument_type)) in arguments.iter().enumerate() {
|
for (argument_index, (argument, argument_type)) in arguments.iter().enumerate() {
|
||||||
match argument {
|
match argument {
|
||||||
Argument::Positional | Argument::Synthetic => {
|
Argument::Positional | Argument::Synthetic => {
|
||||||
|
@ -2610,6 +2647,16 @@ impl<'db> Binding<'db> {
|
||||||
pub(crate) fn errors(&self) -> &[BindingError<'db>] {
|
pub(crate) fn errors(&self) -> &[BindingError<'db>] {
|
||||||
&self.errors
|
&self.errors
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Resets the state of this binding to its initial state.
|
||||||
|
fn reset(&mut self) {
|
||||||
|
self.return_ty = Type::unknown();
|
||||||
|
self.specialization = None;
|
||||||
|
self.inherited_specialization = None;
|
||||||
|
self.argument_matches = Box::from([]);
|
||||||
|
self.parameter_tys = Box::from([]);
|
||||||
|
self.errors.clear();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
|
|
|
@ -5833,7 +5833,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
||||||
let bindings = callable_type
|
let bindings = callable_type
|
||||||
.bindings(self.db())
|
.bindings(self.db())
|
||||||
.match_parameters(self.db(), &call_arguments);
|
.match_parameters(self.db(), &call_arguments);
|
||||||
self.infer_argument_types(arguments, &mut call_arguments, &bindings.argument_forms);
|
self.infer_argument_types(arguments, &mut call_arguments, bindings.argument_forms());
|
||||||
|
|
||||||
// Validate `TypedDict` constructor calls after argument type inference
|
// Validate `TypedDict` constructor calls after argument type inference
|
||||||
if let Some(class_literal) = callable_type.into_class_literal() {
|
if let Some(class_literal) = callable_type.into_class_literal() {
|
||||||
|
|
|
@ -1015,11 +1015,6 @@ impl<'db> Tuple<Type<'db>> {
|
||||||
UnionType::from_elements(db, self.all_elements())
|
UnionType::from_elements(db, self.all_elements())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Concatenates another tuple to the end of this tuple, returning a new tuple.
|
|
||||||
pub(crate) fn concat(&self, db: &'db dyn Db, other: &Self) -> Self {
|
|
||||||
TupleSpecBuilder::from(self).concat(db, other).build()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Resizes this tuple to a different length, if possible. If this tuple cannot satisfy the
|
/// Resizes this tuple to a different length, if possible. If this tuple cannot satisfy the
|
||||||
/// desired minimum or maximum length, we return an error. If we return an `Ok` result, the
|
/// desired minimum or maximum length, we return an error. If we return an `Ok` result, the
|
||||||
/// [`len`][Self::len] of the resulting tuple is guaranteed to be equal to `new_length`.
|
/// [`len`][Self::len] of the resulting tuple is guaranteed to be equal to `new_length`.
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue