diff --git a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md index b2606cf434..031de910a0 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md +++ b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md @@ -134,6 +134,7 @@ def f1(a: int = 1) -> None: ... def f2(a: int = 2) -> None: ... static_assert(is_equivalent_to(CallableTypeOf[f1], CallableTypeOf[f2])) +static_assert(is_equivalent_to(CallableTypeOf[f1] | bool | CallableTypeOf[f2], CallableTypeOf[f2] | bool | CallableTypeOf[f1])) ``` The names of the positional-only, variadic and keyword-variadic parameters does not need to be the @@ -144,6 +145,7 @@ def f3(a1: int, /, *args1: int, **kwargs2: int) -> None: ... def f4(a2: int, /, *args2: int, **kwargs1: int) -> None: ... static_assert(is_equivalent_to(CallableTypeOf[f3], CallableTypeOf[f4])) +static_assert(is_equivalent_to(CallableTypeOf[f3] | bool | CallableTypeOf[f4], CallableTypeOf[f4] | bool | CallableTypeOf[f3])) ``` Putting it all together, the following two callables are equivalent: @@ -153,6 +155,7 @@ def f5(a1: int, /, b: float, c: bool = False, *args1: int, d: int = 1, e: str, * def f6(a2: int, /, b: float, c: bool = True, *args2: int, d: int = 2, e: str, **kwargs2: float) -> None: ... static_assert(is_equivalent_to(CallableTypeOf[f5], CallableTypeOf[f6])) +static_assert(is_equivalent_to(CallableTypeOf[f5] | bool | CallableTypeOf[f6], CallableTypeOf[f6] | bool | CallableTypeOf[f5])) ``` ### Not equivalent diff --git a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_gradual_equivalent_to.md b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_gradual_equivalent_to.md index 6d96420e33..7bc43b69e7 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_gradual_equivalent_to.md +++ b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_gradual_equivalent_to.md @@ -147,6 +147,9 @@ def f4(a=2): ... def f5(a): ... static_assert(is_gradual_equivalent_to(CallableTypeOf[f3], CallableTypeOf[f4])) +static_assert( + is_gradual_equivalent_to(CallableTypeOf[f3] | bool | CallableTypeOf[f4], CallableTypeOf[f4] | bool | CallableTypeOf[f3]) +) static_assert(not is_gradual_equivalent_to(CallableTypeOf[f3], CallableTypeOf[f5])) def f6(a, /): ... diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 879ba7098a..2bed15f04b 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -597,19 +597,22 @@ impl<'db> Type<'db> { } } - /// Return a normalized version of `self` in which all unions and intersections are sorted - /// according to a canonical order, no matter how "deeply" a union/intersection may be nested. + /// Return a "normalized" version of `self` that ensures that equivalent types have the same Salsa ID. + /// + /// A normalized type: + /// - Has all unions and intersections sorted according to a canonical order, + /// no matter how "deeply" a union/intersection may be nested. + /// - Strips the names of positional-only parameters and variadic parameters from `Callable` types, + /// as these are irrelevant to whether a callable type `X` is equivalent to a callable type `Y`. + /// - Strips the types of default values from parameters in `Callable` types: only whether a parameter + /// *has* or *does not have* a default value is relevant to whether two `Callable` types are equivalent. #[must_use] - pub fn with_sorted_unions_and_intersections(self, db: &'db dyn Db) -> Self { + pub fn normalized(self, db: &'db dyn Db) -> Self { match self { - Type::Union(union) => Type::Union(union.to_sorted_union(db)), - Type::Intersection(intersection) => { - Type::Intersection(intersection.to_sorted_intersection(db)) - } - Type::Tuple(tuple) => Type::Tuple(tuple.with_sorted_unions_and_intersections(db)), - Type::Callable(callable) => { - Type::Callable(callable.with_sorted_unions_and_intersections(db)) - } + Type::Union(union) => Type::Union(union.normalized(db)), + Type::Intersection(intersection) => Type::Intersection(intersection.normalized(db)), + Type::Tuple(tuple) => Type::Tuple(tuple.normalized(db)), + Type::Callable(callable) => Type::Callable(callable.normalized(db)), Type::LiteralString | Type::Instance(_) | Type::PropertyInstance(_) @@ -4676,16 +4679,19 @@ impl<'db> CallableType<'db> { ) } - fn with_sorted_unions_and_intersections(self, db: &'db dyn Db) -> Self { + /// Return a "normalized" version of this `Callable` type. + /// + /// See [`Type::normalized`] for more details. + fn normalized(self, db: &'db dyn Db) -> Self { let signature = self.signature(db); let parameters = signature .parameters() .iter() - .map(|param| param.clone().with_sorted_unions_and_intersections(db)) + .map(|param| param.normalized(db)) .collect(); let return_ty = signature .return_ty - .map(|return_ty| return_ty.with_sorted_unions_and_intersections(db)); + .map(|return_ty| return_ty.normalized(db)); CallableType::new(db, Signature::new(parameters, return_ty)) } @@ -5447,13 +5453,15 @@ impl<'db> UnionType<'db> { self.elements(db).iter().all(|ty| ty.is_fully_static(db)) } - /// Create a new union type with the elements sorted according to a canonical ordering. + /// Create a new union type with the elements normalized. + /// + /// See [`Type::normalized`] for more details. #[must_use] - pub fn to_sorted_union(self, db: &'db dyn Db) -> Self { + pub fn normalized(self, db: &'db dyn Db) -> Self { let mut new_elements: Vec> = self .elements(db) .iter() - .map(|element| element.with_sorted_unions_and_intersections(db)) + .map(|element| element.normalized(db)) .collect(); new_elements.sort_unstable_by(|l, r| union_or_intersection_elements_ordering(db, l, r)); UnionType::new(db, new_elements.into_boxed_slice()) @@ -5487,13 +5495,13 @@ impl<'db> UnionType<'db> { return true; } - let sorted_self = self.to_sorted_union(db); + let sorted_self = self.normalized(db); if sorted_self == other { return true; } - sorted_self == other.to_sorted_union(db) + sorted_self == other.normalized(db) } /// Return `true` if `self` has exactly the same set of possible static materializations as `other` @@ -5510,13 +5518,13 @@ impl<'db> UnionType<'db> { return false; } - let sorted_self = self.to_sorted_union(db); + let sorted_self = self.normalized(db); if sorted_self == other { return true; } - let sorted_other = other.to_sorted_union(db); + let sorted_other = other.normalized(db); if sorted_self == sorted_other { return true; @@ -5547,17 +5555,17 @@ pub struct IntersectionType<'db> { impl<'db> IntersectionType<'db> { /// Return a new `IntersectionType` instance with the positive and negative types sorted - /// according to a canonical ordering. + /// according to a canonical ordering, and other normalizations applied to each element as applicable. + /// + /// See [`Type::normalized`] for more details. #[must_use] - pub fn to_sorted_intersection(self, db: &'db dyn Db) -> Self { + pub fn normalized(self, db: &'db dyn Db) -> Self { fn normalized_set<'db>( db: &'db dyn Db, elements: &FxOrderSet>, ) -> FxOrderSet> { - let mut elements: FxOrderSet> = elements - .iter() - .map(|ty| ty.with_sorted_unions_and_intersections(db)) - .collect(); + let mut elements: FxOrderSet> = + elements.iter().map(|ty| ty.normalized(db)).collect(); elements.sort_unstable_by(|l, r| union_or_intersection_elements_ordering(db, l, r)); elements @@ -5620,13 +5628,13 @@ impl<'db> IntersectionType<'db> { return true; } - let sorted_self = self.to_sorted_intersection(db); + let sorted_self = self.normalized(db); if sorted_self == other { return true; } - sorted_self == other.to_sorted_intersection(db) + sorted_self == other.normalized(db) } /// Return `true` if `self` has exactly the same set of possible static materializations as `other` @@ -5642,13 +5650,13 @@ impl<'db> IntersectionType<'db> { return false; } - let sorted_self = self.to_sorted_intersection(db); + let sorted_self = self.normalized(db); if sorted_self == other { return true; } - let sorted_other = other.to_sorted_intersection(db); + let sorted_other = other.normalized(db); if sorted_self == sorted_other { return true; @@ -5834,14 +5842,15 @@ impl<'db> TupleType<'db> { Type::Tuple(Self::new(db, elements.into_boxed_slice())) } - /// Return a normalized version of `self` in which all unions and intersections are sorted - /// according to a canonical order, no matter how "deeply" a union/intersection may be nested. + /// Return a normalized version of `self`. + /// + /// See [`Type::normalized`] for more details. #[must_use] - pub fn with_sorted_unions_and_intersections(self, db: &'db dyn Db) -> Self { + pub fn normalized(self, db: &'db dyn Db) -> Self { let elements: Box<[Type<'db>]> = self .elements(db) .iter() - .map(|ty| ty.with_sorted_unions_and_intersections(db)) + .map(|ty| ty.normalized(db)) .collect(); TupleType::new(db, elements) } diff --git a/crates/red_knot_python_semantic/src/types/signatures.rs b/crates/red_knot_python_semantic/src/types/signatures.rs index 23ffcdc22c..f67fa7c79c 100644 --- a/crates/red_knot_python_semantic/src/types/signatures.rs +++ b/crates/red_knot_python_semantic/src/types/signatures.rs @@ -606,31 +606,54 @@ impl<'db> Parameter<'db> { self } - pub(crate) fn with_sorted_unions_and_intersections(mut self, db: &'db dyn Db) -> Self { - self.annotated_type = self - .annotated_type - .map(|ty| ty.with_sorted_unions_and_intersections(db)); + /// Strip information from the parameter so that two equivalent parameters compare equal. + /// Normalize nested unions and intersections in the annotated type, if any. + /// + /// See [`Type::normalized`] for more details. + pub(crate) fn normalized(&self, db: &'db dyn Db) -> Self { + let Parameter { + annotated_type, + kind, + form, + } = self; - self.kind = match self.kind { - ParameterKind::PositionalOnly { name, default_type } => ParameterKind::PositionalOnly { - name, - default_type: default_type.map(|ty| ty.with_sorted_unions_and_intersections(db)), + // Ensure unions and intersections are ordered in the annotated type (if there is one) + let annotated_type = annotated_type.map(|ty| ty.normalized(db)); + + // Ensure that parameter names are stripped from positional-only, variadic and keyword-variadic parameters. + // Ensure that we only record whether a parameter *has* a default + // (strip the precise *type* of the default from the parameter, replacing it with `Never`). + let kind = match kind { + ParameterKind::PositionalOnly { + name: _, + default_type, + } => ParameterKind::PositionalOnly { + name: None, + default_type: default_type.map(|_| Type::Never), }, ParameterKind::PositionalOrKeyword { name, default_type } => { ParameterKind::PositionalOrKeyword { - name, - default_type: default_type - .map(|ty| ty.with_sorted_unions_and_intersections(db)), + name: name.clone(), + default_type: default_type.map(|_| Type::Never), } } ParameterKind::KeywordOnly { name, default_type } => ParameterKind::KeywordOnly { - name, - default_type: default_type.map(|ty| ty.with_sorted_unions_and_intersections(db)), + name: name.clone(), + default_type: default_type.map(|_| Type::Never), + }, + ParameterKind::Variadic { name: _ } => ParameterKind::Variadic { + name: Name::new_static("args"), + }, + ParameterKind::KeywordVariadic { name: _ } => ParameterKind::KeywordVariadic { + name: Name::new_static("kwargs"), }, - ParameterKind::Variadic { .. } | ParameterKind::KeywordVariadic { .. } => self.kind, }; - self + Self { + annotated_type, + kind, + form: *form, + } } fn from_node_and_kind( diff --git a/crates/red_knot_python_semantic/src/types/type_ordering.rs b/crates/red_knot_python_semantic/src/types/type_ordering.rs index b73b59d57b..e7f199aa9a 100644 --- a/crates/red_knot_python_semantic/src/types/type_ordering.rs +++ b/crates/red_knot_python_semantic/src/types/type_ordering.rs @@ -73,13 +73,17 @@ pub(super) fn union_or_intersection_elements_ordering<'db>( (Type::WrapperDescriptor(_), _) => Ordering::Less, (_, Type::WrapperDescriptor(_)) => Ordering::Greater, - (Type::Callable(left), Type::Callable(right)) => left.cmp(right), + (Type::Callable(left), Type::Callable(right)) => { + debug_assert_eq!(*left, left.normalized(db)); + debug_assert_eq!(*right, right.normalized(db)); + left.cmp(right) + } (Type::Callable(_), _) => Ordering::Less, (_, Type::Callable(_)) => Ordering::Greater, (Type::Tuple(left), Type::Tuple(right)) => { - debug_assert_eq!(*left, left.with_sorted_unions_and_intersections(db)); - debug_assert_eq!(*right, right.with_sorted_unions_and_intersections(db)); + debug_assert_eq!(*left, left.normalized(db)); + debug_assert_eq!(*right, right.normalized(db)); left.cmp(right) } (Type::Tuple(_), _) => Ordering::Less, @@ -271,8 +275,8 @@ pub(super) fn union_or_intersection_elements_ordering<'db>( } (Type::Intersection(left), Type::Intersection(right)) => { - debug_assert_eq!(*left, left.to_sorted_intersection(db)); - debug_assert_eq!(*right, right.to_sorted_intersection(db)); + debug_assert_eq!(*left, left.normalized(db)); + debug_assert_eq!(*right, right.normalized(db)); if left == right { return Ordering::Equal;