mirror of
https://github.com/astral-sh/ruff.git
synced 2025-12-23 09:19:39 +00:00
[ty] Don't add identical lower/upper bounds multiple times when inferring specializations (#22030)
When inferring a specialization of a `Callable` type, we use the new constraint set implementation. In the example in https://github.com/astral-sh/ty/issues/1968, we end up with a constraint set that includes all of the following clauses: ``` U_co ≤ M1 | M2 | M3 | M4 | M5 | M6 | M7 M1 ≤ U_co ≤ M1 | M2 | M3 | M4 | M5 | M6 | M7 M2 ≤ U_co ≤ M1 | M2 | M3 | M4 | M5 | M6 | M7 M3 ≤ U_co ≤ M1 | M2 | M3 | M4 | M5 | M6 | M7 M4 ≤ U_co ≤ M1 | M2 | M3 | M4 | M5 | M6 | M7 M5 ≤ U_co ≤ M1 | M2 | M3 | M4 | M5 | M6 | M7 M6 ≤ U_co ≤ M1 | M2 | M3 | M4 | M5 | M6 | M7 M7 ≤ U_co ≤ M1 | M2 | M3 | M4 | M5 | M6 | M7 ``` In general, we take the upper bounds of those constraints to get the specialization. However, the upper bounds of those constraints are not all guaranteed to be the same, and so first we need to intersect them all together. In this case, the upper bounds are all identical, so their intersection is trivial: ``` U_co = M1 | M2 | M3 | M4 | M5 | M6 | M7 ``` But we were still doing the work of calculating that trivial intersection 7 times. And each time we have to do 7^2 comparisons of the `M*` classes, ending up with O(n^3) overall work. This pattern is common enough that we can put in a quick heuristic to prune identical copies of the same type before performing the intersection. Fixes https://github.com/astral-sh/ty/issues/1968
This commit is contained in:
parent
30c3f9aafe
commit
b36ff75a24
2 changed files with 51 additions and 6 deletions
|
|
@ -738,6 +738,8 @@ def f[T](x: T, y: Not[T]) -> T:
|
|||
|
||||
## `Callable` parameters
|
||||
|
||||
### Class constructors
|
||||
|
||||
We can recurse into the parameters and return values of `Callable` parameters to infer
|
||||
specializations of a generic function.
|
||||
|
||||
|
|
@ -891,3 +893,46 @@ def _(x: list[str]):
|
|||
# error: [invalid-argument-type]
|
||||
reveal_type(accepts_callable(GenericClass)(x, x))
|
||||
```
|
||||
|
||||
### Don't include identical lower/upper bounds in type mapping multiple times
|
||||
|
||||
This is was a performance regression reported in
|
||||
[ty#1968](https://github.com/astral-sh/ty/issues/1968). Before fixing this, we would see the
|
||||
`U ≤ M1 | ... | M7` upper bound 7 times. Since we intersect upper bounds before recording a single
|
||||
type mapping, we would perform 7 intersections. Each intersection would require 7^2 comparisons of
|
||||
the `Mx` types. We now have a simple heuristics that avoids processing any identical lower or upper
|
||||
bound more than once, since we know the extra copies cannot affect the result.
|
||||
|
||||
```py
|
||||
from typing import Callable, Generic, TypeVar, Union
|
||||
|
||||
class M1: ...
|
||||
class M2: ...
|
||||
class M3: ...
|
||||
class M4: ...
|
||||
class M5: ...
|
||||
class M6: ...
|
||||
class M7: ...
|
||||
|
||||
Msg = Union[M1, M2, M3, M4, M5, M6, M7]
|
||||
|
||||
T = TypeVar("T")
|
||||
U_co = TypeVar("U_co", covariant=True)
|
||||
|
||||
class Stream(Generic[T]):
|
||||
def apply(self, func: Callable[["Stream[T]"], "Stream[U_co]"]) -> "Stream[U_co]":
|
||||
return func(self)
|
||||
|
||||
TMsg = TypeVar("TMsg", bound=Msg)
|
||||
|
||||
class Builder(Generic[TMsg]):
|
||||
def build(self) -> Stream[TMsg]:
|
||||
stream: Stream[TMsg] = Stream()
|
||||
# TODO: no error
|
||||
# error: [invalid-assignment]
|
||||
stream = stream.apply(self._handler)
|
||||
return stream
|
||||
|
||||
def _handler(self, stream: Stream[Msg]) -> Stream[Msg]:
|
||||
return stream
|
||||
```
|
||||
|
|
|
|||
|
|
@ -1586,8 +1586,8 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
) {
|
||||
#[derive(Default)]
|
||||
struct Bounds<'db> {
|
||||
lower: Vec<Type<'db>>,
|
||||
upper: Vec<Type<'db>>,
|
||||
lower: FxOrderSet<Type<'db>>,
|
||||
upper: FxOrderSet<Type<'db>>,
|
||||
}
|
||||
|
||||
let constraints = constraints.limit_to_valid_specializations(self.db);
|
||||
|
|
@ -1611,17 +1611,17 @@ impl<'db> SpecializationBuilder<'db> {
|
|||
let lower = constraint.lower(self.db);
|
||||
let upper = constraint.upper(self.db);
|
||||
let bounds = mappings.entry(typevar).or_default();
|
||||
bounds.lower.push(lower);
|
||||
bounds.upper.push(upper);
|
||||
bounds.lower.insert(lower);
|
||||
bounds.upper.insert(upper);
|
||||
|
||||
if let Type::TypeVar(lower_bound_typevar) = lower {
|
||||
let bounds = mappings.entry(lower_bound_typevar).or_default();
|
||||
bounds.upper.push(Type::TypeVar(typevar));
|
||||
bounds.upper.insert(Type::TypeVar(typevar));
|
||||
}
|
||||
|
||||
if let Type::TypeVar(upper_bound_typevar) = upper {
|
||||
let bounds = mappings.entry(upper_bound_typevar).or_default();
|
||||
bounds.lower.push(Type::TypeVar(typevar));
|
||||
bounds.lower.insert(Type::TypeVar(typevar));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue