[ty] Improve ability to solve TypeVars when they appear in unions (#19829)

This commit is contained in:
Alex Waygood 2025-08-08 17:50:37 +01:00 committed by GitHub
parent 6b0eadfb4d
commit 8489816edc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 98 additions and 32 deletions

View file

@ -395,3 +395,40 @@ def decorated(t: T) -> None:
# error: [redundant-cast]
reveal_type(cast(T, t)) # revealed: T@decorated
```
## Solving TypeVars with upper bounds in unions
```py
from typing import Generic, TypeVar
class A: ...
T = TypeVar("T", bound=A)
class B(Generic[T]):
x: T
def f(c: T | None):
return None
def g(b: B[T]):
return f(b.x) # Fine
```
## Constrained TypeVar in a union
This is a regression test for an issue that surfaced in the primer report of an early version of
<https://github.com/astral-sh/ruff/pull/19811>, where we failed to solve the `TypeVar` here due to
the fact that it only appears in the function's type annotations as part of a union:
```py
from typing import TypeVar
T = TypeVar("T", str, bytes)
def NamedTemporaryFile(suffix: T | None, prefix: T | None) -> None:
return None
def f(x: str):
NamedTemporaryFile(prefix=x, suffix=".tar.gz") # Fine
```

View file

@ -404,3 +404,32 @@ def decorated[T](t: T) -> None:
# error: [redundant-cast]
reveal_type(cast(T, t)) # revealed: T@decorated
```
## Solving TypeVars with upper bounds in unions
```py
class A: ...
class B[T: A]:
x: T
def f[T: A](c: T | None):
return None
def g[T: A](b: B[T]):
return f(b.x) # Fine
```
## Constrained TypeVar in a union
This is a regression test for an issue that surfaced in the primer report of an early version of
<https://github.com/astral-sh/ruff/pull/19811>, where we failed to solve the `TypeVar` here due to
the fact that it only appears in the function's type annotations as part of a union:
```py
def f[T: (str, bytes)](suffix: T | None, prefix: T | None):
return None
def g(x: str):
f(prefix=x, suffix=".tar.gz")
```

View file

@ -792,6 +792,38 @@ impl<'db> SpecializationBuilder<'db> {
}
match (formal, actual) {
(Type::Union(formal), _) => {
// TODO: We haven't implemented a full unification solver yet. If typevars appear
// in multiple union elements, we ideally want to express that _only one_ of them
// needs to match, and that we should infer the smallest type mapping that allows
// that.
//
// For now, we punt on handling multiple typevar elements. Instead, if _precisely
// one_ union element _is_ a typevar (not _contains_ a typevar), then we go ahead
// and add a mapping between that typevar and the actual type. (Note that we've
// already handled above the case where the actual is assignable to a _non-typevar_
// union element.)
let mut typevars = formal.iter(self.db).filter_map(|ty| match ty {
Type::TypeVar(typevar) => Some(*typevar),
_ => None,
});
let typevar = typevars.next();
let additional_typevars = typevars.next();
if let (Some(typevar), None) = (typevar, additional_typevars) {
self.add_type_mapping(typevar, actual);
}
}
(Type::Intersection(formal), _) => {
// The actual type must be assignable to every (positive) element of the
// formal intersection, so we must infer type mappings for each of them. (The
// actual type must also be disjoint from every negative element of the
// intersection, but that doesn't help us infer any type mappings.)
for positive in formal.iter_positive(self.db) {
self.infer(positive, actual)?;
}
}
(Type::TypeVar(typevar), ty) | (ty, Type::TypeVar(typevar)) => {
match typevar.bound_or_constraints(self.db) {
Some(TypeVarBoundOrConstraints::UpperBound(bound)) => {
@ -877,38 +909,6 @@ impl<'db> SpecializationBuilder<'db> {
}
}
(Type::Union(formal), _) => {
// TODO: We haven't implemented a full unification solver yet. If typevars appear
// in multiple union elements, we ideally want to express that _only one_ of them
// needs to match, and that we should infer the smallest type mapping that allows
// that.
//
// For now, we punt on handling multiple typevar elements. Instead, if _precisely
// one_ union element _is_ a typevar (not _contains_ a typevar), then we go ahead
// and add a mapping between that typevar and the actual type. (Note that we've
// already handled above the case where the actual is assignable to a _non-typevar_
// union element.)
let mut typevars = formal.iter(self.db).filter_map(|ty| match ty {
Type::TypeVar(typevar) => Some(*typevar),
_ => None,
});
let typevar = typevars.next();
let additional_typevars = typevars.next();
if let (Some(typevar), None) = (typevar, additional_typevars) {
self.add_type_mapping(typevar, actual);
}
}
(Type::Intersection(formal), _) => {
// The actual type must be assignable to every (positive) element of the
// formal intersection, so we must infer type mappings for each of them. (The
// actual type must also be disjoint from every negative element of the
// intersection, but that doesn't help us infer any type mappings.)
for positive in formal.iter_positive(self.db) {
self.infer(positive, actual)?;
}
}
// TODO: Add more forms that we can structurally induct into: type[C], callables
_ => {}
}