[ty] Don't add identical lower/upper bounds multiple times when inferring specializations (#22030)

When inferring a specialization of a `Callable` type, we use the new
constraint set implementation. In the example in
https://github.com/astral-sh/ty/issues/1968, we end up with a constraint
set that includes all of the following clauses:

```
     U_co ≤ M1 | M2 | M3 | M4 | M5 | M6 | M7
M1 ≤ U_co ≤ M1 | M2 | M3 | M4 | M5 | M6 | M7
M2 ≤ U_co ≤ M1 | M2 | M3 | M4 | M5 | M6 | M7
M3 ≤ U_co ≤ M1 | M2 | M3 | M4 | M5 | M6 | M7
M4 ≤ U_co ≤ M1 | M2 | M3 | M4 | M5 | M6 | M7
M5 ≤ U_co ≤ M1 | M2 | M3 | M4 | M5 | M6 | M7
M6 ≤ U_co ≤ M1 | M2 | M3 | M4 | M5 | M6 | M7
M7 ≤ U_co ≤ M1 | M2 | M3 | M4 | M5 | M6 | M7
```

In general, we take the upper bounds of those constraints to get the
specialization. However, the upper bounds of those constraints are not
all guaranteed to be the same, and so first we need to intersect them
all together. In this case, the upper bounds are all identical, so their
intersection is trivial:

```
U_co = M1 | M2 | M3 | M4 | M5 | M6 | M7
```

But we were still doing the work of calculating that trivial
intersection 7 times. And each time we have to do 7^2 comparisons of the
`M*` classes, ending up with O(n^3) overall work.

This pattern is common enough that we can put in a quick heuristic to
prune identical copies of the same type before performing the
intersection.

Fixes https://github.com/astral-sh/ty/issues/1968
This commit is contained in:
Douglas Creager 2025-12-17 13:35:52 -05:00 committed by GitHub
parent 30c3f9aafe
commit b36ff75a24
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 51 additions and 6 deletions

View file

@ -738,6 +738,8 @@ def f[T](x: T, y: Not[T]) -> T:
## `Callable` parameters
### Class constructors
We can recurse into the parameters and return values of `Callable` parameters to infer
specializations of a generic function.
@ -891,3 +893,46 @@ def _(x: list[str]):
# error: [invalid-argument-type]
reveal_type(accepts_callable(GenericClass)(x, x))
```
### Don't include identical lower/upper bounds in type mapping multiple times
This is was a performance regression reported in
[ty#1968](https://github.com/astral-sh/ty/issues/1968). Before fixing this, we would see the
`U ≤ M1 | ... | M7` upper bound 7 times. Since we intersect upper bounds before recording a single
type mapping, we would perform 7 intersections. Each intersection would require 7^2 comparisons of
the `Mx` types. We now have a simple heuristics that avoids processing any identical lower or upper
bound more than once, since we know the extra copies cannot affect the result.
```py
from typing import Callable, Generic, TypeVar, Union
class M1: ...
class M2: ...
class M3: ...
class M4: ...
class M5: ...
class M6: ...
class M7: ...
Msg = Union[M1, M2, M3, M4, M5, M6, M7]
T = TypeVar("T")
U_co = TypeVar("U_co", covariant=True)
class Stream(Generic[T]):
def apply(self, func: Callable[["Stream[T]"], "Stream[U_co]"]) -> "Stream[U_co]":
return func(self)
TMsg = TypeVar("TMsg", bound=Msg)
class Builder(Generic[TMsg]):
def build(self) -> Stream[TMsg]:
stream: Stream[TMsg] = Stream()
# TODO: no error
# error: [invalid-assignment]
stream = stream.apply(self._handler)
return stream
def _handler(self, stream: Stream[Msg]) -> Stream[Msg]:
return stream
```

View file

@ -1586,8 +1586,8 @@ impl<'db> SpecializationBuilder<'db> {
) {
#[derive(Default)]
struct Bounds<'db> {
lower: Vec<Type<'db>>,
upper: Vec<Type<'db>>,
lower: FxOrderSet<Type<'db>>,
upper: FxOrderSet<Type<'db>>,
}
let constraints = constraints.limit_to_valid_specializations(self.db);
@ -1611,17 +1611,17 @@ impl<'db> SpecializationBuilder<'db> {
let lower = constraint.lower(self.db);
let upper = constraint.upper(self.db);
let bounds = mappings.entry(typevar).or_default();
bounds.lower.push(lower);
bounds.upper.push(upper);
bounds.lower.insert(lower);
bounds.upper.insert(upper);
if let Type::TypeVar(lower_bound_typevar) = lower {
let bounds = mappings.entry(lower_bound_typevar).or_default();
bounds.upper.push(Type::TypeVar(typevar));
bounds.upper.insert(Type::TypeVar(typevar));
}
if let Type::TypeVar(upper_bound_typevar) = upper {
let bounds = mappings.entry(upper_bound_typevar).or_default();
bounds.lower.push(Type::TypeVar(typevar));
bounds.lower.insert(Type::TypeVar(typevar));
}
}