mirror of
https://github.com/astral-sh/ruff.git
synced 2025-08-15 08:00:19 +00:00
[ty] Filter overloads based on Any
/ Unknown
(#18607)
## Summary Closes: astral-sh/ty#552 This PR adds support for step 5 of the overload call evaluation algorithm which specifies: > For all arguments, determine whether all possible materializations of the argument’s type are > assignable to the corresponding parameter type for each of the remaining overloads. If so, > eliminate all of the subsequent remaining overloads. The algorithm works in two parts: 1. Find out the participating parameter indexes. These are the parameters that aren't gradual equivalent to one or more parameter types at the same index in other overloads. 2. Loop over each overload and check whether that would be the _final_ overload for the argument types i.e., the remaining overloads will never be matched against these argument types For step 1, the participating parameter indexes are computed by just comparing whether all the parameter types at the corresponding index for all the overloads are **gradual equivalent**. The step 2 of the algorithm used is described in [this comment](https://github.com/astral-sh/ty/issues/552#issuecomment-2969165421). ## Test Plan Update the overload call tests.
This commit is contained in:
parent
1d458d4314
commit
c7e020df6b
3 changed files with 700 additions and 45 deletions
|
@ -5834,9 +5834,9 @@ impl<'db> KnownInstanceType<'db> {
|
|||
|
||||
#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
|
||||
pub enum DynamicType {
|
||||
// An explicitly annotated `typing.Any`
|
||||
/// An explicitly annotated `typing.Any`
|
||||
Any,
|
||||
// An unannotated value, or a dynamic type resulting from an error
|
||||
/// An unannotated value, or a dynamic type resulting from an error
|
||||
Unknown,
|
||||
/// Temporary type for symbols that can't be inferred yet because of missing implementations.
|
||||
///
|
||||
|
|
|
@ -3,6 +3,8 @@
|
|||
//! [signatures][crate::types::signatures], we have to handle the fact that the callable might be a
|
||||
//! union of types, each of which might contain multiple overloads.
|
||||
|
||||
use std::collections::HashSet;
|
||||
|
||||
use itertools::Itertools;
|
||||
use ruff_db::parsed::parsed_module;
|
||||
use smallvec::{SmallVec, smallvec};
|
||||
|
@ -1029,7 +1031,7 @@ impl<'db> From<Binding<'db>> for Bindings<'db> {
|
|||
signature_type,
|
||||
dunder_call_is_possibly_unbound: false,
|
||||
bound_type: None,
|
||||
return_type: None,
|
||||
overload_call_return_type: None,
|
||||
overloads: smallvec![from],
|
||||
};
|
||||
Bindings {
|
||||
|
@ -1068,13 +1070,21 @@ pub(crate) struct CallableBinding<'db> {
|
|||
/// The type of the bound `self` or `cls` parameter if this signature is for a bound method.
|
||||
pub(crate) bound_type: Option<Type<'db>>,
|
||||
|
||||
/// The return type of this callable.
|
||||
/// The return type of this overloaded callable.
|
||||
///
|
||||
/// This is only `Some` if it's an overloaded callable, "argument type expansion" was
|
||||
/// performed, and one of the expansion evaluated successfully for all of the argument lists.
|
||||
/// This type is then the union of all the return types of the matched overloads for the
|
||||
/// expanded argument lists.
|
||||
return_type: Option<Type<'db>>,
|
||||
/// This is [`Some`] only in the following cases:
|
||||
/// 1. Argument type expansion was performed and one of the expansions evaluated successfully
|
||||
/// for all of the argument lists, or
|
||||
/// 2. Overload call evaluation was ambiguous, meaning that multiple overloads matched the
|
||||
/// argument lists, but they all had different return types
|
||||
///
|
||||
/// For (1), the final return type is the union of all the return types of the matched
|
||||
/// overloads for the expanded argument lists.
|
||||
///
|
||||
/// For (2), the final return type is [`Unknown`].
|
||||
///
|
||||
/// [`Unknown`]: crate::types::DynamicType::Unknown
|
||||
overload_call_return_type: Option<OverloadCallReturnType<'db>>,
|
||||
|
||||
/// The bindings of each overload of this callable. Will be empty if the type is not callable.
|
||||
///
|
||||
|
@ -1097,7 +1107,7 @@ impl<'db> CallableBinding<'db> {
|
|||
signature_type,
|
||||
dunder_call_is_possibly_unbound: false,
|
||||
bound_type: None,
|
||||
return_type: None,
|
||||
overload_call_return_type: None,
|
||||
overloads,
|
||||
}
|
||||
}
|
||||
|
@ -1108,7 +1118,7 @@ impl<'db> CallableBinding<'db> {
|
|||
signature_type,
|
||||
dunder_call_is_possibly_unbound: false,
|
||||
bound_type: None,
|
||||
return_type: None,
|
||||
overload_call_return_type: None,
|
||||
overloads: smallvec![],
|
||||
}
|
||||
}
|
||||
|
@ -1176,7 +1186,7 @@ impl<'db> CallableBinding<'db> {
|
|||
}
|
||||
};
|
||||
|
||||
let snapshotter = MatchingOverloadsSnapshotter::new(matching_overload_indexes);
|
||||
let snapshotter = CallableBindingSnapshotter::new(matching_overload_indexes);
|
||||
|
||||
// State of the bindings _before_ evaluating (type checking) the matching overloads using
|
||||
// the non-expanded argument types.
|
||||
|
@ -1196,9 +1206,13 @@ impl<'db> CallableBinding<'db> {
|
|||
// If only one overload evaluates without error, it is the winning match.
|
||||
return;
|
||||
}
|
||||
MatchingOverloadIndex::Multiple(_) => {
|
||||
MatchingOverloadIndex::Multiple(indexes) => {
|
||||
// If two or more candidate overloads remain, proceed to step 4.
|
||||
// TODO: Step 4 and Step 5 goes here...
|
||||
// TODO: Step 4
|
||||
|
||||
// Step 5
|
||||
self.filter_overloads_using_any_or_unknown(db, argument_types.types(), &indexes);
|
||||
|
||||
// We're returning here because this shouldn't lead to argument type expansion.
|
||||
return;
|
||||
}
|
||||
|
@ -1225,7 +1239,7 @@ impl<'db> CallableBinding<'db> {
|
|||
// This is the merged state of the bindings after evaluating all of the expanded
|
||||
// argument lists. This will be the final state to restore the bindings to if all of
|
||||
// the expanded argument lists evaluated successfully.
|
||||
let mut merged_evaluation_state: Option<MatchingOverloadsSnapshot<'db>> = None;
|
||||
let mut merged_evaluation_state: Option<CallableBindingSnapshot<'db>> = None;
|
||||
|
||||
let mut return_types = Vec::new();
|
||||
|
||||
|
@ -1241,10 +1255,16 @@ impl<'db> CallableBinding<'db> {
|
|||
MatchingOverloadIndex::Single(index) => {
|
||||
Some(self.overloads[index].return_type())
|
||||
}
|
||||
MatchingOverloadIndex::Multiple(index) => {
|
||||
// TODO: Step 4 and Step 5 goes here... but for now we just use the return
|
||||
// type of the first matched overload.
|
||||
Some(self.overloads[index[0]].return_type())
|
||||
MatchingOverloadIndex::Multiple(matching_overload_indexes) => {
|
||||
// TODO: Step 4
|
||||
|
||||
self.filter_overloads_using_any_or_unknown(
|
||||
db,
|
||||
expanded_argument_types,
|
||||
&matching_overload_indexes,
|
||||
);
|
||||
|
||||
Some(self.return_type())
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1274,17 +1294,23 @@ impl<'db> CallableBinding<'db> {
|
|||
}
|
||||
|
||||
if return_types.len() == expanded_argument_lists.len() {
|
||||
// If the number of return types is equal to the number of expanded argument lists,
|
||||
// they all evaluated successfully. So, we need to combine their return types by
|
||||
// union to determine the final return type.
|
||||
self.return_type = Some(UnionType::from_elements(db, return_types));
|
||||
|
||||
// Restore the bindings state to the one that merges the bindings state evaluating
|
||||
// each of the expanded argument list.
|
||||
//
|
||||
// Note that this needs to happen *before* setting the return type, because this
|
||||
// will restore the return type to the one before argument type expansion.
|
||||
if let Some(merged_evaluation_state) = merged_evaluation_state {
|
||||
snapshotter.restore(self, merged_evaluation_state);
|
||||
}
|
||||
|
||||
// If the number of return types is equal to the number of expanded argument lists,
|
||||
// they all evaluated successfully. So, we need to combine their return types by
|
||||
// union to determine the final return type.
|
||||
self.overload_call_return_type =
|
||||
Some(OverloadCallReturnType::ArgumentTypeExpansion(
|
||||
UnionType::from_elements(db, return_types),
|
||||
));
|
||||
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
@ -1296,6 +1322,137 @@ impl<'db> CallableBinding<'db> {
|
|||
snapshotter.restore(self, post_evaluation_snapshot);
|
||||
}
|
||||
|
||||
/// Filter overloads based on [`Any`] or [`Unknown`] argument types.
|
||||
///
|
||||
/// This is the step 5 of the [overload call evaluation algorithm][1].
|
||||
///
|
||||
/// The filtering works on the remaining overloads that are present at the
|
||||
/// `matching_overload_indexes` and are filtered out by marking them as unmatched overloads
|
||||
/// using the [`mark_as_unmatched_overload`] method.
|
||||
///
|
||||
/// [`Any`]: crate::types::DynamicType::Any
|
||||
/// [`Unknown`]: crate::types::DynamicType::Unknown
|
||||
/// [`mark_as_unmatched_overload`]: Binding::mark_as_unmatched_overload
|
||||
/// [1]: https://typing.python.org/en/latest/spec/overload.html#overload-call-evaluation
|
||||
fn filter_overloads_using_any_or_unknown(
|
||||
&mut self,
|
||||
db: &'db dyn Db,
|
||||
argument_types: &[Type<'db>],
|
||||
matching_overload_indexes: &[usize],
|
||||
) {
|
||||
// These are the parameter indexes that matches the arguments that participate in the
|
||||
// filtering process.
|
||||
//
|
||||
// The parameter types at these indexes have at least one overload where the type isn't
|
||||
// gradual equivalent to the parameter types at the same index for other overloads.
|
||||
let mut participating_parameter_indexes = HashSet::new();
|
||||
|
||||
// These only contain the top materialized argument types for the corresponding
|
||||
// participating parameter indexes.
|
||||
let mut top_materialized_argument_types = vec![];
|
||||
|
||||
for (argument_index, argument_type) in argument_types.iter().enumerate() {
|
||||
let mut first_parameter_type: Option<Type<'db>> = None;
|
||||
let mut participating_parameter_index = None;
|
||||
|
||||
for overload_index in matching_overload_indexes {
|
||||
let overload = &self.overloads[*overload_index];
|
||||
let Some(parameter_index) = overload.argument_parameters[argument_index] else {
|
||||
// There is no parameter for this argument in this overload.
|
||||
break;
|
||||
};
|
||||
// TODO: For an unannotated `self` / `cls` parameter, the type should be
|
||||
// `typing.Self` / `type[typing.Self]`
|
||||
let current_parameter_type = overload.signature.parameters()[parameter_index]
|
||||
.annotated_type()
|
||||
.unwrap_or(Type::unknown());
|
||||
if let Some(first_parameter_type) = first_parameter_type {
|
||||
if !first_parameter_type.is_gradual_equivalent_to(db, current_parameter_type) {
|
||||
participating_parameter_index = Some(parameter_index);
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
first_parameter_type = Some(current_parameter_type);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(parameter_index) = participating_parameter_index {
|
||||
participating_parameter_indexes.insert(parameter_index);
|
||||
top_materialized_argument_types.push(argument_type.top_materialization(db));
|
||||
}
|
||||
}
|
||||
|
||||
let top_materialized_argument_type =
|
||||
TupleType::from_elements(db, top_materialized_argument_types);
|
||||
|
||||
// A flag to indicate whether we've found the overload that makes the remaining overloads
|
||||
// unmatched for the given argument types.
|
||||
let mut filter_remaining_overloads = false;
|
||||
|
||||
for (upto, current_index) in matching_overload_indexes.iter().enumerate() {
|
||||
if filter_remaining_overloads {
|
||||
self.overloads[*current_index].mark_as_unmatched_overload();
|
||||
continue;
|
||||
}
|
||||
let mut parameter_types = Vec::with_capacity(argument_types.len());
|
||||
for argument_index in 0..argument_types.len() {
|
||||
// The parameter types at the current argument index.
|
||||
let mut current_parameter_types = vec![];
|
||||
for overload_index in &matching_overload_indexes[..=upto] {
|
||||
let overload = &self.overloads[*overload_index];
|
||||
let Some(parameter_index) = overload.argument_parameters[argument_index] else {
|
||||
// There is no parameter for this argument in this overload.
|
||||
continue;
|
||||
};
|
||||
if !participating_parameter_indexes.contains(¶meter_index) {
|
||||
// This parameter doesn't participate in the filtering process.
|
||||
continue;
|
||||
}
|
||||
// TODO: For an unannotated `self` / `cls` parameter, the type should be
|
||||
// `typing.Self` / `type[typing.Self]`
|
||||
let parameter_type = overload.signature.parameters()[parameter_index]
|
||||
.annotated_type()
|
||||
.unwrap_or(Type::unknown());
|
||||
current_parameter_types.push(parameter_type);
|
||||
}
|
||||
if current_parameter_types.is_empty() {
|
||||
continue;
|
||||
}
|
||||
parameter_types.push(UnionType::from_elements(db, current_parameter_types));
|
||||
}
|
||||
if top_materialized_argument_type
|
||||
.is_assignable_to(db, TupleType::from_elements(db, parameter_types))
|
||||
{
|
||||
filter_remaining_overloads = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Once this filtering process is applied for all arguments, examine the return types of
|
||||
// the remaining overloads. If the resulting return types for all remaining overloads are
|
||||
// equivalent, proceed to step 6.
|
||||
let are_return_types_equivalent_for_all_matching_overloads = {
|
||||
let mut matching_overloads = self.matching_overloads();
|
||||
if let Some(first_overload_return_type) = matching_overloads
|
||||
.next()
|
||||
.map(|(_, overload)| overload.return_type())
|
||||
{
|
||||
matching_overloads.all(|(_, overload)| {
|
||||
overload
|
||||
.return_type()
|
||||
.is_equivalent_to(db, first_overload_return_type)
|
||||
})
|
||||
} else {
|
||||
// No matching overload
|
||||
true
|
||||
}
|
||||
};
|
||||
|
||||
if !are_return_types_equivalent_for_all_matching_overloads {
|
||||
// Overload matching is ambiguous.
|
||||
self.overload_call_return_type = Some(OverloadCallReturnType::Ambiguous);
|
||||
}
|
||||
}
|
||||
|
||||
fn as_result(&self) -> Result<(), CallErrorKind> {
|
||||
if !self.is_callable() {
|
||||
return Err(CallErrorKind::NotCallable);
|
||||
|
@ -1370,8 +1527,11 @@ impl<'db> CallableBinding<'db> {
|
|||
/// For an invalid call to an overloaded function, we return `Type::unknown`, since we cannot
|
||||
/// make any useful conclusions about which overload was intended to be called.
|
||||
pub(crate) fn return_type(&self) -> Type<'db> {
|
||||
if let Some(return_type) = self.return_type {
|
||||
return return_type;
|
||||
if let Some(overload_call_return_type) = self.overload_call_return_type {
|
||||
return match overload_call_return_type {
|
||||
OverloadCallReturnType::ArgumentTypeExpansion(return_type) => return_type,
|
||||
OverloadCallReturnType::Ambiguous => Type::unknown(),
|
||||
};
|
||||
}
|
||||
if let Some((_, first_overload)) = self.matching_overloads().next() {
|
||||
return first_overload.return_type();
|
||||
|
@ -1521,6 +1681,12 @@ impl<'a, 'db> IntoIterator for &'a CallableBinding<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
enum OverloadCallReturnType<'db> {
|
||||
ArgumentTypeExpansion(Type<'db>),
|
||||
Ambiguous,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum MatchingOverloadIndex {
|
||||
/// No matching overloads found.
|
||||
|
@ -1855,6 +2021,11 @@ impl<'db> Binding<'db> {
|
|||
.map(|(arg_and_type, _)| arg_and_type)
|
||||
}
|
||||
|
||||
/// Mark this overload binding as an unmatched overload.
|
||||
fn mark_as_unmatched_overload(&mut self) {
|
||||
self.errors.push(BindingError::UnmatchedOverload);
|
||||
}
|
||||
|
||||
fn report_diagnostics(
|
||||
&self,
|
||||
context: &InferContext<'db, '_>,
|
||||
|
@ -1915,23 +2086,27 @@ struct BindingSnapshot<'db> {
|
|||
errors: Vec<BindingError<'db>>,
|
||||
}
|
||||
|
||||
/// Represents the snapshot of the matched overload bindings.
|
||||
///
|
||||
/// The reason that this only contains the matched overloads are:
|
||||
/// 1. Avoid creating snapshots for the overloads that have been filtered by the arity check
|
||||
/// 2. Avoid duplicating errors when merging the snapshots on a successful evaluation of all the
|
||||
/// expanded argument lists
|
||||
#[derive(Clone, Debug)]
|
||||
struct MatchingOverloadsSnapshot<'db>(Vec<(usize, BindingSnapshot<'db>)>);
|
||||
struct CallableBindingSnapshot<'db> {
|
||||
overload_return_type: Option<OverloadCallReturnType<'db>>,
|
||||
|
||||
impl<'db> MatchingOverloadsSnapshot<'db> {
|
||||
/// Represents the snapshot of the matched overload bindings.
|
||||
///
|
||||
/// The reason that this only contains the matched overloads are:
|
||||
/// 1. Avoid creating snapshots for the overloads that have been filtered by the arity check
|
||||
/// 2. Avoid duplicating errors when merging the snapshots on a successful evaluation of all
|
||||
/// the expanded argument lists
|
||||
matching_overloads: Vec<(usize, BindingSnapshot<'db>)>,
|
||||
}
|
||||
|
||||
impl<'db> CallableBindingSnapshot<'db> {
|
||||
/// Update the state of the matched overload bindings in this snapshot with the current
|
||||
/// state in the given `binding`.
|
||||
fn update(&mut self, binding: &CallableBinding<'db>) {
|
||||
// Here, the `snapshot` is the state of this binding for the previous argument list and
|
||||
// `binding` would contain the state after evaluating the current argument list.
|
||||
for (snapshot, binding) in self
|
||||
.0
|
||||
.matching_overloads
|
||||
.iter_mut()
|
||||
.map(|(index, snapshot)| (snapshot, &binding.overloads[*index]))
|
||||
{
|
||||
|
@ -1967,13 +2142,13 @@ impl<'db> MatchingOverloadsSnapshot<'db> {
|
|||
|
||||
/// A helper to take snapshots of the matched overload bindings for the current state of the
|
||||
/// bindings.
|
||||
struct MatchingOverloadsSnapshotter(Vec<usize>);
|
||||
struct CallableBindingSnapshotter(Vec<usize>);
|
||||
|
||||
impl MatchingOverloadsSnapshotter {
|
||||
impl CallableBindingSnapshotter {
|
||||
/// Creates a new snapshotter for the given indexes of the matched overloads.
|
||||
fn new(indexes: Vec<usize>) -> Self {
|
||||
debug_assert!(indexes.len() > 1);
|
||||
MatchingOverloadsSnapshotter(indexes)
|
||||
CallableBindingSnapshotter(indexes)
|
||||
}
|
||||
|
||||
/// Takes a snapshot of the current state of the matched overload bindings.
|
||||
|
@ -1981,23 +2156,26 @@ impl MatchingOverloadsSnapshotter {
|
|||
/// # Panics
|
||||
///
|
||||
/// Panics if the indexes of the matched overloads are not valid for the given binding.
|
||||
fn take<'db>(&self, binding: &CallableBinding<'db>) -> MatchingOverloadsSnapshot<'db> {
|
||||
MatchingOverloadsSnapshot(
|
||||
self.0
|
||||
fn take<'db>(&self, binding: &CallableBinding<'db>) -> CallableBindingSnapshot<'db> {
|
||||
CallableBindingSnapshot {
|
||||
overload_return_type: binding.overload_call_return_type,
|
||||
matching_overloads: self
|
||||
.0
|
||||
.iter()
|
||||
.map(|index| (*index, binding.overloads[*index].snapshot()))
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Restores the state of the matched overload bindings from the given snapshot.
|
||||
fn restore<'db>(
|
||||
&self,
|
||||
binding: &mut CallableBinding<'db>,
|
||||
snapshot: MatchingOverloadsSnapshot<'db>,
|
||||
snapshot: CallableBindingSnapshot<'db>,
|
||||
) {
|
||||
debug_assert_eq!(self.0.len(), snapshot.0.len());
|
||||
for (index, snapshot) in snapshot.0 {
|
||||
debug_assert_eq!(self.0.len(), snapshot.matching_overloads.len());
|
||||
binding.overload_call_return_type = snapshot.overload_return_type;
|
||||
for (index, snapshot) in snapshot.matching_overloads {
|
||||
binding.overloads[index].restore(snapshot);
|
||||
}
|
||||
}
|
||||
|
@ -2140,6 +2318,9 @@ pub(crate) enum BindingError<'db> {
|
|||
/// We use this variant to report errors in `property.__get__` and `property.__set__`, which
|
||||
/// can occur when the call to the underlying getter/setter fails.
|
||||
InternalCallError(&'static str),
|
||||
/// This overload binding of the callable does not match the arguments.
|
||||
// TODO: We could expand this with an enum to specify why the overload is unmatched.
|
||||
UnmatchedOverload,
|
||||
}
|
||||
|
||||
impl<'db> BindingError<'db> {
|
||||
|
@ -2332,6 +2513,8 @@ impl<'db> BindingError<'db> {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
Self::UnmatchedOverload => {}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue