[red-knot] Update call binding to return all matching overloads (#17618)

## Summary

This PR updates the existing overload matching methods to return an
iterator of all the matched overloads instead.

This would be useful once the overload call evaluation algorithm is
implemented which should provide an accurate picture of all the matched
overloads. The return type would then be picked from either the only
matched overload or the first overload from the ones that are matched.

In an earlier version of this PR, it tried to check if using an
intersection of return types from the matched overload would help reduce
the false positives but that's not enough. [This
comment](https://github.com/astral-sh/ruff/pull/17618#issuecomment-2842891696)
keep the ecosystem analysis for that change for prosperity.

> [!NOTE]
>
> The best way to review this PR is by hiding the whitespace changes
because there are two instances where a large match expression is
indented to be inside a loop over matching overlods
>
> <img width="1207" alt="Screenshot 2025-04-28 at 15 12 16"
src="https://github.com/user-attachments/assets/e06cbfa4-04fa-435f-84ef-4e5c3c5626d1"
/>

## Test Plan

Make sure existing test cases are unaffected and no ecosystem changes.
This commit is contained in:
Dhruv Manilawala 2025-05-01 01:33:21 +05:30 committed by GitHub
parent 6e765b4527
commit d2a238dfad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 903 additions and 858 deletions

View file

@ -4527,27 +4527,43 @@ impl<'db> Type<'db> {
new_call_outcome @ (None | Some(Ok(_))),
init_call_outcome @ (None | Some(Ok(_))),
) => {
fn combine_specializations<'db>(
db: &'db dyn Db,
s1: Option<Specialization<'db>>,
s2: Option<Specialization<'db>>,
) -> Option<Specialization<'db>> {
match (s1, s2) {
(None, None) => None,
(Some(s), None) | (None, Some(s)) => Some(s),
(Some(s1), Some(s2)) => Some(s1.combine(db, s2)),
}
}
fn combine_binding_specialization<'db>(
db: &'db dyn Db,
binding: &CallableBinding<'db>,
) -> Option<Specialization<'db>> {
binding
.matching_overloads()
.map(|(_, binding)| binding.inherited_specialization())
.reduce(|acc, specialization| {
combine_specializations(db, acc, specialization)
})
.flatten()
}
let new_specialization = new_call_outcome
.and_then(Result::ok)
.as_ref()
.and_then(Bindings::single_element)
.and_then(CallableBinding::matching_overload)
.and_then(|(_, binding)| binding.inherited_specialization());
.and_then(|binding| combine_binding_specialization(db, binding));
let init_specialization = init_call_outcome
.and_then(Result::ok)
.as_ref()
.and_then(Bindings::single_element)
.and_then(CallableBinding::matching_overload)
.and_then(|(_, binding)| binding.inherited_specialization());
let specialization = match (new_specialization, init_specialization) {
(None, None) => None,
(Some(specialization), None) | (None, Some(specialization)) => {
Some(specialization)
}
(Some(new_specialization), Some(init_specialization)) => {
Some(new_specialization.combine(db, init_specialization))
}
};
.and_then(|binding| combine_binding_specialization(db, binding));
let specialization =
combine_specializations(db, new_specialization, init_specialization);
let specialized = specialization
.map(|specialization| {
Type::instance(

File diff suppressed because it is too large Load diff

View file

@ -4627,307 +4627,310 @@ impl<'db> TypeInferenceBuilder<'db> {
Ok(mut bindings) => {
for binding in &mut bindings {
let binding_type = binding.callable_type;
let Some((_, overload)) = binding.matching_overload_mut() else {
continue;
};
for (_, overload) in binding.matching_overloads_mut() {
match binding_type {
Type::FunctionLiteral(function_literal) => {
let Some(known_function) = function_literal.known(self.db()) else {
continue;
};
match binding_type {
Type::FunctionLiteral(function_literal) => {
let Some(known_function) = function_literal.known(self.db()) else {
continue;
};
match known_function {
KnownFunction::RevealType => {
if let [Some(revealed_type)] = overload.parameter_types() {
if let Some(builder) = self.context.report_diagnostic(
DiagnosticId::RevealedType,
Severity::Info,
) {
let mut diag = builder.into_diagnostic("Revealed type");
let span = self.context.span(call_expression);
diag.annotate(Annotation::primary(span).message(
format_args!(
"`{}`",
revealed_type.display(self.db())
),
));
match known_function {
KnownFunction::RevealType => {
if let [Some(revealed_type)] = overload.parameter_types() {
if let Some(builder) = self.context.report_diagnostic(
DiagnosticId::RevealedType,
Severity::Info,
) {
let mut diag =
builder.into_diagnostic("Revealed type");
let span = self.context.span(call_expression);
diag.annotate(Annotation::primary(span).message(
format_args!(
"`{}`",
revealed_type.display(self.db())
),
));
}
}
}
}
KnownFunction::AssertType => {
if let [Some(actual_ty), Some(asserted_ty)] =
overload.parameter_types()
{
if !actual_ty
.is_gradual_equivalent_to(self.db(), *asserted_ty)
KnownFunction::AssertType => {
if let [Some(actual_ty), Some(asserted_ty)] =
overload.parameter_types()
{
if let Some(builder) = self.context.report_lint(
&TYPE_ASSERTION_FAILURE,
call_expression,
) {
builder.into_diagnostic(format_args!(
"Actual type `{}` is not the same \
if !actual_ty
.is_gradual_equivalent_to(self.db(), *asserted_ty)
{
if let Some(builder) = self.context.report_lint(
&TYPE_ASSERTION_FAILURE,
call_expression,
) {
builder.into_diagnostic(format_args!(
"Actual type `{}` is not the same \
as asserted type `{}`",
actual_ty.display(self.db()),
asserted_ty.display(self.db()),
));
actual_ty.display(self.db()),
asserted_ty.display(self.db()),
));
}
}
}
}
}
KnownFunction::AssertNever => {
if let [Some(actual_ty)] = overload.parameter_types() {
if !actual_ty.is_equivalent_to(self.db(), Type::Never) {
if let Some(builder) = self.context.report_lint(
&TYPE_ASSERTION_FAILURE,
call_expression,
) {
builder.into_diagnostic(format_args!(
"Expected type `Never`, got `{}` instead",
actual_ty.display(self.db()),
));
KnownFunction::AssertNever => {
if let [Some(actual_ty)] = overload.parameter_types() {
if !actual_ty.is_equivalent_to(self.db(), Type::Never) {
if let Some(builder) = self.context.report_lint(
&TYPE_ASSERTION_FAILURE,
call_expression,
) {
builder.into_diagnostic(format_args!(
"Expected type `Never`, got `{}` instead",
actual_ty.display(self.db()),
));
}
}
}
}
}
KnownFunction::StaticAssert => {
if let [Some(parameter_ty), message] =
overload.parameter_types()
{
let truthiness = match parameter_ty.try_bool(self.db()) {
Ok(truthiness) => truthiness,
Err(err) => {
let condition = arguments
.find_argument("condition", 0)
.map(|argument| match argument {
KnownFunction::StaticAssert => {
if let [Some(parameter_ty), message] =
overload.parameter_types()
{
let truthiness = match parameter_ty.try_bool(self.db())
{
Ok(truthiness) => truthiness,
Err(err) => {
let condition = arguments
.find_argument("condition", 0)
.map(|argument| {
match argument {
ruff_python_ast::ArgOrKeyword::Arg(
expr,
) => ast::AnyNodeRef::from(expr),
ruff_python_ast::ArgOrKeyword::Keyword(
keyword,
) => ast::AnyNodeRef::from(keyword),
})
.unwrap_or(ast::AnyNodeRef::from(
call_expression,
));
}
})
.unwrap_or(ast::AnyNodeRef::from(
call_expression,
));
err.report_diagnostic(&self.context, condition);
err.report_diagnostic(&self.context, condition);
continue;
}
};
continue;
}
};
if let Some(builder) = self
.context
.report_lint(&STATIC_ASSERT_ERROR, call_expression)
{
if !truthiness.is_always_true() {
if let Some(message) = message
.and_then(Type::into_string_literal)
.map(|s| &**s.value(self.db()))
{
builder.into_diagnostic(format_args!(
"Static assertion error: {message}"
));
} else if *parameter_ty
== Type::BooleanLiteral(false)
{
builder.into_diagnostic(
"Static assertion error: \
if let Some(builder) = self
.context
.report_lint(&STATIC_ASSERT_ERROR, call_expression)
{
if !truthiness.is_always_true() {
if let Some(message) = message
.and_then(Type::into_string_literal)
.map(|s| &**s.value(self.db()))
{
builder.into_diagnostic(format_args!(
"Static assertion error: {message}"
));
} else if *parameter_ty
== Type::BooleanLiteral(false)
{
builder.into_diagnostic(
"Static assertion error: \
argument evaluates to `False`",
);
} else if truthiness.is_always_false() {
builder.into_diagnostic(format_args!(
"Static assertion error: \
);
} else if truthiness.is_always_false() {
builder.into_diagnostic(format_args!(
"Static assertion error: \
argument of type `{parameter_ty}` \
is statically known to be falsy",
parameter_ty =
parameter_ty.display(self.db())
));
} else {
builder.into_diagnostic(format_args!(
"Static assertion error: \
parameter_ty =
parameter_ty.display(self.db())
));
} else {
builder.into_diagnostic(format_args!(
"Static assertion error: \
argument of type `{parameter_ty}` \
has an ambiguous static truthiness",
parameter_ty =
parameter_ty.display(self.db())
parameter_ty =
parameter_ty.display(self.db())
));
}
}
}
}
}
KnownFunction::Cast => {
if let [Some(casted_type), Some(source_type)] =
overload.parameter_types()
{
let db = self.db();
if (source_type.is_equivalent_to(db, *casted_type)
|| source_type.normalized(db)
== casted_type.normalized(db))
&& !source_type.contains_todo(db)
{
if let Some(builder) = self
.context
.report_lint(&REDUNDANT_CAST, call_expression)
{
builder.into_diagnostic(format_args!(
"Value is already of type `{}`",
casted_type.display(db),
));
}
}
}
}
}
KnownFunction::Cast => {
if let [Some(casted_type), Some(source_type)] =
overload.parameter_types()
{
let db = self.db();
if (source_type.is_equivalent_to(db, *casted_type)
|| source_type.normalized(db)
== casted_type.normalized(db))
&& !source_type.contains_todo(db)
KnownFunction::GetProtocolMembers => {
if let [Some(Type::ClassLiteral(class))] =
overload.parameter_types()
{
if let Some(builder) = self
.context
.report_lint(&REDUNDANT_CAST, call_expression)
{
builder.into_diagnostic(format_args!(
"Value is already of type `{}`",
casted_type.display(db),
));
if !class.is_protocol(self.db()) {
report_bad_argument_to_get_protocol_members(
&self.context,
call_expression,
*class,
);
}
}
}
}
KnownFunction::GetProtocolMembers => {
if let [Some(Type::ClassLiteral(class))] =
overload.parameter_types()
{
if !class.is_protocol(self.db()) {
report_bad_argument_to_get_protocol_members(
&self.context,
call_expression,
*class,
);
}
}
}
KnownFunction::IsInstance | KnownFunction::IsSubclass => {
if let [_, Some(Type::ClassLiteral(class))] =
overload.parameter_types()
{
if let Some(protocol_class) =
class.into_protocol_class(self.db())
KnownFunction::IsInstance | KnownFunction::IsSubclass => {
if let [_, Some(Type::ClassLiteral(class))] =
overload.parameter_types()
{
if !protocol_class.is_runtime_checkable(self.db()) {
report_runtime_check_against_non_runtime_checkable_protocol(
if let Some(protocol_class) =
class.into_protocol_class(self.db())
{
if !protocol_class.is_runtime_checkable(self.db()) {
report_runtime_check_against_non_runtime_checkable_protocol(
&self.context,
call_expression,
protocol_class,
known_function
);
}
}
}
}
_ => {}
}
_ => {}
}
}
Type::ClassLiteral(class) => {
let Some(known_class) = class.known(self.db()) else {
continue;
};
Type::ClassLiteral(class) => {
let Some(known_class) = class.known(self.db()) else {
continue;
};
match known_class {
KnownClass::Super => {
// Handle the case where `super()` is called with no arguments.
// In this case, we need to infer the two arguments:
// 1. The nearest enclosing class
// 2. The first parameter of the current function (typically `self` or `cls`)
match overload.parameter_types() {
[] => {
let scope = self.scope();
match known_class {
KnownClass::Super => {
// Handle the case where `super()` is called with no arguments.
// In this case, we need to infer the two arguments:
// 1. The nearest enclosing class
// 2. The first parameter of the current function (typically `self` or `cls`)
match overload.parameter_types() {
[] => {
let scope = self.scope();
let Some(enclosing_class) =
self.enclosing_class_symbol(scope)
else {
overload.set_return_type(Type::unknown());
BoundSuperError::UnavailableImplicitArguments
.report_diagnostic(
let Some(enclosing_class) =
self.enclosing_class_symbol(scope)
else {
overload.set_return_type(Type::unknown());
BoundSuperError::UnavailableImplicitArguments
.report_diagnostic(
&self.context,
call_expression.into(),
);
continue;
};
let Some(first_param) =
self.first_param_type_in_scope(scope)
else {
overload.set_return_type(Type::unknown());
BoundSuperError::UnavailableImplicitArguments
.report_diagnostic(
&self.context,
call_expression.into(),
);
continue;
};
let bound_super = BoundSuperType::build(
self.db(),
enclosing_class,
first_param,
)
.unwrap_or_else(|err| {
err.report_diagnostic(
&self.context,
call_expression.into(),
);
continue;
};
Type::unknown()
});
let Some(first_param) =
self.first_param_type_in_scope(scope)
else {
overload.set_return_type(Type::unknown());
BoundSuperError::UnavailableImplicitArguments
.report_diagnostic(
&self.context,
call_expression.into(),
);
continue;
};
let bound_super = BoundSuperType::build(
self.db(),
enclosing_class,
first_param,
)
.unwrap_or_else(|err| {
err.report_diagnostic(
&self.context,
call_expression.into(),
);
Type::unknown()
});
overload.set_return_type(bound_super);
}
[Some(pivot_class_type), Some(owner_type)] => {
let bound_super = BoundSuperType::build(
self.db(),
*pivot_class_type,
*owner_type,
)
.unwrap_or_else(|err| {
err.report_diagnostic(
&self.context,
call_expression.into(),
);
Type::unknown()
});
overload.set_return_type(bound_super);
}
_ => (),
}
}
KnownClass::TypeVar => {
let assigned_to = (self.index)
.try_expression(call_expression_node)
.and_then(|expr| expr.assigned_to(self.db()));
let Some(target) =
assigned_to.as_ref().and_then(|assigned_to| {
match assigned_to.node().targets.as_slice() {
[ast::Expr::Name(target)] => Some(target),
_ => None,
overload.set_return_type(bound_super);
}
})
else {
if let Some(builder) = self.context.report_lint(
&INVALID_LEGACY_TYPE_VARIABLE,
call_expression,
) {
builder.into_diagnostic(format_args!(
[Some(pivot_class_type), Some(owner_type)] => {
let bound_super = BoundSuperType::build(
self.db(),
*pivot_class_type,
*owner_type,
)
.unwrap_or_else(|err| {
err.report_diagnostic(
&self.context,
call_expression.into(),
);
Type::unknown()
});
overload.set_return_type(bound_super);
}
_ => (),
}
}
KnownClass::TypeVar => {
let assigned_to = (self.index)
.try_expression(call_expression_node)
.and_then(|expr| expr.assigned_to(self.db()));
let Some(target) =
assigned_to.as_ref().and_then(|assigned_to| {
match assigned_to.node().targets.as_slice() {
[ast::Expr::Name(target)] => Some(target),
_ => None,
}
})
else {
if let Some(builder) = self.context.report_lint(
&INVALID_LEGACY_TYPE_VARIABLE,
call_expression,
) {
builder.into_diagnostic(format_args!(
"A legacy `typing.TypeVar` must be immediately assigned to a variable",
));
}
continue;
};
}
continue;
};
let [Some(name_param), constraints, bound, default, _contravariant, _covariant, _infer_variance] =
overload.parameter_types()
else {
continue;
};
let [Some(name_param), constraints, bound, default, _contravariant, _covariant, _infer_variance] =
overload.parameter_types()
else {
continue;
};
let name_param = name_param
.into_string_literal()
.map(|name| name.value(self.db()).as_ref());
if name_param.is_none_or(|name_param| name_param != target.id) {
if let Some(builder) = self.context.report_lint(
&INVALID_LEGACY_TYPE_VARIABLE,
call_expression,
) {
builder.into_diagnostic(format_args!(
let name_param = name_param
.into_string_literal()
.map(|name| name.value(self.db()).as_ref());
if name_param
.is_none_or(|name_param| name_param != target.id)
{
if let Some(builder) = self.context.report_lint(
&INVALID_LEGACY_TYPE_VARIABLE,
call_expression,
) {
builder.into_diagnostic(format_args!(
"The name of a legacy `typing.TypeVar`{} must match \
the name of the variable it is assigned to (`{}`)",
if let Some(name_param) = name_param {
@ -4937,60 +4940,63 @@ impl<'db> TypeInferenceBuilder<'db> {
},
target.id,
));
}
continue;
}
continue;
let bound_or_constraint = match (bound, constraints) {
(Some(bound), None) => {
Some(TypeVarBoundOrConstraints::UpperBound(*bound))
}
(None, Some(_constraints)) => {
// We don't use UnionType::from_elements or UnionBuilder here,
// because we don't want to simplify the list of constraints like
// we do with the elements of an actual union type.
// TODO: Consider using a new `OneOfType` connective here instead,
// since that more accurately represents the actual semantics of
// typevar constraints.
let elements = UnionType::new(
self.db(),
overload
.arguments_for_parameter(
&call_argument_types,
1,
)
.map(|(_, ty)| ty)
.collect::<Box<_>>(),
);
Some(TypeVarBoundOrConstraints::Constraints(
elements,
))
}
// TODO: Emit a diagnostic that TypeVar cannot be both bounded and
// constrained
(Some(_), Some(_)) => continue,
(None, None) => None,
};
let containing_assignment =
self.index.expect_single_definition(target);
overload.set_return_type(Type::KnownInstance(
KnownInstanceType::TypeVar(TypeVarInstance::new(
self.db(),
target.id.clone(),
containing_assignment,
bound_or_constraint,
*default,
TypeVarKind::Legacy,
)),
));
}
let bound_or_constraint = match (bound, constraints) {
(Some(bound), None) => {
Some(TypeVarBoundOrConstraints::UpperBound(*bound))
}
(None, Some(_constraints)) => {
// We don't use UnionType::from_elements or UnionBuilder here,
// because we don't want to simplify the list of constraints like
// we do with the elements of an actual union type.
// TODO: Consider using a new `OneOfType` connective here instead,
// since that more accurately represents the actual semantics of
// typevar constraints.
let elements = UnionType::new(
self.db(),
overload
.arguments_for_parameter(
&call_argument_types,
1,
)
.map(|(_, ty)| ty)
.collect::<Box<_>>(),
);
Some(TypeVarBoundOrConstraints::Constraints(elements))
}
// TODO: Emit a diagnostic that TypeVar cannot be both bounded and
// constrained
(Some(_), Some(_)) => continue,
(None, None) => None,
};
let containing_assignment =
self.index.expect_single_definition(target);
overload.set_return_type(Type::KnownInstance(
KnownInstanceType::TypeVar(TypeVarInstance::new(
self.db(),
target.id.clone(),
containing_assignment,
bound_or_constraint,
*default,
TypeVarKind::Legacy,
)),
));
_ => (),
}
_ => (),
}
_ => (),
}
_ => (),
}
}
bindings.return_type(self.db())
@ -6637,7 +6643,8 @@ impl<'db> TypeInferenceBuilder<'db> {
.next()
.expect("valid bindings should have one callable");
let (_, overload) = callable
.matching_overload()
.matching_overloads()
.next()
.expect("valid bindings should have matching overload");
let specialization = generic_context.specialize(
self.db(),