[red-knot] Add FunctionType::to_overloaded (#17585)

## Summary

This PR adds a new method `FunctionType::to_overloaded` which converts a
`FunctionType` into an `OverloadedFunction` which contains all the
`@overload`-ed `FunctionType` and the implementation `FunctionType` if
it exists.

There's a big caveat here (it's the way overloads work) which is that
this method can only "see" all the overloads that comes _before_ itself.
Consider the following example:

```py
from typing import overload

@overload
def foo() -> None: ...
@overload
def foo(x: int) -> int: ...
def foo(x: int | None) -> int | None:
	return x
```

Here, when the `to_overloaded` method is invoked on the
1. first `foo` definition, it would only contain a single overload which
is itself and no implementation.
2. second `foo` definition, it would contain both overloads and still no
implementation
3. third `foo` definition, it would contain both overloads and the
implementation which is itself

### Usages

This method will be used in the logic for checking invalid overload
usages. It can also be used for #17541.

## Test Plan

Make sure that existing tests pass.
This commit is contained in:
Dhruv Manilawala 2025-04-24 02:57:05 +05:30 committed by GitHub
parent bfc1650198
commit 7b6222700b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -5921,6 +5921,19 @@ impl<'db> IntoIterator for &'db FunctionSignature<'db> {
}
}
/// An overloaded function.
///
/// This is created by the [`to_overloaded`] method on [`FunctionType`].
///
/// [`to_overloaded`]: FunctionType::to_overloaded
#[derive(Debug, PartialEq, Eq, salsa::Update)]
struct OverloadedFunction<'db> {
/// The overloads of this function.
overloads: Vec<FunctionType<'db>>,
/// The implementation of this overloaded function, if any.
implementation: Option<FunctionType<'db>>,
}
#[salsa::interned(debug)]
pub struct FunctionType<'db> {
/// Name of the function at definition.
@ -6009,49 +6022,20 @@ impl<'db> FunctionType<'db> {
/// would depend on the function's AST and rerun for every change in that file.
#[salsa::tracked(return_ref)]
pub(crate) fn signature(self, db: &'db dyn Db) -> FunctionSignature<'db> {
let internal_signature = self.internal_signature(db);
// The semantic model records a use for each function on the name node. This is used here
// to get the previous function definition with the same name.
let scope = self.definition(db).scope(db);
let use_def = semantic_index(db, scope.file(db)).use_def_map(scope.file_scope_id(db));
let use_id = self
.body_scope(db)
.node(db)
.expect_function()
.name
.scoped_use_id(db, scope);
if let Symbol::Type(Type::FunctionLiteral(function_literal), Boundness::Bound) =
symbol_from_bindings(db, use_def.bindings_at_use(use_id))
{
match function_literal.signature(db) {
FunctionSignature::Single(_) => {
debug_assert!(
!function_literal.has_known_decorator(db, FunctionDecorators::OVERLOAD),
"Expected `FunctionSignature::Overloaded` if the previous function was an overload"
);
}
FunctionSignature::Overloaded(_, Some(_)) => {
// If the previous overloaded function already has an implementation, then this
// new signature completely replaces it.
}
FunctionSignature::Overloaded(signatures, None) => {
return if self.has_known_decorator(db, FunctionDecorators::OVERLOAD) {
let mut signatures = signatures.clone();
signatures.push(internal_signature);
FunctionSignature::Overloaded(signatures, None)
} else {
FunctionSignature::Overloaded(signatures.clone(), Some(internal_signature))
};
}
}
}
if self.has_known_decorator(db, FunctionDecorators::OVERLOAD) {
FunctionSignature::Overloaded(vec![internal_signature], None)
if let Some(overloaded) = self.to_overloaded(db) {
FunctionSignature::Overloaded(
overloaded
.overloads
.iter()
.copied()
.map(|overload| overload.internal_signature(db))
.collect(),
overloaded
.implementation
.map(|implementation| implementation.internal_signature(db)),
)
} else {
FunctionSignature::Single(internal_signature)
FunctionSignature::Single(self.internal_signature(db))
}
}
@ -6142,6 +6126,107 @@ impl<'db> FunctionType<'db> {
Some(specialization),
)
}
/// Returns `self` as [`OverloadedFunction`] if it is overloaded, [`None`] otherwise.
///
/// ## Note
///
/// The way this method works only allows us to "see" the overloads that are defined before
/// this function definition. This is because the semantic model records a use for each
/// function on the name node which is used to get the previous function definition with the
/// same name. This means that [`OverloadedFunction`] would only include the functions that
/// comes before this function definition. Consider the following example:
///
/// ```py
/// from typing import overload
///
/// @overload
/// def foo() -> None: ...
/// @overload
/// def foo(x: int) -> int: ...
/// def foo(x: int | None) -> int | None:
/// return x
/// ```
///
/// Here, when the `to_overloaded` method is invoked on the
/// 1. first `foo` definition, it would only contain a single overload which is itself and no
/// implementation
/// 2. second `foo` definition, it would contain both overloads and still no implementation
/// 3. third `foo` definition, it would contain both overloads and the implementation which is
/// itself
fn to_overloaded(self, db: &'db dyn Db) -> Option<&'db OverloadedFunction<'db>> {
#[allow(clippy::ref_option)] // TODO: Remove once salsa supports deref (https://github.com/salsa-rs/salsa/pull/772)
#[salsa::tracked(return_ref)]
fn to_overloaded_impl<'db>(
db: &'db dyn Db,
function: FunctionType<'db>,
) -> Option<OverloadedFunction<'db>> {
// The semantic model records a use for each function on the name node. This is used here
// to get the previous function definition with the same name.
let scope = function.definition(db).scope(db);
let use_def = semantic_index(db, scope.file(db)).use_def_map(scope.file_scope_id(db));
let use_id = function
.body_scope(db)
.node(db)
.expect_function()
.name
.scoped_use_id(db, scope);
if let Symbol::Type(Type::FunctionLiteral(function_literal), Boundness::Bound) =
symbol_from_bindings(db, use_def.bindings_at_use(use_id))
{
match function_literal.to_overloaded(db) {
None => {
debug_assert!(
!function_literal.has_known_decorator(db, FunctionDecorators::OVERLOAD),
"Expected `Some(OverloadedFunction)` if the previous function was an overload"
);
}
Some(OverloadedFunction {
implementation: Some(_),
..
}) => {
// If the previous overloaded function already has an implementation, then this
// new signature completely replaces it.
}
Some(OverloadedFunction {
overloads,
implementation: None,
}) => {
return Some(
if function.has_known_decorator(db, FunctionDecorators::OVERLOAD) {
let mut overloads = overloads.clone();
overloads.push(function);
OverloadedFunction {
overloads,
implementation: None,
}
} else {
OverloadedFunction {
overloads: overloads.clone(),
implementation: Some(function),
}
},
);
}
}
}
if function.has_known_decorator(db, FunctionDecorators::OVERLOAD) {
Some(OverloadedFunction {
overloads: vec![function],
implementation: None,
})
} else {
None
}
}
// HACK: This is required because salsa doesn't support returning `Option<&T>` from tracked
// functions yet. Refer to https://github.com/salsa-rs/salsa/pull/772. Remove the inner
// function once it's supported.
to_overloaded_impl(db, self).as_ref()
}
}
/// Non-exhaustive enumeration of known functions (e.g. `builtins.reveal_type`, ...) that might