[red-knot] Update call binding to return all matching overloads (#17618)

## Summary

This PR updates the existing overload matching methods to return an
iterator of all the matched overloads instead.

This would be useful once the overload call evaluation algorithm is
implemented which should provide an accurate picture of all the matched
overloads. The return type would then be picked from either the only
matched overload or the first overload from the ones that are matched.

In an earlier version of this PR, it tried to check if using an
intersection of return types from the matched overload would help reduce
the false positives but that's not enough. [This
comment](https://github.com/astral-sh/ruff/pull/17618#issuecomment-2842891696)
keep the ecosystem analysis for that change for prosperity.

> [!NOTE]
>
> The best way to review this PR is by hiding the whitespace changes
because there are two instances where a large match expression is
indented to be inside a loop over matching overlods
>
> <img width="1207" alt="Screenshot 2025-04-28 at 15 12 16"
src="https://github.com/user-attachments/assets/e06cbfa4-04fa-435f-84ef-4e5c3c5626d1"
/>

## Test Plan

Make sure existing test cases are unaffected and no ecosystem changes.
This commit is contained in:
Dhruv Manilawala 2025-05-01 01:33:21 +05:30 committed by GitHub
parent 6e765b4527
commit d2a238dfad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 903 additions and 858 deletions

View file

@ -4527,27 +4527,43 @@ impl<'db> Type<'db> {
new_call_outcome @ (None | Some(Ok(_))),
init_call_outcome @ (None | Some(Ok(_))),
) => {
fn combine_specializations<'db>(
db: &'db dyn Db,
s1: Option<Specialization<'db>>,
s2: Option<Specialization<'db>>,
) -> Option<Specialization<'db>> {
match (s1, s2) {
(None, None) => None,
(Some(s), None) | (None, Some(s)) => Some(s),
(Some(s1), Some(s2)) => Some(s1.combine(db, s2)),
}
}
fn combine_binding_specialization<'db>(
db: &'db dyn Db,
binding: &CallableBinding<'db>,
) -> Option<Specialization<'db>> {
binding
.matching_overloads()
.map(|(_, binding)| binding.inherited_specialization())
.reduce(|acc, specialization| {
combine_specializations(db, acc, specialization)
})
.flatten()
}
let new_specialization = new_call_outcome
.and_then(Result::ok)
.as_ref()
.and_then(Bindings::single_element)
.and_then(CallableBinding::matching_overload)
.and_then(|(_, binding)| binding.inherited_specialization());
.and_then(|binding| combine_binding_specialization(db, binding));
let init_specialization = init_call_outcome
.and_then(Result::ok)
.as_ref()
.and_then(Bindings::single_element)
.and_then(CallableBinding::matching_overload)
.and_then(|(_, binding)| binding.inherited_specialization());
let specialization = match (new_specialization, init_specialization) {
(None, None) => None,
(Some(specialization), None) | (None, Some(specialization)) => {
Some(specialization)
}
(Some(new_specialization), Some(init_specialization)) => {
Some(new_specialization.combine(db, init_specialization))
}
};
.and_then(|binding| combine_binding_specialization(db, binding));
let specialization =
combine_specializations(db, new_specialization, init_specialization);
let specialized = specialization
.map(|specialization| {
Type::instance(

View file

@ -221,25 +221,24 @@ impl<'db> Bindings<'db> {
// Each special case listed here should have a corresponding clause in `Type::signatures`.
for (binding, callable_signature) in self.elements.iter_mut().zip(self.signatures.iter()) {
let binding_type = binding.callable_type;
let Some((overload_index, overload)) = binding.matching_overload_mut() else {
continue;
};
for (overload_index, overload) in binding.matching_overloads_mut() {
match binding_type {
Type::MethodWrapper(MethodWrapperKind::FunctionTypeDunderGet(function)) => {
if function.has_known_decorator(db, FunctionDecorators::CLASSMETHOD) {
match overload.parameter_types() {
[_, Some(owner)] => {
overload.set_return_type(Type::BoundMethod(BoundMethodType::new(
db, function, *owner,
)));
overload.set_return_type(Type::BoundMethod(
BoundMethodType::new(db, function, *owner),
));
}
[Some(instance), None] => {
overload.set_return_type(Type::BoundMethod(BoundMethodType::new(
overload.set_return_type(Type::BoundMethod(
BoundMethodType::new(
db,
function,
instance.to_meta_type(db),
)));
),
));
}
_ => {}
}
@ -295,22 +294,24 @@ impl<'db> Bindings<'db> {
}
}
Type::WrapperDescriptor(WrapperDescriptorKind::PropertyDunderGet) => match overload
.parameter_types()
{
Type::WrapperDescriptor(WrapperDescriptorKind::PropertyDunderGet) => {
match overload.parameter_types() {
[Some(property @ Type::PropertyInstance(_)), Some(instance), ..]
if instance.is_none(db) =>
{
overload.set_return_type(*property);
}
[Some(Type::PropertyInstance(property)), Some(Type::KnownInstance(KnownInstanceType::TypeAliasType(type_alias))), ..]
[Some(Type::PropertyInstance(property)), Some(Type::KnownInstance(KnownInstanceType::TypeAliasType(
type_alias,
))), ..]
if property.getter(db).is_some_and(|getter| {
getter
.into_function_literal()
.is_some_and(|f| f.name(db) == "__name__")
}) =>
{
overload.set_return_type(Type::string_literal(db, type_alias.name(db)));
overload
.set_return_type(Type::string_literal(db, type_alias.name(db)));
}
[Some(Type::PropertyInstance(property)), Some(Type::KnownInstance(KnownInstanceType::TypeVar(typevar))), ..] => {
match property
@ -319,12 +320,16 @@ impl<'db> Bindings<'db> {
.map(|f| f.name(db).as_str())
{
Some("__name__") => {
overload
.set_return_type(Type::string_literal(db, typevar.name(db)));
overload.set_return_type(Type::string_literal(
db,
typevar.name(db),
));
}
Some("__bound__") => {
overload.set_return_type(
typevar.upper_bound(db).unwrap_or_else(|| Type::none(db)),
typevar
.upper_bound(db)
.unwrap_or_else(|| Type::none(db)),
);
}
Some("__constraints__") => {
@ -346,7 +351,10 @@ impl<'db> Bindings<'db> {
[Some(Type::PropertyInstance(property)), Some(instance), ..] => {
if let Some(getter) = property.getter(db) {
if let Ok(return_ty) = getter
.try_call(db, &mut CallArgumentTypes::positional([*instance]))
.try_call(
db,
&mut CallArgumentTypes::positional([*instance]),
)
.map(|binding| binding.return_type(db))
{
overload.set_return_type(return_ty);
@ -357,14 +365,15 @@ impl<'db> Bindings<'db> {
overload.set_return_type(Type::unknown());
}
} else {
overload
.errors
.push(BindingError::InternalCallError("property has no getter"));
overload.errors.push(BindingError::InternalCallError(
"property has no getter",
));
overload.set_return_type(Type::Never);
}
}
_ => {}
},
}
}
Type::MethodWrapper(MethodWrapperKind::PropertyDunderGet(property)) => {
match overload.parameter_types() {
@ -374,7 +383,10 @@ impl<'db> Bindings<'db> {
[Some(instance), ..] => {
if let Some(getter) = property.getter(db) {
if let Ok(return_ty) = getter
.try_call(db, &mut CallArgumentTypes::positional([*instance]))
.try_call(
db,
&mut CallArgumentTypes::positional([*instance]),
)
.map(|binding| binding.return_type(db))
{
overload.set_return_type(return_ty);
@ -409,9 +421,9 @@ impl<'db> Bindings<'db> {
));
}
} else {
overload
.errors
.push(BindingError::InternalCallError("property has no setter"));
overload.errors.push(BindingError::InternalCallError(
"property has no setter",
));
}
}
}
@ -428,9 +440,9 @@ impl<'db> Bindings<'db> {
));
}
} else {
overload
.errors
.push(BindingError::InternalCallError("property has no setter"));
overload.errors.push(BindingError::InternalCallError(
"property has no setter",
));
}
}
}
@ -446,7 +458,8 @@ impl<'db> Bindings<'db> {
}
Type::DataclassTransformer(params) => {
if let [Some(Type::FunctionLiteral(function))] = overload.parameter_types() {
if let [Some(Type::FunctionLiteral(function))] = overload.parameter_types()
{
overload.set_return_type(Type::FunctionLiteral(
function.with_dataclass_transformer_params(db, params),
));
@ -539,7 +552,8 @@ impl<'db> Bindings<'db> {
Some(KnownFunction::IsFullyStatic) => {
if let [Some(ty)] = overload.parameter_types() {
overload.set_return_type(Type::BooleanLiteral(ty.is_fully_static(db)));
overload
.set_return_type(Type::BooleanLiteral(ty.is_fully_static(db)));
}
}
@ -551,7 +565,8 @@ impl<'db> Bindings<'db> {
Some(KnownFunction::IsSingleValued) => {
if let [Some(ty)] = overload.parameter_types() {
overload.set_return_type(Type::BooleanLiteral(ty.is_single_valued(db)));
overload
.set_return_type(Type::BooleanLiteral(ty.is_single_valued(db)));
}
}
@ -649,7 +664,8 @@ impl<'db> Bindings<'db> {
Type::Never
};
let union_with_default = |ty| UnionType::from_elements(db, [ty, default]);
let union_with_default =
|ty| UnionType::from_elements(db, [ty, default]);
// TODO: we could emit a diagnostic here (if default is not set)
overload.set_return_type(
@ -759,7 +775,8 @@ impl<'db> Bindings<'db> {
dataclass_params.set(DataclassParams::ORDER, *order);
}
overload.set_return_type(Type::DataclassDecorator(dataclass_params));
overload
.set_return_type(Type::DataclassDecorator(dataclass_params));
}
}
},
@ -805,6 +822,7 @@ impl<'db> Bindings<'db> {
}
}
}
}
}
impl<'a, 'db> IntoIterator for &'a Bindings<'db> {
@ -868,7 +886,7 @@ impl<'db> CallableBinding<'db> {
// the matching overloads. Make sure to implement that as part of separating call binding into
// two phases.
//
// [1] https://github.com/python/typing/pull/1839
// [1] https://typing.python.org/en/latest/spec/overload.html#overload-call-evaluation
let overloads = signature
.into_iter()
.map(|signature| {
@ -928,35 +946,39 @@ impl<'db> CallableBinding<'db> {
/// Returns whether there were any errors binding this call site. If the callable has multiple
/// overloads, they must _all_ have errors.
pub(crate) fn has_binding_errors(&self) -> bool {
self.matching_overload().is_none()
self.matching_overloads().next().is_none()
}
/// Returns the overload that matched for this call binding. Returns `None` if none of the
/// overloads matched.
pub(crate) fn matching_overload(&self) -> Option<(usize, &Binding<'db>)> {
/// Returns an iterator over all the overloads that matched for this call binding.
pub(crate) fn matching_overloads(&self) -> impl Iterator<Item = (usize, &Binding<'db>)> {
self.overloads
.iter()
.enumerate()
.find(|(_, overload)| overload.as_result().is_ok())
.filter(|(_, overload)| overload.as_result().is_ok())
}
/// Returns the overload that matched for this call binding. Returns `None` if none of the
/// overloads matched.
pub(crate) fn matching_overload_mut(&mut self) -> Option<(usize, &mut Binding<'db>)> {
/// Returns an iterator over all the mutable overloads that matched for this call binding.
pub(crate) fn matching_overloads_mut(
&mut self,
) -> impl Iterator<Item = (usize, &mut Binding<'db>)> {
self.overloads
.iter_mut()
.enumerate()
.find(|(_, overload)| overload.as_result().is_ok())
.filter(|(_, overload)| overload.as_result().is_ok())
}
/// Returns the return type of this call. For a valid call, this is the return type of the
/// overload that the arguments matched against. For an invalid call to a non-overloaded
/// first overload that the arguments matched against. For an invalid call to a non-overloaded
/// function, this is the return type of the function. For an invalid call to an overloaded
/// function, we return `Type::unknown`, since we cannot make any useful conclusions about
/// which overload was intended to be called.
pub(crate) fn return_type(&self) -> Type<'db> {
if let Some((_, overload)) = self.matching_overload() {
return overload.return_type();
// TODO: Implement the overload call evaluation algorithm as mentioned in the spec [1] to
// get the matching overload and use that to get the return type.
//
// [1]: https://typing.python.org/en/latest/spec/overload.html#overload-call-evaluation
if let Some((_, first_overload)) = self.matching_overloads().next() {
return first_overload.return_type();
}
if let [overload] = self.overloads.as_slice() {
return overload.return_type();

View file

@ -4627,10 +4627,7 @@ impl<'db> TypeInferenceBuilder<'db> {
Ok(mut bindings) => {
for binding in &mut bindings {
let binding_type = binding.callable_type;
let Some((_, overload)) = binding.matching_overload_mut() else {
continue;
};
for (_, overload) in binding.matching_overloads_mut() {
match binding_type {
Type::FunctionLiteral(function_literal) => {
let Some(known_function) = function_literal.known(self.db()) else {
@ -4644,7 +4641,8 @@ impl<'db> TypeInferenceBuilder<'db> {
DiagnosticId::RevealedType,
Severity::Info,
) {
let mut diag = builder.into_diagnostic("Revealed type");
let mut diag =
builder.into_diagnostic("Revealed type");
let span = self.context.span(call_expression);
diag.annotate(Annotation::primary(span).message(
format_args!(
@ -4695,18 +4693,21 @@ impl<'db> TypeInferenceBuilder<'db> {
if let [Some(parameter_ty), message] =
overload.parameter_types()
{
let truthiness = match parameter_ty.try_bool(self.db()) {
let truthiness = match parameter_ty.try_bool(self.db())
{
Ok(truthiness) => truthiness,
Err(err) => {
let condition = arguments
.find_argument("condition", 0)
.map(|argument| match argument {
.map(|argument| {
match argument {
ruff_python_ast::ArgOrKeyword::Arg(
expr,
) => ast::AnyNodeRef::from(expr),
ruff_python_ast::ArgOrKeyword::Keyword(
keyword,
) => ast::AnyNodeRef::from(keyword),
}
})
.unwrap_or(ast::AnyNodeRef::from(
call_expression,
@ -4922,7 +4923,9 @@ impl<'db> TypeInferenceBuilder<'db> {
let name_param = name_param
.into_string_literal()
.map(|name| name.value(self.db()).as_ref());
if name_param.is_none_or(|name_param| name_param != target.id) {
if name_param
.is_none_or(|name_param| name_param != target.id)
{
if let Some(builder) = self.context.report_lint(
&INVALID_LEGACY_TYPE_VARIABLE,
call_expression,
@ -4963,7 +4966,9 @@ impl<'db> TypeInferenceBuilder<'db> {
.map(|(_, ty)| ty)
.collect::<Box<_>>(),
);
Some(TypeVarBoundOrConstraints::Constraints(elements))
Some(TypeVarBoundOrConstraints::Constraints(
elements,
))
}
// TODO: Emit a diagnostic that TypeVar cannot be both bounded and
@ -4993,6 +4998,7 @@ impl<'db> TypeInferenceBuilder<'db> {
_ => (),
}
}
}
bindings.return_type(self.db())
}
@ -6637,7 +6643,8 @@ impl<'db> TypeInferenceBuilder<'db> {
.next()
.expect("valid bindings should have one callable");
let (_, overload) = callable
.matching_overload()
.matching_overloads()
.next()
.expect("valid bindings should have matching overload");
let specialization = generic_context.specialize(
self.db(),