ty_python_semantic: add union type context to function call type errors

This context gets added only when calling a function through a union
type.
This commit is contained in:
Andrew Gallant 2025-05-09 09:58:18 -04:00 committed by Andrew Gallant
parent 5ea3a52c8a
commit 346e82b572
42 changed files with 235 additions and 101 deletions

View file

@ -24,7 +24,7 @@ use crate::types::{
FunctionType, KnownClass, KnownFunction, KnownInstanceType, MethodWrapperKind,
PropertyInstanceType, TupleType, UnionType, WrapperDescriptorKind,
};
use ruff_db::diagnostic::{Annotation, Severity, SubDiagnostic};
use ruff_db::diagnostic::{Annotation, Diagnostic, Severity, SubDiagnostic};
use ruff_python_ast as ast;
/// Binding information for a possible union of callables. At a call site, the arguments must be
@ -199,8 +199,19 @@ impl<'db> Bindings<'db> {
}
}
// If this is not a union, then report a diagnostic for any
// errors as normal.
if let Some(binding) = self.single_element() {
binding.report_diagnostics(context, node, None);
return;
}
for binding in self {
binding.report_diagnostics(context, node);
let union_diag = UnionDiagnostic {
callable_type: self.callable_type(),
binding,
};
binding.report_diagnostics(context, node, Some(&union_diag));
}
}
@ -1043,23 +1054,34 @@ impl<'db> CallableBinding<'db> {
Type::unknown()
}
fn report_diagnostics(&self, context: &InferContext<'db>, node: ast::AnyNodeRef) {
fn report_diagnostics(
&self,
context: &InferContext<'db>,
node: ast::AnyNodeRef,
union_diag: Option<&UnionDiagnostic<'_, '_>>,
) {
if !self.is_callable() {
if let Some(builder) = context.report_lint(&CALL_NON_CALLABLE, node) {
builder.into_diagnostic(format_args!(
let mut diag = builder.into_diagnostic(format_args!(
"Object of type `{}` is not callable",
self.callable_type.display(context.db()),
));
if let Some(union_diag) = union_diag {
union_diag.add_union_context(context.db(), &mut diag);
}
}
return;
}
if self.dunder_call_is_possibly_unbound {
if let Some(builder) = context.report_lint(&CALL_NON_CALLABLE, node) {
builder.into_diagnostic(format_args!(
let mut diag = builder.into_diagnostic(format_args!(
"Object of type `{}` is not callable (possibly unbound `__call__` method)",
self.callable_type.display(context.db()),
));
if let Some(union_diag) = union_diag {
union_diag.add_union_context(context.db(), &mut diag);
}
}
return;
}
@ -1067,7 +1089,7 @@ impl<'db> CallableBinding<'db> {
let callable_description = CallableDescription::new(context.db(), self.callable_type);
if self.overloads.len() > 1 {
if let Some(builder) = context.report_lint(&NO_MATCHING_OVERLOAD, node) {
builder.into_diagnostic(format_args!(
let mut diag = builder.into_diagnostic(format_args!(
"No overload{} matches arguments",
if let Some(CallableDescription { kind, name }) = callable_description {
format!(" of {kind} `{name}`")
@ -1075,6 +1097,9 @@ impl<'db> CallableBinding<'db> {
String::new()
}
));
if let Some(union_diag) = union_diag {
union_diag.add_union_context(context.db(), &mut diag);
}
}
return;
}
@ -1086,6 +1111,7 @@ impl<'db> CallableBinding<'db> {
node,
self.signature_type,
callable_description.as_ref(),
union_diag,
);
}
}
@ -1385,9 +1411,10 @@ impl<'db> Binding<'db> {
node: ast::AnyNodeRef,
callable_ty: Type<'db>,
callable_description: Option<&CallableDescription>,
union_diag: Option<&UnionDiagnostic<'_, '_>>,
) {
for error in &self.errors {
error.report_diagnostic(context, node, callable_ty, callable_description);
error.report_diagnostic(context, node, callable_ty, callable_description, union_diag);
}
}
@ -1539,12 +1566,13 @@ pub(crate) enum BindingError<'db> {
}
impl<'db> BindingError<'db> {
pub(super) fn report_diagnostic(
fn report_diagnostic(
&self,
context: &InferContext<'db>,
node: ast::AnyNodeRef,
callable_ty: Type<'db>,
callable_description: Option<&CallableDescription>,
union_diag: Option<&UnionDiagnostic<'_, '_>>,
) {
match self {
Self::InvalidArgumentType {
@ -1561,7 +1589,14 @@ impl<'db> BindingError<'db> {
let provided_ty_display = provided_ty.display(context.db());
let expected_ty_display = expected_ty.display(context.db());
let mut diag = builder.into_diagnostic("Argument to this function is incorrect");
let mut diag = builder.into_diagnostic(format_args!(
"Argument{} is incorrect",
if let Some(CallableDescription { kind, name }) = callable_description {
format!(" to {kind} `{name}`")
} else {
String::new()
}
));
diag.set_primary_message(format_args!(
"Expected `{expected_ty_display}`, found `{provided_ty_display}`"
));
@ -1575,6 +1610,9 @@ impl<'db> BindingError<'db> {
);
diag.sub(sub);
}
if let Some(union_diag) = union_diag {
union_diag.add_union_context(context.db(), &mut diag);
}
}
Self::TooManyPositionalArguments {
@ -1584,7 +1622,7 @@ impl<'db> BindingError<'db> {
} => {
let node = Self::get_node(node, *first_excess_argument_index);
if let Some(builder) = context.report_lint(&TOO_MANY_POSITIONAL_ARGUMENTS, node) {
builder.into_diagnostic(format_args!(
let mut diag = builder.into_diagnostic(format_args!(
"Too many positional arguments{}: expected \
{expected_positional_count}, got {provided_positional_count}",
if let Some(CallableDescription { kind, name }) = callable_description {
@ -1593,13 +1631,16 @@ impl<'db> BindingError<'db> {
String::new()
}
));
if let Some(union_diag) = union_diag {
union_diag.add_union_context(context.db(), &mut diag);
}
}
}
Self::MissingArguments { parameters } => {
if let Some(builder) = context.report_lint(&MISSING_ARGUMENT, node) {
let s = if parameters.0.len() == 1 { "" } else { "s" };
builder.into_diagnostic(format_args!(
let mut diag = builder.into_diagnostic(format_args!(
"No argument{s} provided for required parameter{s} {parameters}{}",
if let Some(CallableDescription { kind, name }) = callable_description {
format!(" of {kind} `{name}`")
@ -1607,6 +1648,9 @@ impl<'db> BindingError<'db> {
String::new()
}
));
if let Some(union_diag) = union_diag {
union_diag.add_union_context(context.db(), &mut diag);
}
}
}
@ -1616,7 +1660,7 @@ impl<'db> BindingError<'db> {
} => {
let node = Self::get_node(node, *argument_index);
if let Some(builder) = context.report_lint(&UNKNOWN_ARGUMENT, node) {
builder.into_diagnostic(format_args!(
let mut diag = builder.into_diagnostic(format_args!(
"Argument `{argument_name}` does not match any known parameter{}",
if let Some(CallableDescription { kind, name }) = callable_description {
format!(" of {kind} `{name}`")
@ -1624,6 +1668,9 @@ impl<'db> BindingError<'db> {
String::new()
}
));
if let Some(union_diag) = union_diag {
union_diag.add_union_context(context.db(), &mut diag);
}
}
}
@ -1633,7 +1680,7 @@ impl<'db> BindingError<'db> {
} => {
let node = Self::get_node(node, *argument_index);
if let Some(builder) = context.report_lint(&PARAMETER_ALREADY_ASSIGNED, node) {
builder.into_diagnostic(format_args!(
let mut diag = builder.into_diagnostic(format_args!(
"Multiple values provided for parameter {parameter}{}",
if let Some(CallableDescription { kind, name }) = callable_description {
format!(" of {kind} `{name}`")
@ -1641,6 +1688,9 @@ impl<'db> BindingError<'db> {
String::new()
}
));
if let Some(union_diag) = union_diag {
union_diag.add_union_context(context.db(), &mut diag);
}
}
}
@ -1657,7 +1707,14 @@ impl<'db> BindingError<'db> {
let argument_type = error.argument_type();
let argument_ty_display = argument_type.display(context.db());
let mut diag = builder.into_diagnostic("Argument to this function is incorrect");
let mut diag = builder.into_diagnostic(format_args!(
"Argument{} is incorrect",
if let Some(CallableDescription { kind, name }) = callable_description {
format!(" to {kind} `{name}`")
} else {
String::new()
}
));
diag.set_primary_message(format_args!(
"Argument type `{argument_ty_display}` does not satisfy {} of type variable `{}`",
match error {
@ -1671,12 +1728,15 @@ impl<'db> BindingError<'db> {
let mut sub = SubDiagnostic::new(Severity::Info, "Type variable defined here");
sub.annotate(Annotation::primary(typevar_range.into()));
diag.sub(sub);
if let Some(union_diag) = union_diag {
union_diag.add_union_context(context.db(), &mut diag);
}
}
Self::InternalCallError(reason) => {
let node = Self::get_node(node, None);
if let Some(builder) = context.report_lint(&CALL_NON_CALLABLE, node) {
builder.into_diagnostic(format_args!(
let mut diag = builder.into_diagnostic(format_args!(
"Call{} failed: {reason}",
if let Some(CallableDescription { kind, name }) = callable_description {
format!(" of {kind} `{name}`")
@ -1684,6 +1744,9 @@ impl<'db> BindingError<'db> {
String::new()
}
));
if let Some(union_diag) = union_diag {
union_diag.add_union_context(context.db(), &mut diag);
}
}
}
}
@ -1708,3 +1771,39 @@ impl<'db> BindingError<'db> {
}
}
}
/// Contains additional context for union specific diagnostics.
///
/// This is used when a function call is inconsistent with one or more variants
/// of a union. This can be used to attach sub-diagnostics that clarify that
/// the error is part of a union.
struct UnionDiagnostic<'b, 'db> {
/// The type of the union.
callable_type: Type<'db>,
/// The specific binding that failed.
binding: &'b CallableBinding<'db>,
}
impl UnionDiagnostic<'_, '_> {
/// Adds context about any relevant union function types to the given
/// diagnostic.
fn add_union_context(&self, db: &'_ dyn Db, diag: &mut Diagnostic) {
let sub = SubDiagnostic::new(
Severity::Info,
format_args!(
"Union variant `{callable_ty}` is incompatible with this call site",
callable_ty = self.binding.callable_type.display(db),
),
);
diag.sub(sub);
let sub = SubDiagnostic::new(
Severity::Info,
format_args!(
"Attempted to call union type `{}`",
self.callable_type.display(db)
),
);
diag.sub(sub);
}
}