From b36ff75a240f9bd2b6bce5ace1388db854cfad3e Mon Sep 17 00:00:00 2001 From: Douglas Creager Date: Wed, 17 Dec 2025 13:35:52 -0500 Subject: [PATCH] [ty] Don't add identical lower/upper bounds multiple times when inferring specializations (#22030) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When inferring a specialization of a `Callable` type, we use the new constraint set implementation. In the example in https://github.com/astral-sh/ty/issues/1968, we end up with a constraint set that includes all of the following clauses: ``` U_co ≤ M1 | M2 | M3 | M4 | M5 | M6 | M7 M1 ≤ U_co ≤ M1 | M2 | M3 | M4 | M5 | M6 | M7 M2 ≤ U_co ≤ M1 | M2 | M3 | M4 | M5 | M6 | M7 M3 ≤ U_co ≤ M1 | M2 | M3 | M4 | M5 | M6 | M7 M4 ≤ U_co ≤ M1 | M2 | M3 | M4 | M5 | M6 | M7 M5 ≤ U_co ≤ M1 | M2 | M3 | M4 | M5 | M6 | M7 M6 ≤ U_co ≤ M1 | M2 | M3 | M4 | M5 | M6 | M7 M7 ≤ U_co ≤ M1 | M2 | M3 | M4 | M5 | M6 | M7 ``` In general, we take the upper bounds of those constraints to get the specialization. However, the upper bounds of those constraints are not all guaranteed to be the same, and so first we need to intersect them all together. In this case, the upper bounds are all identical, so their intersection is trivial: ``` U_co = M1 | M2 | M3 | M4 | M5 | M6 | M7 ``` But we were still doing the work of calculating that trivial intersection 7 times. And each time we have to do 7^2 comparisons of the `M*` classes, ending up with O(n^3) overall work. This pattern is common enough that we can put in a quick heuristic to prune identical copies of the same type before performing the intersection. Fixes https://github.com/astral-sh/ty/issues/1968 --- .../mdtest/generics/pep695/functions.md | 45 +++++++++++++++++++ .../ty_python_semantic/src/types/generics.rs | 12 ++--- 2 files changed, 51 insertions(+), 6 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md b/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md index 8121ce5d26..3cdebe848e 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md @@ -738,6 +738,8 @@ def f[T](x: T, y: Not[T]) -> T: ## `Callable` parameters +### Class constructors + We can recurse into the parameters and return values of `Callable` parameters to infer specializations of a generic function. @@ -891,3 +893,46 @@ def _(x: list[str]): # error: [invalid-argument-type] reveal_type(accepts_callable(GenericClass)(x, x)) ``` + +### Don't include identical lower/upper bounds in type mapping multiple times + +This is was a performance regression reported in +[ty#1968](https://github.com/astral-sh/ty/issues/1968). Before fixing this, we would see the +`U ≤ M1 | ... | M7` upper bound 7 times. Since we intersect upper bounds before recording a single +type mapping, we would perform 7 intersections. Each intersection would require 7^2 comparisons of +the `Mx` types. We now have a simple heuristics that avoids processing any identical lower or upper +bound more than once, since we know the extra copies cannot affect the result. + +```py +from typing import Callable, Generic, TypeVar, Union + +class M1: ... +class M2: ... +class M3: ... +class M4: ... +class M5: ... +class M6: ... +class M7: ... + +Msg = Union[M1, M2, M3, M4, M5, M6, M7] + +T = TypeVar("T") +U_co = TypeVar("U_co", covariant=True) + +class Stream(Generic[T]): + def apply(self, func: Callable[["Stream[T]"], "Stream[U_co]"]) -> "Stream[U_co]": + return func(self) + +TMsg = TypeVar("TMsg", bound=Msg) + +class Builder(Generic[TMsg]): + def build(self) -> Stream[TMsg]: + stream: Stream[TMsg] = Stream() + # TODO: no error + # error: [invalid-assignment] + stream = stream.apply(self._handler) + return stream + + def _handler(self, stream: Stream[Msg]) -> Stream[Msg]: + return stream +``` diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index c012ab09f6..ca6a700bd2 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -1586,8 +1586,8 @@ impl<'db> SpecializationBuilder<'db> { ) { #[derive(Default)] struct Bounds<'db> { - lower: Vec>, - upper: Vec>, + lower: FxOrderSet>, + upper: FxOrderSet>, } let constraints = constraints.limit_to_valid_specializations(self.db); @@ -1611,17 +1611,17 @@ impl<'db> SpecializationBuilder<'db> { let lower = constraint.lower(self.db); let upper = constraint.upper(self.db); let bounds = mappings.entry(typevar).or_default(); - bounds.lower.push(lower); - bounds.upper.push(upper); + bounds.lower.insert(lower); + bounds.upper.insert(upper); if let Type::TypeVar(lower_bound_typevar) = lower { let bounds = mappings.entry(lower_bound_typevar).or_default(); - bounds.upper.push(Type::TypeVar(typevar)); + bounds.upper.insert(Type::TypeVar(typevar)); } if let Type::TypeVar(upper_bound_typevar) = upper { let bounds = mappings.entry(upper_bound_typevar).or_default(); - bounds.lower.push(Type::TypeVar(typevar)); + bounds.lower.insert(Type::TypeVar(typevar)); } }