From ffb7bdd59535d5534d7f1f44a9542f3372102e9a Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Fri, 14 Nov 2025 15:19:08 -0500 Subject: [PATCH] [ty] Propagate type context through conditional expressions (#21443) ## Summary Resolves https://github.com/astral-sh/ty/issues/1543. --- .../mdtest/assignment/annotations.md | 2 -- .../resources/mdtest/bidirectional.md | 21 +++++++++++++++++++ .../src/types/infer/builder.rs | 12 +++++++---- .../types/infer/builder/type_expression.rs | 2 +- 4 files changed, 30 insertions(+), 7 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md index 043380338b..34fca6af0e 100644 --- a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md +++ b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md @@ -341,8 +341,6 @@ d: X[Any] = X(1) reveal_type(d) # revealed: X[Any] def _(flag: bool): - # TODO: Handle unions correctly. - # error: [invalid-assignment] "Object of type `X[int]` is not assignable to `X[int | None]`" a: X[int | None] = X(1) if flag else X(2) reveal_type(a) # revealed: X[int | None] ``` diff --git a/crates/ty_python_semantic/resources/mdtest/bidirectional.md b/crates/ty_python_semantic/resources/mdtest/bidirectional.md index 1211f92fe5..c2d0cfd45a 100644 --- a/crates/ty_python_semantic/resources/mdtest/bidirectional.md +++ b/crates/ty_python_semantic/resources/mdtest/bidirectional.md @@ -281,6 +281,27 @@ A(f(1)) A(f([])) ``` +## Conditional expressions + +```toml +[environment] +python-version = "3.12" +``` + +The type context is propagated through both branches of conditional expressions: + +```py +def f[T](x: T) -> list[T]: + raise NotImplementedError + +def _(flag: bool): + x1 = f(1) if flag else f(2) + reveal_type(x1) # revealed: list[Literal[1]] | list[Literal[2]] + + x2: list[int | None] = f(1) if flag else f(2) + reveal_type(x2) # revealed: list[int | None] +``` + ## Multi-inference diagnostics ```toml diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index ddd6a33edb..43e239787d 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -6894,7 +6894,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ast::Expr::Compare(compare) => self.infer_compare_expression(compare), ast::Expr::Subscript(subscript) => self.infer_subscript_expression(subscript), ast::Expr::Slice(slice) => self.infer_slice_expression(slice), - ast::Expr::If(if_expression) => self.infer_if_expression(if_expression), + ast::Expr::If(if_expression) => self.infer_if_expression(if_expression, tcx), ast::Expr::Lambda(lambda_expression) => self.infer_lambda_expression(lambda_expression), ast::Expr::Call(call_expression) => self.infer_call_expression(call_expression, tcx), ast::Expr::Starred(starred) => self.infer_starred_expression(starred), @@ -7740,7 +7740,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { }) } - fn infer_if_expression(&mut self, if_expression: &ast::ExprIf) -> Type<'db> { + fn infer_if_expression( + &mut self, + if_expression: &ast::ExprIf, + tcx: TypeContext<'db>, + ) -> Type<'db> { let ast::ExprIf { range: _, node_index: _, @@ -7750,8 +7754,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } = if_expression; let test_ty = self.infer_standalone_expression(test, TypeContext::default()); - let body_ty = self.infer_expression(body, TypeContext::default()); - let orelse_ty = self.infer_expression(orelse, TypeContext::default()); + let body_ty = self.infer_expression(body, tcx); + let orelse_ty = self.infer_expression(orelse, tcx); match test_ty.try_bool(self.db()).unwrap_or_else(|err| { err.report_diagnostic(&self.context, &**test); diff --git a/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs b/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs index 20c5362a0b..513361ff4e 100644 --- a/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs +++ b/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs @@ -330,7 +330,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> { } ast::Expr::If(if_expression) => { - self.infer_if_expression(if_expression); + self.infer_if_expression(if_expression, TypeContext::default()); self.report_invalid_type_expression( expression, format_args!("`if` expressions are not allowed in type expressions"),