[ty] Propagate type context through conditional expressions (#21443)

## Summary

Resolves https://github.com/astral-sh/ty/issues/1543.
This commit is contained in:
Ibraheem Ahmed 2025-11-14 15:19:08 -05:00 committed by GitHub
parent 0a55327d64
commit ffb7bdd595
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 30 additions and 7 deletions

View file

@ -341,8 +341,6 @@ d: X[Any] = X(1)
reveal_type(d) # revealed: X[Any]
def _(flag: bool):
# TODO: Handle unions correctly.
# error: [invalid-assignment] "Object of type `X[int]` is not assignable to `X[int | None]`"
a: X[int | None] = X(1) if flag else X(2)
reveal_type(a) # revealed: X[int | None]
```

View file

@ -281,6 +281,27 @@ A(f(1))
A(f([]))
```
## Conditional expressions
```toml
[environment]
python-version = "3.12"
```
The type context is propagated through both branches of conditional expressions:
```py
def f[T](x: T) -> list[T]:
raise NotImplementedError
def _(flag: bool):
x1 = f(1) if flag else f(2)
reveal_type(x1) # revealed: list[Literal[1]] | list[Literal[2]]
x2: list[int | None] = f(1) if flag else f(2)
reveal_type(x2) # revealed: list[int | None]
```
## Multi-inference diagnostics
```toml

View file

@ -6894,7 +6894,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
ast::Expr::Compare(compare) => self.infer_compare_expression(compare),
ast::Expr::Subscript(subscript) => self.infer_subscript_expression(subscript),
ast::Expr::Slice(slice) => self.infer_slice_expression(slice),
ast::Expr::If(if_expression) => self.infer_if_expression(if_expression),
ast::Expr::If(if_expression) => self.infer_if_expression(if_expression, tcx),
ast::Expr::Lambda(lambda_expression) => self.infer_lambda_expression(lambda_expression),
ast::Expr::Call(call_expression) => self.infer_call_expression(call_expression, tcx),
ast::Expr::Starred(starred) => self.infer_starred_expression(starred),
@ -7740,7 +7740,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
})
}
fn infer_if_expression(&mut self, if_expression: &ast::ExprIf) -> Type<'db> {
fn infer_if_expression(
&mut self,
if_expression: &ast::ExprIf,
tcx: TypeContext<'db>,
) -> Type<'db> {
let ast::ExprIf {
range: _,
node_index: _,
@ -7750,8 +7754,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
} = if_expression;
let test_ty = self.infer_standalone_expression(test, TypeContext::default());
let body_ty = self.infer_expression(body, TypeContext::default());
let orelse_ty = self.infer_expression(orelse, TypeContext::default());
let body_ty = self.infer_expression(body, tcx);
let orelse_ty = self.infer_expression(orelse, tcx);
match test_ty.try_bool(self.db()).unwrap_or_else(|err| {
err.report_diagnostic(&self.context, &**test);

View file

@ -330,7 +330,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
}
ast::Expr::If(if_expression) => {
self.infer_if_expression(if_expression);
self.infer_if_expression(if_expression, TypeContext::default());
self.report_invalid_type_expression(
expression,
format_args!("`if` expressions are not allowed in type expressions"),