mirror of
https://github.com/astral-sh/ruff.git
synced 2025-10-01 06:11:21 +00:00
[red-knot] Fix equivalence of differently ordered unions that contain Callable
types (#17145)
## Summary Fixes https://github.com/astral-sh/ruff/issues/17058. Equivalent callable types were not understood as equivalent when they appeared nested inside unions and intersections. This PR fixes that by ensuring that `Callable` elements nested inside unions, intersections and tuples have their representations normalized before one union type is compared with another for equivalence, or before one intersection type is compared with another for equivalence. The normalizations applied to a `Callable` type are: - the type of the default value is stripped from all parameters (only whether the parameter _has_ a default value is relevant to whether one `Callable` type is equivalent to another) - The names of the parameters are stripped from positional-only parameters, variadic parameters and keyword-variadic parameters - Unions and intersections that are present (top-level or nested) inside parameter annotations or return annotations are normalized. Adding a `CallableType::normalized()` method also allows us to simplify the implementation of `CallableType::is_equivalent_to()`. ### Should these normalizations be done eagerly as part of a `CallableType` constructor? I considered this. It's something that we could still consider doing in the future; this PR doesn't rule it out as a possibility. However, I didn't pursue it for now, for several reasons: 1. Our current `Display` implementation doesn't handle well the possibility that a parameter might not have a name or an annotated type. Callable types with parameters like this would be displayed as follows: ```py (, ,) -> None: ... ``` That's fixable! It could easily become something like `(Unknown, Unknown) -> None: ...`. But it also illustrates that we probably want to retain the parameter names when displaying the signature of a `lambda` function if you're hovering over a reference to the lambda in an IDE. Currently we don't have a `LambdaType` struct for representing `lambda` functions; if we wanted to eagerly normalize signatures when creating `CallableType`s, we'd probably have to add a `LambdaType` struct so that we would retain the full signature of a `lambda` function, rather than representing it as an eagerly simplified `CallableType`. 2. In order to ensure that it's impossible to create `CallableType`s without the parameters being normalized, I'd either have to create an alternative `SimplifiedSignature` struct (which would duplicate a lot of code), or move `CallableType` to a new module so that the only way of constructing a `CallableType` instance would be via a constructor method that performs the normalizations eagerly on the callable's signature. Again, this isn't a dealbreaker, and I think it's still an option, but it would be a lot of churn, and it didn't seem necessary for now. Doing it this way, at least to start with, felt like it would create a diff that's easier to review and felt like it would create fewer merge conflicts for others. ## Test Plan - Added a regression mdtest for https://github.com/astral-sh/ruff/issues/17058 - Ran `QUICKCHECK_TESTS=1000000 cargo test --release -p red_knot_python_semantic -- --ignored types::property_tests::stable`
This commit is contained in:
parent
cb7dae1e96
commit
c2bb5d5250
5 changed files with 97 additions and 55 deletions
|
@ -134,6 +134,7 @@ def f1(a: int = 1) -> None: ...
|
|||
def f2(a: int = 2) -> None: ...
|
||||
|
||||
static_assert(is_equivalent_to(CallableTypeOf[f1], CallableTypeOf[f2]))
|
||||
static_assert(is_equivalent_to(CallableTypeOf[f1] | bool | CallableTypeOf[f2], CallableTypeOf[f2] | bool | CallableTypeOf[f1]))
|
||||
```
|
||||
|
||||
The names of the positional-only, variadic and keyword-variadic parameters does not need to be the
|
||||
|
@ -144,6 +145,7 @@ def f3(a1: int, /, *args1: int, **kwargs2: int) -> None: ...
|
|||
def f4(a2: int, /, *args2: int, **kwargs1: int) -> None: ...
|
||||
|
||||
static_assert(is_equivalent_to(CallableTypeOf[f3], CallableTypeOf[f4]))
|
||||
static_assert(is_equivalent_to(CallableTypeOf[f3] | bool | CallableTypeOf[f4], CallableTypeOf[f4] | bool | CallableTypeOf[f3]))
|
||||
```
|
||||
|
||||
Putting it all together, the following two callables are equivalent:
|
||||
|
@ -153,6 +155,7 @@ def f5(a1: int, /, b: float, c: bool = False, *args1: int, d: int = 1, e: str, *
|
|||
def f6(a2: int, /, b: float, c: bool = True, *args2: int, d: int = 2, e: str, **kwargs2: float) -> None: ...
|
||||
|
||||
static_assert(is_equivalent_to(CallableTypeOf[f5], CallableTypeOf[f6]))
|
||||
static_assert(is_equivalent_to(CallableTypeOf[f5] | bool | CallableTypeOf[f6], CallableTypeOf[f6] | bool | CallableTypeOf[f5]))
|
||||
```
|
||||
|
||||
### Not equivalent
|
||||
|
|
|
@ -147,6 +147,9 @@ def f4(a=2): ...
|
|||
def f5(a): ...
|
||||
|
||||
static_assert(is_gradual_equivalent_to(CallableTypeOf[f3], CallableTypeOf[f4]))
|
||||
static_assert(
|
||||
is_gradual_equivalent_to(CallableTypeOf[f3] | bool | CallableTypeOf[f4], CallableTypeOf[f4] | bool | CallableTypeOf[f3])
|
||||
)
|
||||
static_assert(not is_gradual_equivalent_to(CallableTypeOf[f3], CallableTypeOf[f5]))
|
||||
|
||||
def f6(a, /): ...
|
||||
|
|
|
@ -597,19 +597,22 @@ impl<'db> Type<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Return a normalized version of `self` in which all unions and intersections are sorted
|
||||
/// according to a canonical order, no matter how "deeply" a union/intersection may be nested.
|
||||
/// Return a "normalized" version of `self` that ensures that equivalent types have the same Salsa ID.
|
||||
///
|
||||
/// A normalized type:
|
||||
/// - Has all unions and intersections sorted according to a canonical order,
|
||||
/// no matter how "deeply" a union/intersection may be nested.
|
||||
/// - Strips the names of positional-only parameters and variadic parameters from `Callable` types,
|
||||
/// as these are irrelevant to whether a callable type `X` is equivalent to a callable type `Y`.
|
||||
/// - Strips the types of default values from parameters in `Callable` types: only whether a parameter
|
||||
/// *has* or *does not have* a default value is relevant to whether two `Callable` types are equivalent.
|
||||
#[must_use]
|
||||
pub fn with_sorted_unions_and_intersections(self, db: &'db dyn Db) -> Self {
|
||||
pub fn normalized(self, db: &'db dyn Db) -> Self {
|
||||
match self {
|
||||
Type::Union(union) => Type::Union(union.to_sorted_union(db)),
|
||||
Type::Intersection(intersection) => {
|
||||
Type::Intersection(intersection.to_sorted_intersection(db))
|
||||
}
|
||||
Type::Tuple(tuple) => Type::Tuple(tuple.with_sorted_unions_and_intersections(db)),
|
||||
Type::Callable(callable) => {
|
||||
Type::Callable(callable.with_sorted_unions_and_intersections(db))
|
||||
}
|
||||
Type::Union(union) => Type::Union(union.normalized(db)),
|
||||
Type::Intersection(intersection) => Type::Intersection(intersection.normalized(db)),
|
||||
Type::Tuple(tuple) => Type::Tuple(tuple.normalized(db)),
|
||||
Type::Callable(callable) => Type::Callable(callable.normalized(db)),
|
||||
Type::LiteralString
|
||||
| Type::Instance(_)
|
||||
| Type::PropertyInstance(_)
|
||||
|
@ -4676,16 +4679,19 @@ impl<'db> CallableType<'db> {
|
|||
)
|
||||
}
|
||||
|
||||
fn with_sorted_unions_and_intersections(self, db: &'db dyn Db) -> Self {
|
||||
/// Return a "normalized" version of this `Callable` type.
|
||||
///
|
||||
/// See [`Type::normalized`] for more details.
|
||||
fn normalized(self, db: &'db dyn Db) -> Self {
|
||||
let signature = self.signature(db);
|
||||
let parameters = signature
|
||||
.parameters()
|
||||
.iter()
|
||||
.map(|param| param.clone().with_sorted_unions_and_intersections(db))
|
||||
.map(|param| param.normalized(db))
|
||||
.collect();
|
||||
let return_ty = signature
|
||||
.return_ty
|
||||
.map(|return_ty| return_ty.with_sorted_unions_and_intersections(db));
|
||||
.map(|return_ty| return_ty.normalized(db));
|
||||
CallableType::new(db, Signature::new(parameters, return_ty))
|
||||
}
|
||||
|
||||
|
@ -5447,13 +5453,15 @@ impl<'db> UnionType<'db> {
|
|||
self.elements(db).iter().all(|ty| ty.is_fully_static(db))
|
||||
}
|
||||
|
||||
/// Create a new union type with the elements sorted according to a canonical ordering.
|
||||
/// Create a new union type with the elements normalized.
|
||||
///
|
||||
/// See [`Type::normalized`] for more details.
|
||||
#[must_use]
|
||||
pub fn to_sorted_union(self, db: &'db dyn Db) -> Self {
|
||||
pub fn normalized(self, db: &'db dyn Db) -> Self {
|
||||
let mut new_elements: Vec<Type<'db>> = self
|
||||
.elements(db)
|
||||
.iter()
|
||||
.map(|element| element.with_sorted_unions_and_intersections(db))
|
||||
.map(|element| element.normalized(db))
|
||||
.collect();
|
||||
new_elements.sort_unstable_by(|l, r| union_or_intersection_elements_ordering(db, l, r));
|
||||
UnionType::new(db, new_elements.into_boxed_slice())
|
||||
|
@ -5487,13 +5495,13 @@ impl<'db> UnionType<'db> {
|
|||
return true;
|
||||
}
|
||||
|
||||
let sorted_self = self.to_sorted_union(db);
|
||||
let sorted_self = self.normalized(db);
|
||||
|
||||
if sorted_self == other {
|
||||
return true;
|
||||
}
|
||||
|
||||
sorted_self == other.to_sorted_union(db)
|
||||
sorted_self == other.normalized(db)
|
||||
}
|
||||
|
||||
/// Return `true` if `self` has exactly the same set of possible static materializations as `other`
|
||||
|
@ -5510,13 +5518,13 @@ impl<'db> UnionType<'db> {
|
|||
return false;
|
||||
}
|
||||
|
||||
let sorted_self = self.to_sorted_union(db);
|
||||
let sorted_self = self.normalized(db);
|
||||
|
||||
if sorted_self == other {
|
||||
return true;
|
||||
}
|
||||
|
||||
let sorted_other = other.to_sorted_union(db);
|
||||
let sorted_other = other.normalized(db);
|
||||
|
||||
if sorted_self == sorted_other {
|
||||
return true;
|
||||
|
@ -5547,17 +5555,17 @@ pub struct IntersectionType<'db> {
|
|||
|
||||
impl<'db> IntersectionType<'db> {
|
||||
/// Return a new `IntersectionType` instance with the positive and negative types sorted
|
||||
/// according to a canonical ordering.
|
||||
/// according to a canonical ordering, and other normalizations applied to each element as applicable.
|
||||
///
|
||||
/// See [`Type::normalized`] for more details.
|
||||
#[must_use]
|
||||
pub fn to_sorted_intersection(self, db: &'db dyn Db) -> Self {
|
||||
pub fn normalized(self, db: &'db dyn Db) -> Self {
|
||||
fn normalized_set<'db>(
|
||||
db: &'db dyn Db,
|
||||
elements: &FxOrderSet<Type<'db>>,
|
||||
) -> FxOrderSet<Type<'db>> {
|
||||
let mut elements: FxOrderSet<Type<'db>> = elements
|
||||
.iter()
|
||||
.map(|ty| ty.with_sorted_unions_and_intersections(db))
|
||||
.collect();
|
||||
let mut elements: FxOrderSet<Type<'db>> =
|
||||
elements.iter().map(|ty| ty.normalized(db)).collect();
|
||||
|
||||
elements.sort_unstable_by(|l, r| union_or_intersection_elements_ordering(db, l, r));
|
||||
elements
|
||||
|
@ -5620,13 +5628,13 @@ impl<'db> IntersectionType<'db> {
|
|||
return true;
|
||||
}
|
||||
|
||||
let sorted_self = self.to_sorted_intersection(db);
|
||||
let sorted_self = self.normalized(db);
|
||||
|
||||
if sorted_self == other {
|
||||
return true;
|
||||
}
|
||||
|
||||
sorted_self == other.to_sorted_intersection(db)
|
||||
sorted_self == other.normalized(db)
|
||||
}
|
||||
|
||||
/// Return `true` if `self` has exactly the same set of possible static materializations as `other`
|
||||
|
@ -5642,13 +5650,13 @@ impl<'db> IntersectionType<'db> {
|
|||
return false;
|
||||
}
|
||||
|
||||
let sorted_self = self.to_sorted_intersection(db);
|
||||
let sorted_self = self.normalized(db);
|
||||
|
||||
if sorted_self == other {
|
||||
return true;
|
||||
}
|
||||
|
||||
let sorted_other = other.to_sorted_intersection(db);
|
||||
let sorted_other = other.normalized(db);
|
||||
|
||||
if sorted_self == sorted_other {
|
||||
return true;
|
||||
|
@ -5834,14 +5842,15 @@ impl<'db> TupleType<'db> {
|
|||
Type::Tuple(Self::new(db, elements.into_boxed_slice()))
|
||||
}
|
||||
|
||||
/// Return a normalized version of `self` in which all unions and intersections are sorted
|
||||
/// according to a canonical order, no matter how "deeply" a union/intersection may be nested.
|
||||
/// Return a normalized version of `self`.
|
||||
///
|
||||
/// See [`Type::normalized`] for more details.
|
||||
#[must_use]
|
||||
pub fn with_sorted_unions_and_intersections(self, db: &'db dyn Db) -> Self {
|
||||
pub fn normalized(self, db: &'db dyn Db) -> Self {
|
||||
let elements: Box<[Type<'db>]> = self
|
||||
.elements(db)
|
||||
.iter()
|
||||
.map(|ty| ty.with_sorted_unions_and_intersections(db))
|
||||
.map(|ty| ty.normalized(db))
|
||||
.collect();
|
||||
TupleType::new(db, elements)
|
||||
}
|
||||
|
|
|
@ -606,31 +606,54 @@ impl<'db> Parameter<'db> {
|
|||
self
|
||||
}
|
||||
|
||||
pub(crate) fn with_sorted_unions_and_intersections(mut self, db: &'db dyn Db) -> Self {
|
||||
self.annotated_type = self
|
||||
.annotated_type
|
||||
.map(|ty| ty.with_sorted_unions_and_intersections(db));
|
||||
/// Strip information from the parameter so that two equivalent parameters compare equal.
|
||||
/// Normalize nested unions and intersections in the annotated type, if any.
|
||||
///
|
||||
/// See [`Type::normalized`] for more details.
|
||||
pub(crate) fn normalized(&self, db: &'db dyn Db) -> Self {
|
||||
let Parameter {
|
||||
annotated_type,
|
||||
kind,
|
||||
form,
|
||||
} = self;
|
||||
|
||||
self.kind = match self.kind {
|
||||
ParameterKind::PositionalOnly { name, default_type } => ParameterKind::PositionalOnly {
|
||||
name,
|
||||
default_type: default_type.map(|ty| ty.with_sorted_unions_and_intersections(db)),
|
||||
// Ensure unions and intersections are ordered in the annotated type (if there is one)
|
||||
let annotated_type = annotated_type.map(|ty| ty.normalized(db));
|
||||
|
||||
// Ensure that parameter names are stripped from positional-only, variadic and keyword-variadic parameters.
|
||||
// Ensure that we only record whether a parameter *has* a default
|
||||
// (strip the precise *type* of the default from the parameter, replacing it with `Never`).
|
||||
let kind = match kind {
|
||||
ParameterKind::PositionalOnly {
|
||||
name: _,
|
||||
default_type,
|
||||
} => ParameterKind::PositionalOnly {
|
||||
name: None,
|
||||
default_type: default_type.map(|_| Type::Never),
|
||||
},
|
||||
ParameterKind::PositionalOrKeyword { name, default_type } => {
|
||||
ParameterKind::PositionalOrKeyword {
|
||||
name,
|
||||
default_type: default_type
|
||||
.map(|ty| ty.with_sorted_unions_and_intersections(db)),
|
||||
name: name.clone(),
|
||||
default_type: default_type.map(|_| Type::Never),
|
||||
}
|
||||
}
|
||||
ParameterKind::KeywordOnly { name, default_type } => ParameterKind::KeywordOnly {
|
||||
name,
|
||||
default_type: default_type.map(|ty| ty.with_sorted_unions_and_intersections(db)),
|
||||
name: name.clone(),
|
||||
default_type: default_type.map(|_| Type::Never),
|
||||
},
|
||||
ParameterKind::Variadic { name: _ } => ParameterKind::Variadic {
|
||||
name: Name::new_static("args"),
|
||||
},
|
||||
ParameterKind::KeywordVariadic { name: _ } => ParameterKind::KeywordVariadic {
|
||||
name: Name::new_static("kwargs"),
|
||||
},
|
||||
ParameterKind::Variadic { .. } | ParameterKind::KeywordVariadic { .. } => self.kind,
|
||||
};
|
||||
|
||||
self
|
||||
Self {
|
||||
annotated_type,
|
||||
kind,
|
||||
form: *form,
|
||||
}
|
||||
}
|
||||
|
||||
fn from_node_and_kind(
|
||||
|
|
|
@ -73,13 +73,17 @@ pub(super) fn union_or_intersection_elements_ordering<'db>(
|
|||
(Type::WrapperDescriptor(_), _) => Ordering::Less,
|
||||
(_, Type::WrapperDescriptor(_)) => Ordering::Greater,
|
||||
|
||||
(Type::Callable(left), Type::Callable(right)) => left.cmp(right),
|
||||
(Type::Callable(left), Type::Callable(right)) => {
|
||||
debug_assert_eq!(*left, left.normalized(db));
|
||||
debug_assert_eq!(*right, right.normalized(db));
|
||||
left.cmp(right)
|
||||
}
|
||||
(Type::Callable(_), _) => Ordering::Less,
|
||||
(_, Type::Callable(_)) => Ordering::Greater,
|
||||
|
||||
(Type::Tuple(left), Type::Tuple(right)) => {
|
||||
debug_assert_eq!(*left, left.with_sorted_unions_and_intersections(db));
|
||||
debug_assert_eq!(*right, right.with_sorted_unions_and_intersections(db));
|
||||
debug_assert_eq!(*left, left.normalized(db));
|
||||
debug_assert_eq!(*right, right.normalized(db));
|
||||
left.cmp(right)
|
||||
}
|
||||
(Type::Tuple(_), _) => Ordering::Less,
|
||||
|
@ -271,8 +275,8 @@ pub(super) fn union_or_intersection_elements_ordering<'db>(
|
|||
}
|
||||
|
||||
(Type::Intersection(left), Type::Intersection(right)) => {
|
||||
debug_assert_eq!(*left, left.to_sorted_intersection(db));
|
||||
debug_assert_eq!(*right, right.to_sorted_intersection(db));
|
||||
debug_assert_eq!(*left, left.normalized(db));
|
||||
debug_assert_eq!(*right, right.normalized(db));
|
||||
|
||||
if left == right {
|
||||
return Ordering::Equal;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue