[red-knot] Update call binding to return all matching overloads (#17618)

## Summary

This PR updates the existing overload matching methods to return an
iterator of all the matched overloads instead.

This would be useful once the overload call evaluation algorithm is
implemented which should provide an accurate picture of all the matched
overloads. The return type would then be picked from either the only
matched overload or the first overload from the ones that are matched.

In an earlier version of this PR, it tried to check if using an
intersection of return types from the matched overload would help reduce
the false positives but that's not enough. [This
comment](https://github.com/astral-sh/ruff/pull/17618#issuecomment-2842891696)
keep the ecosystem analysis for that change for prosperity.

> [!NOTE]
>
> The best way to review this PR is by hiding the whitespace changes
because there are two instances where a large match expression is
indented to be inside a loop over matching overlods
>
> <img width="1207" alt="Screenshot 2025-04-28 at 15 12 16"
src="https://github.com/user-attachments/assets/e06cbfa4-04fa-435f-84ef-4e5c3c5626d1"
/>

## Test Plan

Make sure existing test cases are unaffected and no ecosystem changes.
This commit is contained in:
Dhruv Manilawala 2025-05-01 01:33:21 +05:30 committed by GitHub
parent 6e765b4527
commit d2a238dfad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 903 additions and 858 deletions

View file

@ -4527,27 +4527,43 @@ impl<'db> Type<'db> {
new_call_outcome @ (None | Some(Ok(_))), new_call_outcome @ (None | Some(Ok(_))),
init_call_outcome @ (None | Some(Ok(_))), init_call_outcome @ (None | Some(Ok(_))),
) => { ) => {
fn combine_specializations<'db>(
db: &'db dyn Db,
s1: Option<Specialization<'db>>,
s2: Option<Specialization<'db>>,
) -> Option<Specialization<'db>> {
match (s1, s2) {
(None, None) => None,
(Some(s), None) | (None, Some(s)) => Some(s),
(Some(s1), Some(s2)) => Some(s1.combine(db, s2)),
}
}
fn combine_binding_specialization<'db>(
db: &'db dyn Db,
binding: &CallableBinding<'db>,
) -> Option<Specialization<'db>> {
binding
.matching_overloads()
.map(|(_, binding)| binding.inherited_specialization())
.reduce(|acc, specialization| {
combine_specializations(db, acc, specialization)
})
.flatten()
}
let new_specialization = new_call_outcome let new_specialization = new_call_outcome
.and_then(Result::ok) .and_then(Result::ok)
.as_ref() .as_ref()
.and_then(Bindings::single_element) .and_then(Bindings::single_element)
.and_then(CallableBinding::matching_overload) .and_then(|binding| combine_binding_specialization(db, binding));
.and_then(|(_, binding)| binding.inherited_specialization());
let init_specialization = init_call_outcome let init_specialization = init_call_outcome
.and_then(Result::ok) .and_then(Result::ok)
.as_ref() .as_ref()
.and_then(Bindings::single_element) .and_then(Bindings::single_element)
.and_then(CallableBinding::matching_overload) .and_then(|binding| combine_binding_specialization(db, binding));
.and_then(|(_, binding)| binding.inherited_specialization()); let specialization =
let specialization = match (new_specialization, init_specialization) { combine_specializations(db, new_specialization, init_specialization);
(None, None) => None,
(Some(specialization), None) | (None, Some(specialization)) => {
Some(specialization)
}
(Some(new_specialization), Some(init_specialization)) => {
Some(new_specialization.combine(db, init_specialization))
}
};
let specialized = specialization let specialized = specialization
.map(|specialization| { .map(|specialization| {
Type::instance( Type::instance(

File diff suppressed because it is too large Load diff

View file

@ -4627,307 +4627,310 @@ impl<'db> TypeInferenceBuilder<'db> {
Ok(mut bindings) => { Ok(mut bindings) => {
for binding in &mut bindings { for binding in &mut bindings {
let binding_type = binding.callable_type; let binding_type = binding.callable_type;
let Some((_, overload)) = binding.matching_overload_mut() else { for (_, overload) in binding.matching_overloads_mut() {
continue; match binding_type {
}; Type::FunctionLiteral(function_literal) => {
let Some(known_function) = function_literal.known(self.db()) else {
continue;
};
match binding_type { match known_function {
Type::FunctionLiteral(function_literal) => { KnownFunction::RevealType => {
let Some(known_function) = function_literal.known(self.db()) else { if let [Some(revealed_type)] = overload.parameter_types() {
continue; if let Some(builder) = self.context.report_diagnostic(
}; DiagnosticId::RevealedType,
Severity::Info,
match known_function { ) {
KnownFunction::RevealType => { let mut diag =
if let [Some(revealed_type)] = overload.parameter_types() { builder.into_diagnostic("Revealed type");
if let Some(builder) = self.context.report_diagnostic( let span = self.context.span(call_expression);
DiagnosticId::RevealedType, diag.annotate(Annotation::primary(span).message(
Severity::Info, format_args!(
) { "`{}`",
let mut diag = builder.into_diagnostic("Revealed type"); revealed_type.display(self.db())
let span = self.context.span(call_expression); ),
diag.annotate(Annotation::primary(span).message( ));
format_args!( }
"`{}`",
revealed_type.display(self.db())
),
));
} }
} }
} KnownFunction::AssertType => {
KnownFunction::AssertType => { if let [Some(actual_ty), Some(asserted_ty)] =
if let [Some(actual_ty), Some(asserted_ty)] = overload.parameter_types()
overload.parameter_types()
{
if !actual_ty
.is_gradual_equivalent_to(self.db(), *asserted_ty)
{ {
if let Some(builder) = self.context.report_lint( if !actual_ty
&TYPE_ASSERTION_FAILURE, .is_gradual_equivalent_to(self.db(), *asserted_ty)
call_expression, {
) { if let Some(builder) = self.context.report_lint(
builder.into_diagnostic(format_args!( &TYPE_ASSERTION_FAILURE,
"Actual type `{}` is not the same \ call_expression,
) {
builder.into_diagnostic(format_args!(
"Actual type `{}` is not the same \
as asserted type `{}`", as asserted type `{}`",
actual_ty.display(self.db()), actual_ty.display(self.db()),
asserted_ty.display(self.db()), asserted_ty.display(self.db()),
)); ));
}
} }
} }
} }
} KnownFunction::AssertNever => {
KnownFunction::AssertNever => { if let [Some(actual_ty)] = overload.parameter_types() {
if let [Some(actual_ty)] = overload.parameter_types() { if !actual_ty.is_equivalent_to(self.db(), Type::Never) {
if !actual_ty.is_equivalent_to(self.db(), Type::Never) { if let Some(builder) = self.context.report_lint(
if let Some(builder) = self.context.report_lint( &TYPE_ASSERTION_FAILURE,
&TYPE_ASSERTION_FAILURE, call_expression,
call_expression, ) {
) { builder.into_diagnostic(format_args!(
builder.into_diagnostic(format_args!( "Expected type `Never`, got `{}` instead",
"Expected type `Never`, got `{}` instead", actual_ty.display(self.db()),
actual_ty.display(self.db()), ));
)); }
} }
} }
} }
} KnownFunction::StaticAssert => {
KnownFunction::StaticAssert => { if let [Some(parameter_ty), message] =
if let [Some(parameter_ty), message] = overload.parameter_types()
overload.parameter_types() {
{ let truthiness = match parameter_ty.try_bool(self.db())
let truthiness = match parameter_ty.try_bool(self.db()) { {
Ok(truthiness) => truthiness, Ok(truthiness) => truthiness,
Err(err) => { Err(err) => {
let condition = arguments let condition = arguments
.find_argument("condition", 0) .find_argument("condition", 0)
.map(|argument| match argument { .map(|argument| {
match argument {
ruff_python_ast::ArgOrKeyword::Arg( ruff_python_ast::ArgOrKeyword::Arg(
expr, expr,
) => ast::AnyNodeRef::from(expr), ) => ast::AnyNodeRef::from(expr),
ruff_python_ast::ArgOrKeyword::Keyword( ruff_python_ast::ArgOrKeyword::Keyword(
keyword, keyword,
) => ast::AnyNodeRef::from(keyword), ) => ast::AnyNodeRef::from(keyword),
}) }
.unwrap_or(ast::AnyNodeRef::from( })
call_expression, .unwrap_or(ast::AnyNodeRef::from(
)); call_expression,
));
err.report_diagnostic(&self.context, condition); err.report_diagnostic(&self.context, condition);
continue; continue;
} }
}; };
if let Some(builder) = self if let Some(builder) = self
.context .context
.report_lint(&STATIC_ASSERT_ERROR, call_expression) .report_lint(&STATIC_ASSERT_ERROR, call_expression)
{ {
if !truthiness.is_always_true() { if !truthiness.is_always_true() {
if let Some(message) = message if let Some(message) = message
.and_then(Type::into_string_literal) .and_then(Type::into_string_literal)
.map(|s| &**s.value(self.db())) .map(|s| &**s.value(self.db()))
{ {
builder.into_diagnostic(format_args!( builder.into_diagnostic(format_args!(
"Static assertion error: {message}" "Static assertion error: {message}"
)); ));
} else if *parameter_ty } else if *parameter_ty
== Type::BooleanLiteral(false) == Type::BooleanLiteral(false)
{ {
builder.into_diagnostic( builder.into_diagnostic(
"Static assertion error: \ "Static assertion error: \
argument evaluates to `False`", argument evaluates to `False`",
); );
} else if truthiness.is_always_false() { } else if truthiness.is_always_false() {
builder.into_diagnostic(format_args!( builder.into_diagnostic(format_args!(
"Static assertion error: \ "Static assertion error: \
argument of type `{parameter_ty}` \ argument of type `{parameter_ty}` \
is statically known to be falsy", is statically known to be falsy",
parameter_ty = parameter_ty =
parameter_ty.display(self.db()) parameter_ty.display(self.db())
)); ));
} else { } else {
builder.into_diagnostic(format_args!( builder.into_diagnostic(format_args!(
"Static assertion error: \ "Static assertion error: \
argument of type `{parameter_ty}` \ argument of type `{parameter_ty}` \
has an ambiguous static truthiness", has an ambiguous static truthiness",
parameter_ty = parameter_ty =
parameter_ty.display(self.db()) parameter_ty.display(self.db())
));
}
}
}
}
}
KnownFunction::Cast => {
if let [Some(casted_type), Some(source_type)] =
overload.parameter_types()
{
let db = self.db();
if (source_type.is_equivalent_to(db, *casted_type)
|| source_type.normalized(db)
== casted_type.normalized(db))
&& !source_type.contains_todo(db)
{
if let Some(builder) = self
.context
.report_lint(&REDUNDANT_CAST, call_expression)
{
builder.into_diagnostic(format_args!(
"Value is already of type `{}`",
casted_type.display(db),
)); ));
} }
} }
} }
} }
} KnownFunction::GetProtocolMembers => {
KnownFunction::Cast => { if let [Some(Type::ClassLiteral(class))] =
if let [Some(casted_type), Some(source_type)] = overload.parameter_types()
overload.parameter_types()
{
let db = self.db();
if (source_type.is_equivalent_to(db, *casted_type)
|| source_type.normalized(db)
== casted_type.normalized(db))
&& !source_type.contains_todo(db)
{ {
if let Some(builder) = self if !class.is_protocol(self.db()) {
.context report_bad_argument_to_get_protocol_members(
.report_lint(&REDUNDANT_CAST, call_expression) &self.context,
{ call_expression,
builder.into_diagnostic(format_args!( *class,
"Value is already of type `{}`", );
casted_type.display(db),
));
} }
} }
} }
} KnownFunction::IsInstance | KnownFunction::IsSubclass => {
KnownFunction::GetProtocolMembers => { if let [_, Some(Type::ClassLiteral(class))] =
if let [Some(Type::ClassLiteral(class))] = overload.parameter_types()
overload.parameter_types()
{
if !class.is_protocol(self.db()) {
report_bad_argument_to_get_protocol_members(
&self.context,
call_expression,
*class,
);
}
}
}
KnownFunction::IsInstance | KnownFunction::IsSubclass => {
if let [_, Some(Type::ClassLiteral(class))] =
overload.parameter_types()
{
if let Some(protocol_class) =
class.into_protocol_class(self.db())
{ {
if !protocol_class.is_runtime_checkable(self.db()) { if let Some(protocol_class) =
report_runtime_check_against_non_runtime_checkable_protocol( class.into_protocol_class(self.db())
{
if !protocol_class.is_runtime_checkable(self.db()) {
report_runtime_check_against_non_runtime_checkable_protocol(
&self.context, &self.context,
call_expression, call_expression,
protocol_class, protocol_class,
known_function known_function
); );
}
} }
} }
} }
_ => {}
} }
_ => {}
} }
}
Type::ClassLiteral(class) => { Type::ClassLiteral(class) => {
let Some(known_class) = class.known(self.db()) else { let Some(known_class) = class.known(self.db()) else {
continue; continue;
}; };
match known_class { match known_class {
KnownClass::Super => { KnownClass::Super => {
// Handle the case where `super()` is called with no arguments. // Handle the case where `super()` is called with no arguments.
// In this case, we need to infer the two arguments: // In this case, we need to infer the two arguments:
// 1. The nearest enclosing class // 1. The nearest enclosing class
// 2. The first parameter of the current function (typically `self` or `cls`) // 2. The first parameter of the current function (typically `self` or `cls`)
match overload.parameter_types() { match overload.parameter_types() {
[] => { [] => {
let scope = self.scope(); let scope = self.scope();
let Some(enclosing_class) = let Some(enclosing_class) =
self.enclosing_class_symbol(scope) self.enclosing_class_symbol(scope)
else { else {
overload.set_return_type(Type::unknown()); overload.set_return_type(Type::unknown());
BoundSuperError::UnavailableImplicitArguments BoundSuperError::UnavailableImplicitArguments
.report_diagnostic( .report_diagnostic(
&self.context,
call_expression.into(),
);
continue;
};
let Some(first_param) =
self.first_param_type_in_scope(scope)
else {
overload.set_return_type(Type::unknown());
BoundSuperError::UnavailableImplicitArguments
.report_diagnostic(
&self.context,
call_expression.into(),
);
continue;
};
let bound_super = BoundSuperType::build(
self.db(),
enclosing_class,
first_param,
)
.unwrap_or_else(|err| {
err.report_diagnostic(
&self.context, &self.context,
call_expression.into(), call_expression.into(),
); );
continue; Type::unknown()
}; });
let Some(first_param) = overload.set_return_type(bound_super);
self.first_param_type_in_scope(scope)
else {
overload.set_return_type(Type::unknown());
BoundSuperError::UnavailableImplicitArguments
.report_diagnostic(
&self.context,
call_expression.into(),
);
continue;
};
let bound_super = BoundSuperType::build(
self.db(),
enclosing_class,
first_param,
)
.unwrap_or_else(|err| {
err.report_diagnostic(
&self.context,
call_expression.into(),
);
Type::unknown()
});
overload.set_return_type(bound_super);
}
[Some(pivot_class_type), Some(owner_type)] => {
let bound_super = BoundSuperType::build(
self.db(),
*pivot_class_type,
*owner_type,
)
.unwrap_or_else(|err| {
err.report_diagnostic(
&self.context,
call_expression.into(),
);
Type::unknown()
});
overload.set_return_type(bound_super);
}
_ => (),
}
}
KnownClass::TypeVar => {
let assigned_to = (self.index)
.try_expression(call_expression_node)
.and_then(|expr| expr.assigned_to(self.db()));
let Some(target) =
assigned_to.as_ref().and_then(|assigned_to| {
match assigned_to.node().targets.as_slice() {
[ast::Expr::Name(target)] => Some(target),
_ => None,
} }
}) [Some(pivot_class_type), Some(owner_type)] => {
else { let bound_super = BoundSuperType::build(
if let Some(builder) = self.context.report_lint( self.db(),
&INVALID_LEGACY_TYPE_VARIABLE, *pivot_class_type,
call_expression, *owner_type,
) { )
builder.into_diagnostic(format_args!( .unwrap_or_else(|err| {
err.report_diagnostic(
&self.context,
call_expression.into(),
);
Type::unknown()
});
overload.set_return_type(bound_super);
}
_ => (),
}
}
KnownClass::TypeVar => {
let assigned_to = (self.index)
.try_expression(call_expression_node)
.and_then(|expr| expr.assigned_to(self.db()));
let Some(target) =
assigned_to.as_ref().and_then(|assigned_to| {
match assigned_to.node().targets.as_slice() {
[ast::Expr::Name(target)] => Some(target),
_ => None,
}
})
else {
if let Some(builder) = self.context.report_lint(
&INVALID_LEGACY_TYPE_VARIABLE,
call_expression,
) {
builder.into_diagnostic(format_args!(
"A legacy `typing.TypeVar` must be immediately assigned to a variable", "A legacy `typing.TypeVar` must be immediately assigned to a variable",
)); ));
} }
continue; continue;
}; };
let [Some(name_param), constraints, bound, default, _contravariant, _covariant, _infer_variance] = let [Some(name_param), constraints, bound, default, _contravariant, _covariant, _infer_variance] =
overload.parameter_types() overload.parameter_types()
else { else {
continue; continue;
}; };
let name_param = name_param let name_param = name_param
.into_string_literal() .into_string_literal()
.map(|name| name.value(self.db()).as_ref()); .map(|name| name.value(self.db()).as_ref());
if name_param.is_none_or(|name_param| name_param != target.id) { if name_param
if let Some(builder) = self.context.report_lint( .is_none_or(|name_param| name_param != target.id)
&INVALID_LEGACY_TYPE_VARIABLE, {
call_expression, if let Some(builder) = self.context.report_lint(
) { &INVALID_LEGACY_TYPE_VARIABLE,
builder.into_diagnostic(format_args!( call_expression,
) {
builder.into_diagnostic(format_args!(
"The name of a legacy `typing.TypeVar`{} must match \ "The name of a legacy `typing.TypeVar`{} must match \
the name of the variable it is assigned to (`{}`)", the name of the variable it is assigned to (`{}`)",
if let Some(name_param) = name_param { if let Some(name_param) = name_param {
@ -4937,60 +4940,63 @@ impl<'db> TypeInferenceBuilder<'db> {
}, },
target.id, target.id,
)); ));
}
continue;
} }
continue;
let bound_or_constraint = match (bound, constraints) {
(Some(bound), None) => {
Some(TypeVarBoundOrConstraints::UpperBound(*bound))
}
(None, Some(_constraints)) => {
// We don't use UnionType::from_elements or UnionBuilder here,
// because we don't want to simplify the list of constraints like
// we do with the elements of an actual union type.
// TODO: Consider using a new `OneOfType` connective here instead,
// since that more accurately represents the actual semantics of
// typevar constraints.
let elements = UnionType::new(
self.db(),
overload
.arguments_for_parameter(
&call_argument_types,
1,
)
.map(|(_, ty)| ty)
.collect::<Box<_>>(),
);
Some(TypeVarBoundOrConstraints::Constraints(
elements,
))
}
// TODO: Emit a diagnostic that TypeVar cannot be both bounded and
// constrained
(Some(_), Some(_)) => continue,
(None, None) => None,
};
let containing_assignment =
self.index.expect_single_definition(target);
overload.set_return_type(Type::KnownInstance(
KnownInstanceType::TypeVar(TypeVarInstance::new(
self.db(),
target.id.clone(),
containing_assignment,
bound_or_constraint,
*default,
TypeVarKind::Legacy,
)),
));
} }
let bound_or_constraint = match (bound, constraints) { _ => (),
(Some(bound), None) => {
Some(TypeVarBoundOrConstraints::UpperBound(*bound))
}
(None, Some(_constraints)) => {
// We don't use UnionType::from_elements or UnionBuilder here,
// because we don't want to simplify the list of constraints like
// we do with the elements of an actual union type.
// TODO: Consider using a new `OneOfType` connective here instead,
// since that more accurately represents the actual semantics of
// typevar constraints.
let elements = UnionType::new(
self.db(),
overload
.arguments_for_parameter(
&call_argument_types,
1,
)
.map(|(_, ty)| ty)
.collect::<Box<_>>(),
);
Some(TypeVarBoundOrConstraints::Constraints(elements))
}
// TODO: Emit a diagnostic that TypeVar cannot be both bounded and
// constrained
(Some(_), Some(_)) => continue,
(None, None) => None,
};
let containing_assignment =
self.index.expect_single_definition(target);
overload.set_return_type(Type::KnownInstance(
KnownInstanceType::TypeVar(TypeVarInstance::new(
self.db(),
target.id.clone(),
containing_assignment,
bound_or_constraint,
*default,
TypeVarKind::Legacy,
)),
));
} }
_ => (),
} }
_ => (),
} }
_ => (),
} }
} }
bindings.return_type(self.db()) bindings.return_type(self.db())
@ -6637,7 +6643,8 @@ impl<'db> TypeInferenceBuilder<'db> {
.next() .next()
.expect("valid bindings should have one callable"); .expect("valid bindings should have one callable");
let (_, overload) = callable let (_, overload) = callable
.matching_overload() .matching_overloads()
.next()
.expect("valid bindings should have matching overload"); .expect("valid bindings should have matching overload");
let specialization = generic_context.specialize( let specialization = generic_context.specialize(
self.db(), self.db(),