[ty] Surface matched overload diagnostic directly (#18452)
Some checks are pending
CI / cargo build (msrv) (push) Blocked by required conditions
CI / cargo clippy (push) Blocked by required conditions
CI / Determine changes (push) Waiting to run
CI / cargo fmt (push) Waiting to run
CI / cargo test (linux) (push) Blocked by required conditions
CI / cargo test (linux, release) (push) Blocked by required conditions
CI / cargo test (windows) (push) Blocked by required conditions
CI / cargo test (wasm) (push) Blocked by required conditions
CI / cargo build (release) (push) Waiting to run
CI / cargo fuzz build (push) Blocked by required conditions
CI / fuzz parser (push) Blocked by required conditions
CI / test scripts (push) Blocked by required conditions
CI / ecosystem (push) Blocked by required conditions
CI / Fuzz for new ty panics (push) Blocked by required conditions
CI / cargo shear (push) Blocked by required conditions
CI / python package (push) Waiting to run
CI / pre-commit (push) Waiting to run
CI / mkdocs (push) Waiting to run
CI / formatter instabilities and black similarity (push) Blocked by required conditions
CI / test ruff-lsp (push) Blocked by required conditions
CI / check playground (push) Blocked by required conditions
CI / benchmarks-instrumented (push) Blocked by required conditions
CI / benchmarks-walltime (push) Blocked by required conditions
[ty Playground] Release / publish (push) Waiting to run

## Summary

This PR resolves the way diagnostics are reported for an invalid call to
an overloaded function.

If any of the steps in the overload call evaluation algorithm yields a
matching overload but it's type checking that failed, the
`no-matching-overload` diagnostic is incorrect because there is a
matching overload, it's the arguments passed that are invalid as per the
signature. So, this PR improves that by surfacing the diagnostics on the
matching overload directly.

It also provides additional context, specifically the matching overload
where this error occurred and other non-matching overloads. Consider the
following example:

```py
from typing import overload


@overload
def f() -> None: ...
@overload
def f(x: int) -> int: ...
@overload
def f(x: int, y: int) -> int: ...
def f(x: int | None = None, y: int | None = None) -> int | None:
    return None


f("a")
```

We get:

<img width="857" alt="Screenshot 2025-06-18 at 11 07 10"
src="https://github.com/user-attachments/assets/8dbcaf13-2a74-4661-aa94-1225c9402ea6"
/>


## Test Plan

Update test cases, resolve existing todos and validate the updated
snapshots.
This commit is contained in:
Dhruv Manilawala 2025-06-20 08:36:49 +05:30 committed by GitHub
parent 20d73dd41c
commit 22177e6915
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 856 additions and 137 deletions

View file

@ -4,6 +4,7 @@
//! union of types, each of which might contain multiple overloads.
use std::collections::HashSet;
use std::fmt;
use itertools::Itertools;
use ruff_db::parsed::parsed_module;
@ -21,7 +22,9 @@ use crate::types::diagnostic::{
NO_MATCHING_OVERLOAD, PARAMETER_ALREADY_ASSIGNED, TOO_MANY_POSITIONAL_ARGUMENTS,
UNKNOWN_ARGUMENT,
};
use crate::types::function::{DataclassTransformerParams, FunctionDecorators, KnownFunction};
use crate::types::function::{
DataclassTransformerParams, FunctionDecorators, FunctionType, KnownFunction, OverloadLiteral,
};
use crate::types::generics::{Specialization, SpecializationBuilder, SpecializationError};
use crate::types::signatures::{Parameter, ParameterForm};
use crate::types::{
@ -257,7 +260,7 @@ impl<'db> Bindings<'db> {
}
};
// Each special case listed here should have a corresponding clause in `Type::signatures`.
// Each special case listed here should have a corresponding clause in `Type::bindings`.
for binding in &mut self.elements {
let binding_type = binding.callable_type;
for (overload_index, overload) in binding.matching_overloads_mut() {
@ -1032,6 +1035,7 @@ impl<'db> From<Binding<'db>> for Bindings<'db> {
dunder_call_is_possibly_unbound: false,
bound_type: None,
overload_call_return_type: None,
matching_overload_index: None,
overloads: smallvec![from],
};
Bindings {
@ -1086,6 +1090,22 @@ pub(crate) struct CallableBinding<'db> {
/// [`Unknown`]: crate::types::DynamicType::Unknown
overload_call_return_type: Option<OverloadCallReturnType<'db>>,
/// The index of the overload that matched for this overloaded callable.
///
/// This is [`Some`] only for step 1 and 4 of the [overload call evaluation algorithm][1].
///
/// The main use of this field is to surface the diagnostics for a matching overload directly
/// instead of using the `no-matching-overload` diagnostic. This is mentioned in the spec:
///
/// > If only one candidate overload remains, it is the winning match. Evaluate it as if it
/// > were a non-overloaded function call and stop.
///
/// Other steps of the algorithm do not set this field because this use case isn't relevant for
/// them.
///
/// [1]: https://typing.python.org/en/latest/spec/overload.html#overload-call-evaluation
matching_overload_index: Option<usize>,
/// The bindings of each overload of this callable. Will be empty if the type is not callable.
///
/// By using `SmallVec`, we avoid an extra heap allocation for the common case of a
@ -1108,6 +1128,7 @@ impl<'db> CallableBinding<'db> {
dunder_call_is_possibly_unbound: false,
bound_type: None,
overload_call_return_type: None,
matching_overload_index: None,
overloads,
}
}
@ -1119,6 +1140,7 @@ impl<'db> CallableBinding<'db> {
dunder_call_is_possibly_unbound: false,
bound_type: None,
overload_call_return_type: None,
matching_overload_index: None,
overloads: smallvec![],
}
}
@ -1169,10 +1191,9 @@ impl<'db> CallableBinding<'db> {
return;
}
MatchingOverloadIndex::Single(index) => {
// If only one candidate overload remains, it is the winning match.
// TODO: Evaluate it as a regular (non-overloaded) call. This means that any
// diagnostics reported in this check should be reported directly instead of
// reporting it as `no-matching-overload`.
// If only one candidate overload remains, it is the winning match. Evaluate it as
// a regular (non-overloaded) call.
self.matching_overload_index = Some(index);
self.overloads[index].check_types(
db,
argument_types.as_ref(),
@ -1585,15 +1606,48 @@ impl<'db> CallableBinding<'db> {
self.signature_type,
callable_description.as_ref(),
union_diag,
None,
);
}
_overloads => {
// When the number of unmatched overloads exceeds this number, we stop
// printing them to avoid excessive output.
// TODO: This should probably be adapted to handle more
// types of callables[1]. At present, it just handles
// standard function and method calls.
//
// An example of a routine with many many overloads:
// https://github.com/henribru/google-api-python-client-stubs/blob/master/googleapiclient-stubs/discovery.pyi
const MAXIMUM_OVERLOADS: usize = 50;
// [1]: https://github.com/astral-sh/ty/issues/274#issuecomment-2881856028
let function_type_and_kind = match self.signature_type {
Type::FunctionLiteral(function) => Some((FunctionKind::Function, function)),
Type::BoundMethod(bound_method) => Some((
FunctionKind::BoundMethod,
bound_method.function(context.db()),
)),
Type::MethodWrapper(MethodWrapperKind::FunctionTypeDunderGet(function)) => {
Some((FunctionKind::MethodWrapper, function))
}
_ => None,
};
// If there is a single matching overload, the diagnostics should be reported
// directly for that overload.
if let Some(matching_overload_index) = self.matching_overload_index {
let callable_description =
CallableDescription::new(context.db(), self.signature_type);
let matching_overload =
function_type_and_kind.map(|(kind, function)| MatchingOverloadLiteral {
index: matching_overload_index,
kind,
function,
});
self.overloads[matching_overload_index].report_diagnostics(
context,
node,
self.signature_type,
callable_description.as_ref(),
union_diag,
matching_overload.as_ref(),
);
return;
}
let Some(builder) = context.report_lint(&NO_MATCHING_OVERLOAD, node) else {
return;
@ -1608,18 +1662,6 @@ impl<'db> CallableBinding<'db> {
String::new()
}
));
// TODO: This should probably be adapted to handle more
// types of callables[1]. At present, it just handles
// standard function and method calls.
//
// [1]: https://github.com/astral-sh/ty/issues/274#issuecomment-2881856028
let function_type_and_kind = match self.signature_type {
Type::FunctionLiteral(function) => Some(("function", function)),
Type::BoundMethod(bound_method) => {
Some(("bound method", bound_method.function(context.db())))
}
_ => None,
};
if let Some((kind, function)) = function_type_and_kind {
let (overloads, implementation) =
function.overloads_and_implementation(context.db());
@ -2033,9 +2075,17 @@ impl<'db> Binding<'db> {
callable_ty: Type<'db>,
callable_description: Option<&CallableDescription>,
union_diag: Option<&UnionDiagnostic<'_, '_>>,
matching_overload: Option<&MatchingOverloadLiteral<'db>>,
) {
for error in &self.errors {
error.report_diagnostic(context, node, callable_ty, callable_description, union_diag);
error.report_diagnostic(
context,
node,
callable_ty,
callable_description,
union_diag,
matching_overload,
);
}
}
@ -2331,6 +2381,7 @@ impl<'db> BindingError<'db> {
callable_ty: Type<'db>,
callable_description: Option<&CallableDescription>,
union_diag: Option<&UnionDiagnostic<'_, '_>>,
matching_overload: Option<&MatchingOverloadLiteral<'_>>,
) {
match self {
Self::InvalidArgumentType {
@ -2358,7 +2409,48 @@ impl<'db> BindingError<'db> {
diag.set_primary_message(format_args!(
"Expected `{expected_ty_display}`, found `{provided_ty_display}`"
));
if let Some((name_span, parameter_span)) =
if let Some(matching_overload) = matching_overload {
if let Some((name_span, parameter_span)) =
matching_overload.get(context.db()).and_then(|overload| {
overload.parameter_span(context.db(), Some(parameter.index))
})
{
let mut sub =
SubDiagnostic::new(Severity::Info, "Matching overload defined here");
sub.annotate(Annotation::primary(name_span));
sub.annotate(
Annotation::secondary(parameter_span)
.message("Parameter declared here"),
);
diag.sub(sub);
diag.info(format_args!(
"Non-matching overloads for {} `{}`:",
matching_overload.kind,
matching_overload.function.name(context.db())
));
let (overloads, _) = matching_overload
.function
.overloads_and_implementation(context.db());
for (overload_index, overload) in
overloads.iter().enumerate().take(MAXIMUM_OVERLOADS)
{
if overload_index == matching_overload.index {
continue;
}
diag.info(format_args!(
" {}",
overload.signature(context.db(), None).display(context.db())
));
}
if overloads.len() > MAXIMUM_OVERLOADS {
diag.info(format_args!(
"... omitted {remaining} overloads",
remaining = overloads.len() - MAXIMUM_OVERLOADS
));
}
}
} else if let Some((name_span, parameter_span)) =
callable_ty.parameter_span(context.db(), Some(parameter.index))
{
let mut sub = SubDiagnostic::new(Severity::Info, "Function defined here");
@ -2368,6 +2460,7 @@ impl<'db> BindingError<'db> {
);
diag.sub(sub);
}
if let Some(union_diag) = union_diag {
union_diag.add_union_context(context.db(), &mut diag);
}
@ -2573,3 +2666,55 @@ impl UnionDiagnostic<'_, '_> {
diag.sub(sub);
}
}
/// Represents the matching overload of a function literal that was found via the overload call
/// evaluation algorithm.
struct MatchingOverloadLiteral<'db> {
/// The position of the matching overload in the list of overloads.
index: usize,
/// The kind of function this overload is for.
kind: FunctionKind,
/// The function literal that this overload belongs to.
///
/// This is used to retrieve the overload at the given index.
function: FunctionType<'db>,
}
impl<'db> MatchingOverloadLiteral<'db> {
/// Returns the [`OverloadLiteral`] representing this matching overload.
fn get(&self, db: &'db dyn Db) -> Option<OverloadLiteral<'db>> {
let (overloads, _) = self.function.overloads_and_implementation(db);
// TODO: This should actually be safe to index directly but isn't so as of this writing.
// The main reason is that we've custom overload signatures that are constructed manually
// and does not belong to any file. For example, the `__get__` method of a function literal
// has a custom overloaded signature. So, when we try to retrieve the actual overloads
// above, we get an empty list of overloads because the implementation of that method
// relies on it existing in the file.
overloads.get(self.index).copied()
}
}
#[derive(Clone, Copy, Debug)]
enum FunctionKind {
Function,
BoundMethod,
MethodWrapper,
}
impl fmt::Display for FunctionKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
FunctionKind::Function => write!(f, "function"),
FunctionKind::BoundMethod => write!(f, "bound method"),
FunctionKind::MethodWrapper => write!(f, "method wrapper `__get__` of function"),
}
}
}
// When the number of unmatched overloads exceeds this number, we stop printing them to avoid
// excessive output.
//
// An example of a routine with many many overloads:
// https://github.com/henribru/google-api-python-client-stubs/blob/master/googleapiclient-stubs/discovery.pyi
const MAXIMUM_OVERLOADS: usize = 50;

View file

@ -296,7 +296,7 @@ impl<'db> OverloadLiteral<'db> {
)
}
fn parameter_span(
pub(crate) fn parameter_span(
self,
db: &'db dyn Db,
parameter_index: Option<usize>,