[red-knot] Update call binding to return all matching overloads (#17618)

## Summary

This PR updates the existing overload matching methods to return an
iterator of all the matched overloads instead.

This would be useful once the overload call evaluation algorithm is
implemented which should provide an accurate picture of all the matched
overloads. The return type would then be picked from either the only
matched overload or the first overload from the ones that are matched.

In an earlier version of this PR, it tried to check if using an
intersection of return types from the matched overload would help reduce
the false positives but that's not enough. [This
comment](https://github.com/astral-sh/ruff/pull/17618#issuecomment-2842891696)
keep the ecosystem analysis for that change for prosperity.

> [!NOTE]
>
> The best way to review this PR is by hiding the whitespace changes
because there are two instances where a large match expression is
indented to be inside a loop over matching overlods
>
> <img width="1207" alt="Screenshot 2025-04-28 at 15 12 16"
src="https://github.com/user-attachments/assets/e06cbfa4-04fa-435f-84ef-4e5c3c5626d1"
/>

## Test Plan

Make sure existing test cases are unaffected and no ecosystem changes.
This commit is contained in:
Dhruv Manilawala 2025-05-01 01:33:21 +05:30 committed by GitHub
parent 6e765b4527
commit d2a238dfad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 903 additions and 858 deletions

View file

@ -4527,27 +4527,43 @@ impl<'db> Type<'db> {
new_call_outcome @ (None | Some(Ok(_))), new_call_outcome @ (None | Some(Ok(_))),
init_call_outcome @ (None | Some(Ok(_))), init_call_outcome @ (None | Some(Ok(_))),
) => { ) => {
fn combine_specializations<'db>(
db: &'db dyn Db,
s1: Option<Specialization<'db>>,
s2: Option<Specialization<'db>>,
) -> Option<Specialization<'db>> {
match (s1, s2) {
(None, None) => None,
(Some(s), None) | (None, Some(s)) => Some(s),
(Some(s1), Some(s2)) => Some(s1.combine(db, s2)),
}
}
fn combine_binding_specialization<'db>(
db: &'db dyn Db,
binding: &CallableBinding<'db>,
) -> Option<Specialization<'db>> {
binding
.matching_overloads()
.map(|(_, binding)| binding.inherited_specialization())
.reduce(|acc, specialization| {
combine_specializations(db, acc, specialization)
})
.flatten()
}
let new_specialization = new_call_outcome let new_specialization = new_call_outcome
.and_then(Result::ok) .and_then(Result::ok)
.as_ref() .as_ref()
.and_then(Bindings::single_element) .and_then(Bindings::single_element)
.and_then(CallableBinding::matching_overload) .and_then(|binding| combine_binding_specialization(db, binding));
.and_then(|(_, binding)| binding.inherited_specialization());
let init_specialization = init_call_outcome let init_specialization = init_call_outcome
.and_then(Result::ok) .and_then(Result::ok)
.as_ref() .as_ref()
.and_then(Bindings::single_element) .and_then(Bindings::single_element)
.and_then(CallableBinding::matching_overload) .and_then(|binding| combine_binding_specialization(db, binding));
.and_then(|(_, binding)| binding.inherited_specialization()); let specialization =
let specialization = match (new_specialization, init_specialization) { combine_specializations(db, new_specialization, init_specialization);
(None, None) => None,
(Some(specialization), None) | (None, Some(specialization)) => {
Some(specialization)
}
(Some(new_specialization), Some(init_specialization)) => {
Some(new_specialization.combine(db, init_specialization))
}
};
let specialized = specialization let specialized = specialization
.map(|specialization| { .map(|specialization| {
Type::instance( Type::instance(

View file

@ -221,25 +221,24 @@ impl<'db> Bindings<'db> {
// Each special case listed here should have a corresponding clause in `Type::signatures`. // Each special case listed here should have a corresponding clause in `Type::signatures`.
for (binding, callable_signature) in self.elements.iter_mut().zip(self.signatures.iter()) { for (binding, callable_signature) in self.elements.iter_mut().zip(self.signatures.iter()) {
let binding_type = binding.callable_type; let binding_type = binding.callable_type;
let Some((overload_index, overload)) = binding.matching_overload_mut() else { for (overload_index, overload) in binding.matching_overloads_mut() {
continue;
};
match binding_type { match binding_type {
Type::MethodWrapper(MethodWrapperKind::FunctionTypeDunderGet(function)) => { Type::MethodWrapper(MethodWrapperKind::FunctionTypeDunderGet(function)) => {
if function.has_known_decorator(db, FunctionDecorators::CLASSMETHOD) { if function.has_known_decorator(db, FunctionDecorators::CLASSMETHOD) {
match overload.parameter_types() { match overload.parameter_types() {
[_, Some(owner)] => { [_, Some(owner)] => {
overload.set_return_type(Type::BoundMethod(BoundMethodType::new( overload.set_return_type(Type::BoundMethod(
db, function, *owner, BoundMethodType::new(db, function, *owner),
))); ));
} }
[Some(instance), None] => { [Some(instance), None] => {
overload.set_return_type(Type::BoundMethod(BoundMethodType::new( overload.set_return_type(Type::BoundMethod(
BoundMethodType::new(
db, db,
function, function,
instance.to_meta_type(db), instance.to_meta_type(db),
))); ),
));
} }
_ => {} _ => {}
} }
@ -295,22 +294,24 @@ impl<'db> Bindings<'db> {
} }
} }
Type::WrapperDescriptor(WrapperDescriptorKind::PropertyDunderGet) => match overload Type::WrapperDescriptor(WrapperDescriptorKind::PropertyDunderGet) => {
.parameter_types() match overload.parameter_types() {
{
[Some(property @ Type::PropertyInstance(_)), Some(instance), ..] [Some(property @ Type::PropertyInstance(_)), Some(instance), ..]
if instance.is_none(db) => if instance.is_none(db) =>
{ {
overload.set_return_type(*property); overload.set_return_type(*property);
} }
[Some(Type::PropertyInstance(property)), Some(Type::KnownInstance(KnownInstanceType::TypeAliasType(type_alias))), ..] [Some(Type::PropertyInstance(property)), Some(Type::KnownInstance(KnownInstanceType::TypeAliasType(
type_alias,
))), ..]
if property.getter(db).is_some_and(|getter| { if property.getter(db).is_some_and(|getter| {
getter getter
.into_function_literal() .into_function_literal()
.is_some_and(|f| f.name(db) == "__name__") .is_some_and(|f| f.name(db) == "__name__")
}) => }) =>
{ {
overload.set_return_type(Type::string_literal(db, type_alias.name(db))); overload
.set_return_type(Type::string_literal(db, type_alias.name(db)));
} }
[Some(Type::PropertyInstance(property)), Some(Type::KnownInstance(KnownInstanceType::TypeVar(typevar))), ..] => { [Some(Type::PropertyInstance(property)), Some(Type::KnownInstance(KnownInstanceType::TypeVar(typevar))), ..] => {
match property match property
@ -319,12 +320,16 @@ impl<'db> Bindings<'db> {
.map(|f| f.name(db).as_str()) .map(|f| f.name(db).as_str())
{ {
Some("__name__") => { Some("__name__") => {
overload overload.set_return_type(Type::string_literal(
.set_return_type(Type::string_literal(db, typevar.name(db))); db,
typevar.name(db),
));
} }
Some("__bound__") => { Some("__bound__") => {
overload.set_return_type( overload.set_return_type(
typevar.upper_bound(db).unwrap_or_else(|| Type::none(db)), typevar
.upper_bound(db)
.unwrap_or_else(|| Type::none(db)),
); );
} }
Some("__constraints__") => { Some("__constraints__") => {
@ -346,7 +351,10 @@ impl<'db> Bindings<'db> {
[Some(Type::PropertyInstance(property)), Some(instance), ..] => { [Some(Type::PropertyInstance(property)), Some(instance), ..] => {
if let Some(getter) = property.getter(db) { if let Some(getter) = property.getter(db) {
if let Ok(return_ty) = getter if let Ok(return_ty) = getter
.try_call(db, &mut CallArgumentTypes::positional([*instance])) .try_call(
db,
&mut CallArgumentTypes::positional([*instance]),
)
.map(|binding| binding.return_type(db)) .map(|binding| binding.return_type(db))
{ {
overload.set_return_type(return_ty); overload.set_return_type(return_ty);
@ -357,14 +365,15 @@ impl<'db> Bindings<'db> {
overload.set_return_type(Type::unknown()); overload.set_return_type(Type::unknown());
} }
} else { } else {
overload overload.errors.push(BindingError::InternalCallError(
.errors "property has no getter",
.push(BindingError::InternalCallError("property has no getter")); ));
overload.set_return_type(Type::Never); overload.set_return_type(Type::Never);
} }
} }
_ => {} _ => {}
}, }
}
Type::MethodWrapper(MethodWrapperKind::PropertyDunderGet(property)) => { Type::MethodWrapper(MethodWrapperKind::PropertyDunderGet(property)) => {
match overload.parameter_types() { match overload.parameter_types() {
@ -374,7 +383,10 @@ impl<'db> Bindings<'db> {
[Some(instance), ..] => { [Some(instance), ..] => {
if let Some(getter) = property.getter(db) { if let Some(getter) = property.getter(db) {
if let Ok(return_ty) = getter if let Ok(return_ty) = getter
.try_call(db, &mut CallArgumentTypes::positional([*instance])) .try_call(
db,
&mut CallArgumentTypes::positional([*instance]),
)
.map(|binding| binding.return_type(db)) .map(|binding| binding.return_type(db))
{ {
overload.set_return_type(return_ty); overload.set_return_type(return_ty);
@ -409,9 +421,9 @@ impl<'db> Bindings<'db> {
)); ));
} }
} else { } else {
overload overload.errors.push(BindingError::InternalCallError(
.errors "property has no setter",
.push(BindingError::InternalCallError("property has no setter")); ));
} }
} }
} }
@ -428,9 +440,9 @@ impl<'db> Bindings<'db> {
)); ));
} }
} else { } else {
overload overload.errors.push(BindingError::InternalCallError(
.errors "property has no setter",
.push(BindingError::InternalCallError("property has no setter")); ));
} }
} }
} }
@ -446,7 +458,8 @@ impl<'db> Bindings<'db> {
} }
Type::DataclassTransformer(params) => { Type::DataclassTransformer(params) => {
if let [Some(Type::FunctionLiteral(function))] = overload.parameter_types() { if let [Some(Type::FunctionLiteral(function))] = overload.parameter_types()
{
overload.set_return_type(Type::FunctionLiteral( overload.set_return_type(Type::FunctionLiteral(
function.with_dataclass_transformer_params(db, params), function.with_dataclass_transformer_params(db, params),
)); ));
@ -539,7 +552,8 @@ impl<'db> Bindings<'db> {
Some(KnownFunction::IsFullyStatic) => { Some(KnownFunction::IsFullyStatic) => {
if let [Some(ty)] = overload.parameter_types() { if let [Some(ty)] = overload.parameter_types() {
overload.set_return_type(Type::BooleanLiteral(ty.is_fully_static(db))); overload
.set_return_type(Type::BooleanLiteral(ty.is_fully_static(db)));
} }
} }
@ -551,7 +565,8 @@ impl<'db> Bindings<'db> {
Some(KnownFunction::IsSingleValued) => { Some(KnownFunction::IsSingleValued) => {
if let [Some(ty)] = overload.parameter_types() { if let [Some(ty)] = overload.parameter_types() {
overload.set_return_type(Type::BooleanLiteral(ty.is_single_valued(db))); overload
.set_return_type(Type::BooleanLiteral(ty.is_single_valued(db)));
} }
} }
@ -649,7 +664,8 @@ impl<'db> Bindings<'db> {
Type::Never Type::Never
}; };
let union_with_default = |ty| UnionType::from_elements(db, [ty, default]); let union_with_default =
|ty| UnionType::from_elements(db, [ty, default]);
// TODO: we could emit a diagnostic here (if default is not set) // TODO: we could emit a diagnostic here (if default is not set)
overload.set_return_type( overload.set_return_type(
@ -759,7 +775,8 @@ impl<'db> Bindings<'db> {
dataclass_params.set(DataclassParams::ORDER, *order); dataclass_params.set(DataclassParams::ORDER, *order);
} }
overload.set_return_type(Type::DataclassDecorator(dataclass_params)); overload
.set_return_type(Type::DataclassDecorator(dataclass_params));
} }
} }
}, },
@ -805,6 +822,7 @@ impl<'db> Bindings<'db> {
} }
} }
} }
}
} }
impl<'a, 'db> IntoIterator for &'a Bindings<'db> { impl<'a, 'db> IntoIterator for &'a Bindings<'db> {
@ -868,7 +886,7 @@ impl<'db> CallableBinding<'db> {
// the matching overloads. Make sure to implement that as part of separating call binding into // the matching overloads. Make sure to implement that as part of separating call binding into
// two phases. // two phases.
// //
// [1] https://github.com/python/typing/pull/1839 // [1] https://typing.python.org/en/latest/spec/overload.html#overload-call-evaluation
let overloads = signature let overloads = signature
.into_iter() .into_iter()
.map(|signature| { .map(|signature| {
@ -928,35 +946,39 @@ impl<'db> CallableBinding<'db> {
/// Returns whether there were any errors binding this call site. If the callable has multiple /// Returns whether there were any errors binding this call site. If the callable has multiple
/// overloads, they must _all_ have errors. /// overloads, they must _all_ have errors.
pub(crate) fn has_binding_errors(&self) -> bool { pub(crate) fn has_binding_errors(&self) -> bool {
self.matching_overload().is_none() self.matching_overloads().next().is_none()
} }
/// Returns the overload that matched for this call binding. Returns `None` if none of the /// Returns an iterator over all the overloads that matched for this call binding.
/// overloads matched. pub(crate) fn matching_overloads(&self) -> impl Iterator<Item = (usize, &Binding<'db>)> {
pub(crate) fn matching_overload(&self) -> Option<(usize, &Binding<'db>)> {
self.overloads self.overloads
.iter() .iter()
.enumerate() .enumerate()
.find(|(_, overload)| overload.as_result().is_ok()) .filter(|(_, overload)| overload.as_result().is_ok())
} }
/// Returns the overload that matched for this call binding. Returns `None` if none of the /// Returns an iterator over all the mutable overloads that matched for this call binding.
/// overloads matched. pub(crate) fn matching_overloads_mut(
pub(crate) fn matching_overload_mut(&mut self) -> Option<(usize, &mut Binding<'db>)> { &mut self,
) -> impl Iterator<Item = (usize, &mut Binding<'db>)> {
self.overloads self.overloads
.iter_mut() .iter_mut()
.enumerate() .enumerate()
.find(|(_, overload)| overload.as_result().is_ok()) .filter(|(_, overload)| overload.as_result().is_ok())
} }
/// Returns the return type of this call. For a valid call, this is the return type of the /// Returns the return type of this call. For a valid call, this is the return type of the
/// overload that the arguments matched against. For an invalid call to a non-overloaded /// first overload that the arguments matched against. For an invalid call to a non-overloaded
/// function, this is the return type of the function. For an invalid call to an overloaded /// function, this is the return type of the function. For an invalid call to an overloaded
/// function, we return `Type::unknown`, since we cannot make any useful conclusions about /// function, we return `Type::unknown`, since we cannot make any useful conclusions about
/// which overload was intended to be called. /// which overload was intended to be called.
pub(crate) fn return_type(&self) -> Type<'db> { pub(crate) fn return_type(&self) -> Type<'db> {
if let Some((_, overload)) = self.matching_overload() { // TODO: Implement the overload call evaluation algorithm as mentioned in the spec [1] to
return overload.return_type(); // get the matching overload and use that to get the return type.
//
// [1]: https://typing.python.org/en/latest/spec/overload.html#overload-call-evaluation
if let Some((_, first_overload)) = self.matching_overloads().next() {
return first_overload.return_type();
} }
if let [overload] = self.overloads.as_slice() { if let [overload] = self.overloads.as_slice() {
return overload.return_type(); return overload.return_type();

View file

@ -4627,10 +4627,7 @@ impl<'db> TypeInferenceBuilder<'db> {
Ok(mut bindings) => { Ok(mut bindings) => {
for binding in &mut bindings { for binding in &mut bindings {
let binding_type = binding.callable_type; let binding_type = binding.callable_type;
let Some((_, overload)) = binding.matching_overload_mut() else { for (_, overload) in binding.matching_overloads_mut() {
continue;
};
match binding_type { match binding_type {
Type::FunctionLiteral(function_literal) => { Type::FunctionLiteral(function_literal) => {
let Some(known_function) = function_literal.known(self.db()) else { let Some(known_function) = function_literal.known(self.db()) else {
@ -4644,7 +4641,8 @@ impl<'db> TypeInferenceBuilder<'db> {
DiagnosticId::RevealedType, DiagnosticId::RevealedType,
Severity::Info, Severity::Info,
) { ) {
let mut diag = builder.into_diagnostic("Revealed type"); let mut diag =
builder.into_diagnostic("Revealed type");
let span = self.context.span(call_expression); let span = self.context.span(call_expression);
diag.annotate(Annotation::primary(span).message( diag.annotate(Annotation::primary(span).message(
format_args!( format_args!(
@ -4695,18 +4693,21 @@ impl<'db> TypeInferenceBuilder<'db> {
if let [Some(parameter_ty), message] = if let [Some(parameter_ty), message] =
overload.parameter_types() overload.parameter_types()
{ {
let truthiness = match parameter_ty.try_bool(self.db()) { let truthiness = match parameter_ty.try_bool(self.db())
{
Ok(truthiness) => truthiness, Ok(truthiness) => truthiness,
Err(err) => { Err(err) => {
let condition = arguments let condition = arguments
.find_argument("condition", 0) .find_argument("condition", 0)
.map(|argument| match argument { .map(|argument| {
match argument {
ruff_python_ast::ArgOrKeyword::Arg( ruff_python_ast::ArgOrKeyword::Arg(
expr, expr,
) => ast::AnyNodeRef::from(expr), ) => ast::AnyNodeRef::from(expr),
ruff_python_ast::ArgOrKeyword::Keyword( ruff_python_ast::ArgOrKeyword::Keyword(
keyword, keyword,
) => ast::AnyNodeRef::from(keyword), ) => ast::AnyNodeRef::from(keyword),
}
}) })
.unwrap_or(ast::AnyNodeRef::from( .unwrap_or(ast::AnyNodeRef::from(
call_expression, call_expression,
@ -4922,7 +4923,9 @@ impl<'db> TypeInferenceBuilder<'db> {
let name_param = name_param let name_param = name_param
.into_string_literal() .into_string_literal()
.map(|name| name.value(self.db()).as_ref()); .map(|name| name.value(self.db()).as_ref());
if name_param.is_none_or(|name_param| name_param != target.id) { if name_param
.is_none_or(|name_param| name_param != target.id)
{
if let Some(builder) = self.context.report_lint( if let Some(builder) = self.context.report_lint(
&INVALID_LEGACY_TYPE_VARIABLE, &INVALID_LEGACY_TYPE_VARIABLE,
call_expression, call_expression,
@ -4963,7 +4966,9 @@ impl<'db> TypeInferenceBuilder<'db> {
.map(|(_, ty)| ty) .map(|(_, ty)| ty)
.collect::<Box<_>>(), .collect::<Box<_>>(),
); );
Some(TypeVarBoundOrConstraints::Constraints(elements)) Some(TypeVarBoundOrConstraints::Constraints(
elements,
))
} }
// TODO: Emit a diagnostic that TypeVar cannot be both bounded and // TODO: Emit a diagnostic that TypeVar cannot be both bounded and
@ -4993,6 +4998,7 @@ impl<'db> TypeInferenceBuilder<'db> {
_ => (), _ => (),
} }
} }
}
bindings.return_type(self.db()) bindings.return_type(self.db())
} }
@ -6637,7 +6643,8 @@ impl<'db> TypeInferenceBuilder<'db> {
.next() .next()
.expect("valid bindings should have one callable"); .expect("valid bindings should have one callable");
let (_, overload) = callable let (_, overload) = callable
.matching_overload() .matching_overloads()
.next()
.expect("valid bindings should have matching overload"); .expect("valid bindings should have matching overload");
let specialization = generic_context.specialize( let specialization = generic_context.specialize(
self.db(), self.db(),