[red-knot] Check assignability for two callable types (#16845)

## Summary

Part of #15382

This PR adds support for checking the assignability of two general
callable types.

This is built on top of #16804 by including the gradual parameters check
and accepting a function that performs the check between the two types.

## Test Plan

Update `is_assignable_to.md` with callable types section.
This commit is contained in:
Dhruv Manilawala 2025-03-23 02:28:44 +05:30 committed by GitHub
parent 92028efe3d
commit 1cffb323bc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 149 additions and 17 deletions

View file

@ -98,3 +98,22 @@ expression.
```py
reveal_type(lambda a=lambda x, y: 0: 2) # revealed: (a=(x, y) -> Unknown) -> Unknown
```
## Assignment
This does not enumerate all combinations of parameter kinds as that should be covered by the
[subtype tests for callable types](./../type_properties/is_subtype_of.md#callable).
```py
from typing import Callable
a1: Callable[[], None] = lambda: None
a2: Callable[[int], None] = lambda x: None
a3: Callable[[int, int], None] = lambda x, y, z=1: None
a4: Callable[[int, int], None] = lambda *args: None
# error: [invalid-assignment]
a5: Callable[[], None] = lambda x: None
# error: [invalid-assignment]
a6: Callable[[int], None] = lambda: None
```

View file

@ -393,4 +393,87 @@ static_assert(is_assignable_to(Never, type[str]))
static_assert(is_assignable_to(Never, type[Any]))
```
## Callable
The examples provided below are only a subset of the possible cases and include the ones with
gradual types. The cases with fully static types and using different combinations of parameter kinds
are covered in the [subtyping tests](./is_subtype_of.md#callable).
### Return type
```py
from knot_extensions import CallableTypeFromFunction, Unknown, static_assert, is_assignable_to
from typing import Any, Callable
static_assert(is_assignable_to(Callable[[], Any], Callable[[], int]))
static_assert(is_assignable_to(Callable[[], int], Callable[[], Any]))
static_assert(is_assignable_to(Callable[[], int], Callable[[], float]))
static_assert(not is_assignable_to(Callable[[], float], Callable[[], int]))
```
The return types should be checked even if the parameter types uses gradual form (`...`).
```py
static_assert(is_assignable_to(Callable[..., int], Callable[..., float]))
static_assert(not is_assignable_to(Callable[..., float], Callable[..., int]))
```
And, if there is no return type, the return type is `Unknown`.
```py
static_assert(is_assignable_to(Callable[[], Unknown], Callable[[], int]))
static_assert(is_assignable_to(Callable[[], int], Callable[[], Unknown]))
```
### Parameter types
A `Callable` which uses the gradual form (`...`) for the parameter types is consistent with any
input signature.
```py
from knot_extensions import CallableTypeFromFunction, static_assert, is_assignable_to
from typing import Any, Callable
static_assert(is_assignable_to(Callable[[], None], Callable[..., None]))
static_assert(is_assignable_to(Callable[..., None], Callable[..., None]))
static_assert(is_assignable_to(Callable[[int, float, str], None], Callable[..., None]))
```
Even if it includes any other parameter kinds.
```py
def positional_only(a: int, b: int, /) -> None: ...
def positional_or_keyword(a: int, b: int) -> None: ...
def variadic(*args: int) -> None: ...
def keyword_only(*, a: int, b: int) -> None: ...
def keyword_variadic(**kwargs: int) -> None: ...
def mixed(a: int, /, b: int, *args: int, c: int, **kwargs: int) -> None: ...
static_assert(is_assignable_to(CallableTypeFromFunction[positional_only], Callable[..., None]))
static_assert(is_assignable_to(CallableTypeFromFunction[positional_or_keyword], Callable[..., None]))
static_assert(is_assignable_to(CallableTypeFromFunction[variadic], Callable[..., None]))
static_assert(is_assignable_to(CallableTypeFromFunction[keyword_only], Callable[..., None]))
static_assert(is_assignable_to(CallableTypeFromFunction[keyword_variadic], Callable[..., None]))
static_assert(is_assignable_to(CallableTypeFromFunction[mixed], Callable[..., None]))
```
And, even if the parameters are unannotated.
```py
def positional_only(a, b, /) -> None: ...
def positional_or_keyword(a, b) -> None: ...
def variadic(*args) -> None: ...
def keyword_only(*, a, b) -> None: ...
def keyword_variadic(**kwargs) -> None: ...
def mixed(a, /, b, *args, c, **kwargs) -> None: ...
static_assert(is_assignable_to(CallableTypeFromFunction[positional_only], Callable[..., None]))
static_assert(is_assignable_to(CallableTypeFromFunction[positional_or_keyword], Callable[..., None]))
static_assert(is_assignable_to(CallableTypeFromFunction[variadic], Callable[..., None]))
static_assert(is_assignable_to(CallableTypeFromFunction[keyword_only], Callable[..., None]))
static_assert(is_assignable_to(CallableTypeFromFunction[keyword_variadic], Callable[..., None]))
static_assert(is_assignable_to(CallableTypeFromFunction[mixed], Callable[..., None]))
```
[typing documentation]: https://typing.readthedocs.io/en/latest/spec/concepts.html#the-assignable-to-or-consistent-subtyping-relation

View file

@ -887,6 +887,11 @@ impl<'db> Type<'db> {
}
}
(
Type::Callable(CallableType::General(self_callable)),
Type::Callable(CallableType::General(target_callable)),
) => self_callable.is_assignable_to(db, target_callable),
// TODO other types containing gradual forms (e.g. generics containing Any/Unknown)
_ => self.is_subtype_of(db, target),
}
@ -4442,8 +4447,32 @@ impl<'db> GeneralCallableType<'db> {
})
}
/// Return `true` if `self` is assignable to `other`.
pub(crate) fn is_assignable_to(self, db: &'db dyn Db, other: Self) -> bool {
self.is_assignable_to_impl(db, other, |type1, type2| {
// In the context of a callable type, the `None` variant represents an `Unknown` type.
type1
.unwrap_or(Type::unknown())
.is_assignable_to(db, type2.unwrap_or(Type::unknown()))
})
}
/// Return `true` if `self` is a subtype of `other`.
pub(crate) fn is_subtype_of(self, db: &'db dyn Db, other: Self) -> bool {
self.is_assignable_to_impl(db, other, |type1, type2| {
// SAFETY: Subtype relation is only checked for fully static types.
type1.unwrap().is_subtype_of(db, type2.unwrap())
})
}
/// Implementation for the [`is_assignable_to`] and [`is_subtype_of`] for callable types.
///
/// [`is_assignable_to`]: Self::is_assignable_to
/// [`is_subtype_of`]: Self::is_subtype_of
fn is_assignable_to_impl<F>(self, db: &'db dyn Db, other: Self, check_types: F) -> bool
where
F: Fn(Option<Type<'db>>, Option<Type<'db>>) -> bool,
{
/// A helper struct to zip two slices of parameters together that provides control over the
/// two iterators individually. It also keeps track of the current parameter in each
/// iterator.
@ -4508,18 +4537,17 @@ impl<'db> GeneralCallableType<'db> {
let self_signature = self.signature(db);
let other_signature = other.signature(db);
// Check if `type1` is a subtype of `type2`. This is mainly to avoid `unwrap` calls
// scattered throughout the function.
let is_subtype = |type1: Option<Type<'db>>, type2: Option<Type<'db>>| {
// SAFETY: Subtype relation is only checked for fully static types.
type1.unwrap().is_subtype_of(db, type2.unwrap())
};
// Return types are covariant.
if !is_subtype(self_signature.return_ty, other_signature.return_ty) {
if !check_types(self_signature.return_ty, other_signature.return_ty) {
return false;
}
if self_signature.parameters().is_gradual() || other_signature.parameters().is_gradual() {
// If either of the parameter lists contains a gradual form (`...`), then it is
// assignable / subtype to and from any other callable type.
return true;
}
let mut parameters = ParametersZip {
current_self: None,
current_other: None,
@ -4577,7 +4605,7 @@ impl<'db> GeneralCallableType<'db> {
if self_default.is_none() && other_default.is_some() {
return false;
}
if !is_subtype(
if !check_types(
other_parameter.annotated_type(),
self_parameter.annotated_type(),
) {
@ -4602,7 +4630,7 @@ impl<'db> GeneralCallableType<'db> {
if self_default.is_none() && other_default.is_some() {
return false;
}
if !is_subtype(
if !check_types(
other_parameter.annotated_type(),
self_parameter.annotated_type(),
) {
@ -4611,7 +4639,7 @@ impl<'db> GeneralCallableType<'db> {
}
(ParameterKind::Variadic { .. }, ParameterKind::PositionalOnly { .. }) => {
if !is_subtype(
if !check_types(
other_parameter.annotated_type(),
self_parameter.annotated_type(),
) {
@ -4641,7 +4669,7 @@ impl<'db> GeneralCallableType<'db> {
// variadic parameter and is deferred to the next iteration.
break;
}
if !is_subtype(
if !check_types(
other_parameter.annotated_type(),
self_parameter.annotated_type(),
) {
@ -4652,7 +4680,7 @@ impl<'db> GeneralCallableType<'db> {
}
(ParameterKind::Variadic { .. }, ParameterKind::Variadic { .. }) => {
if !is_subtype(
if !check_types(
other_parameter.annotated_type(),
self_parameter.annotated_type(),
) {
@ -4730,7 +4758,7 @@ impl<'db> GeneralCallableType<'db> {
if self_default.is_none() && other_default.is_some() {
return false;
}
if !is_subtype(
if !check_types(
other_parameter.annotated_type(),
self_parameter.annotated_type(),
) {
@ -4742,8 +4770,10 @@ impl<'db> GeneralCallableType<'db> {
),
}
} else if let Some(self_keyword_variadic_type) = self_keyword_variadic {
if !is_subtype(other_parameter.annotated_type(), self_keyword_variadic_type)
{
if !check_types(
other_parameter.annotated_type(),
self_keyword_variadic_type,
) {
return false;
}
} else {
@ -4756,7 +4786,7 @@ impl<'db> GeneralCallableType<'db> {
// parameter, `self` must also have a keyword variadic parameter.
return false;
};
if !is_subtype(other_parameter.annotated_type(), self_keyword_variadic_type) {
if !check_types(other_parameter.annotated_type(), self_keyword_variadic_type) {
return false;
}
}