[ty] Use first matching constructor overload when inferring specializations (#18204)

This is a follow-on to #18155. For the example raised in
https://github.com/astral-sh/ty/issues/370:

```py
import tempfile

with tempfile.TemporaryDirectory() as tmp: ...
```

the new logic would notice that both overloads of `TemporaryDirectory`
match, and combine their specializations, resulting in an inferred type
of `str | bytes`.

This PR updates the logic to match our other handling of other calls,
where we only keep the _first_ matching overload. The result for this
example then becomes `str`, matching the runtime behavior. (We still do
not implement the full [overload resolution
algorithm](https://typing.python.org/en/latest/spec/overload.html#overload-call-evaluation)
from the spec.)
This commit is contained in:
Douglas Creager 2025-05-19 15:12:28 -04:00 committed by GitHub
parent 0ede831a3f
commit 4fad15805b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 20 additions and 15 deletions

View file

@ -354,6 +354,8 @@ class C(Generic[T]):
@overload
def __init__(self: "C[bytes]", x: bytes) -> None: ...
@overload
def __init__(self: "C[int]", x: bytes) -> None: ...
@overload
def __init__(self, x: int) -> None: ...
def __init__(self, x: str | bytes | int) -> None: ...
@ -369,6 +371,10 @@ C[bytes]("string") # error: [no-matching-overload]
C[bytes](b"bytes")
C[bytes](12)
C[int]("string") # error: [no-matching-overload]
C[int](b"bytes")
C[int](12)
C[None]("string") # error: [no-matching-overload]
C[None](b"bytes") # error: [no-matching-overload]
C[None](12)

View file

@ -290,6 +290,8 @@ class C[T]:
@overload
def __init__(self: C[bytes], x: bytes) -> None: ...
@overload
def __init__(self: C[int], x: bytes) -> None: ...
@overload
def __init__(self, x: int) -> None: ...
def __init__(self, x: str | bytes | int) -> None: ...
@ -305,6 +307,10 @@ C[bytes]("string") # error: [no-matching-overload]
C[bytes](b"bytes")
C[bytes](12)
C[int]("string") # error: [no-matching-overload]
C[int](b"bytes")
C[int](12)
C[None]("string") # error: [no-matching-overload]
C[None](b"bytes") # error: [no-matching-overload]
C[None](12)

View file

@ -4649,24 +4649,14 @@ impl<'db> Type<'db> {
}
}
fn combine_binding_specialization<'db>(
db: &'db dyn Db,
binding: &CallableBinding<'db>,
) -> Option<Specialization<'db>> {
binding
.matching_overloads()
.map(|(_, binding)| binding.inherited_specialization())
.reduce(|acc, specialization| {
combine_specializations(db, acc, specialization)
})
.flatten()
}
let new_specialization = new_call_outcome
.and_then(Result::ok)
.as_ref()
.and_then(Bindings::single_element)
.and_then(|binding| combine_binding_specialization(db, binding))
.into_iter()
.flat_map(CallableBinding::matching_overloads)
.next()
.and_then(|(_, binding)| binding.inherited_specialization())
.filter(|specialization| {
Some(specialization.generic_context(db)) == generic_context
});
@ -4674,7 +4664,10 @@ impl<'db> Type<'db> {
.and_then(Result::ok)
.as_ref()
.and_then(Bindings::single_element)
.and_then(|binding| combine_binding_specialization(db, binding))
.into_iter()
.flat_map(CallableBinding::matching_overloads)
.next()
.and_then(|(_, binding)| binding.inherited_specialization())
.filter(|specialization| {
Some(specialization.generic_context(db)) == generic_context
});