ruff/crates/ty_python_semantic/resources/mdtest/bidirectional.md
Shunsuke Shibayama dc64c08633
Some checks failed
CI / cargo fmt (push) Has been cancelled
CI / mkdocs (push) Has been cancelled
CI / Determine changes (push) Has been cancelled
CI / cargo build (release) (push) Has been cancelled
CI / python package (push) Has been cancelled
CI / pre-commit (push) Has been cancelled
[ty Playground] Release / publish (push) Has been cancelled
CI / cargo clippy (push) Has been cancelled
CI / cargo test (linux) (push) Has been cancelled
CI / cargo test (linux, release) (push) Has been cancelled
CI / cargo test (windows) (push) Has been cancelled
CI / cargo test (wasm) (push) Has been cancelled
CI / cargo build (msrv) (push) Has been cancelled
CI / cargo fuzz build (push) Has been cancelled
CI / fuzz parser (push) Has been cancelled
CI / test scripts (push) Has been cancelled
CI / ecosystem (push) Has been cancelled
CI / Fuzz for new ty panics (push) Has been cancelled
CI / cargo shear (push) Has been cancelled
CI / ty completion evaluation (push) Has been cancelled
CI / formatter instabilities and black similarity (push) Has been cancelled
CI / test ruff-lsp (push) Has been cancelled
CI / check playground (push) Has been cancelled
CI / benchmarks instrumented (ruff) (push) Has been cancelled
CI / benchmarks instrumented (ty) (push) Has been cancelled
CI / benchmarks walltime (medium|multithreaded) (push) Has been cancelled
CI / benchmarks walltime (small|large) (push) Has been cancelled
[ty] bidirectional type inference using function return type annotations (#20528)
## Summary

Implements bidirectional type inference using function return type
annotations.

This PR was originally proposed to solve astral-sh/ty#1167, but this
does not fully resolve it on its own.
Additionally, I believe we need to allow dataclasses to generate their
own `__new__` methods, [use constructor return types ​​for
inference](5844c0103d/crates/ty_python_semantic/src/types.rs (L5326-L5328)),
and a mechanism to discard type narrowing like `& ~AlwaysFalsy` if
necessary (at a more general level than this PR).

## Test Plan

`mdtest/bidirectional.md` is added.

---------

Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
Co-authored-by: Ibraheem Ahmed <ibraheem@ibraheem.ca>
2025-10-11 00:38:35 +00:00

4.4 KiB

Bidirectional type inference

ty partially supports bidirectional type inference. This is a mechanism for inferring the type of an expression "from the outside in". Normally, type inference proceeds "from the inside out". That is, in order to infer the type of an expression, the types of all sub-expressions must first be inferred. There is no reverse dependency. However, when performing complex type inference, such as when generics are involved, the type of an outer expression can sometimes be useful in inferring inner expressions. Bidirectional type inference is a mechanism that propagates such "expected types" to the inference of inner expressions.

Propagating target type annotation

[environment]
python-version = "3.12"
def list1[T](x: T) -> list[T]:
    return [x]

l1 = list1(1)
reveal_type(l1)  # revealed: list[Literal[1]]
l2: list[int] = list1(1)
reveal_type(l2)  # revealed: list[int]

# `list[Literal[1]]` and `list[int]` are incompatible, since `list[T]` is invariant in `T`.
# error: [invalid-assignment] "Object of type `list[Literal[1]]` is not assignable to `list[int]`"
l2 = l1

intermediate = list1(1)
# TODO: the error will not occur if we can infer the type of `intermediate` to be `list[int]`
# error: [invalid-assignment] "Object of type `list[Literal[1]]` is not assignable to `list[int]`"
l3: list[int] = intermediate
# TODO: it would be nice if this were `list[int]`
reveal_type(intermediate)  # revealed: list[Literal[1]]
reveal_type(l3)  # revealed: list[int]

l4: list[int | str] | None = list1(1)
reveal_type(l4)  # revealed: list[int | str]

def _(l: list[int] | None = None):
    l1 = l or list()
    reveal_type(l1)  # revealed: (list[int] & ~AlwaysFalsy) | list[Unknown]

    l2: list[int] = l or list()
    # it would be better if this were `list[int]`? (https://github.com/astral-sh/ty/issues/136)
    reveal_type(l2)  # revealed: (list[int] & ~AlwaysFalsy) | list[Unknown]

def f[T](x: T, cond: bool) -> T | list[T]:
    return x if cond else [x]

# TODO: no error
# error: [invalid-assignment] "Object of type `Literal[1] | list[Literal[1]]` is not assignable to `int | list[int]`"
l5: int | list[int] = f(1, True)

typed_dict.py:

from typing import TypedDict

class TD(TypedDict):
    x: int

d1 = {"x": 1}
d2: TD = {"x": 1}
d3: dict[str, int] = {"x": 1}

reveal_type(d1)  # revealed: dict[Unknown | str, Unknown | int]
reveal_type(d2)  # revealed: TD
reveal_type(d3)  # revealed: dict[str, int]

def _() -> TD:
    return {"x": 1}

def _() -> TD:
    # error: [missing-typed-dict-key] "Missing required key 'x' in TypedDict `TD` constructor"
    return {}

Propagating return type annotation

[environment]
python-version = "3.12"
from typing import overload, Callable

def list1[T](x: T) -> list[T]:
    return [x]

def get_data() -> dict | None:
    return {}

def wrap_data() -> list[dict]:
    if not (res := get_data()):
        return list1({})
    reveal_type(list1(res))  # revealed: list[dict[Unknown, Unknown] & ~AlwaysFalsy]
    # `list[dict[Unknown, Unknown] & ~AlwaysFalsy]` and `list[dict[Unknown, Unknown]]` are incompatible,
    # but the return type check passes here because the type of `list1(res)` is inferred
    # by bidirectional type inference using the annotated return type, and the type of `res` is not used.
    return list1(res)

def wrap_data2() -> list[dict] | None:
    if not (res := get_data()):
        return None
    reveal_type(list1(res))  # revealed: list[dict[Unknown, Unknown] & ~AlwaysFalsy]
    return list1(res)

def deco[T](func: Callable[[], T]) -> Callable[[], T]:
    return func

def outer() -> Callable[[], list[dict]]:
    @deco
    def inner() -> list[dict]:
        if not (res := get_data()):
            return list1({})
        reveal_type(list1(res))  # revealed: list[dict[Unknown, Unknown] & ~AlwaysFalsy]
        return list1(res)
    return inner

@overload
def f(x: int) -> list[int]: ...
@overload
def f(x: str) -> list[str]: ...
def f(x: int | str) -> list[int] | list[str]:
    # `list[int] | list[str]` is disjoint from `list[int | str]`.
    if isinstance(x, int):
        return list1(x)
    else:
        return list1(x)

reveal_type(f(1))  # revealed: list[int]
reveal_type(f("a"))  # revealed: list[str]

async def g() -> list[int | str]:
    return list1(1)

def h[T](x: T, cond: bool) -> T | list[T]:
    return i(x, cond)

def i[T](x: T, cond: bool) -> T | list[T]:
    return x if cond else [x]