From ff3a6a8fbd5af1b3e9b42b4c5adb9a8954968de9 Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Fri, 31 Oct 2025 12:41:14 -0400 Subject: [PATCH] [ty] Support type context of union attribute assignments (#21170) ## Summary Turns out this is easy to implement. Resolves https://github.com/astral-sh/ty/issues/1375. --- .../resources/mdtest/bidirectional.md | 20 ++++++++++++++++++- .../src/types/infer/builder.rs | 14 +++++++------ 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/bidirectional.md b/crates/ty_python_semantic/resources/mdtest/bidirectional.md index 1cc3dba162..3fee0513ed 100644 --- a/crates/ty_python_semantic/resources/mdtest/bidirectional.md +++ b/crates/ty_python_semantic/resources/mdtest/bidirectional.md @@ -200,7 +200,7 @@ def f() -> list[Literal[1]]: return [1] ``` -## Instance attribute +## Instance attributes ```toml [environment] @@ -235,6 +235,24 @@ def _(flag: bool): C.x = lst(1) ``` +For union targets, each element of the union is considered as a separate type context: + +```py +from typing import Literal + +class X: + x: list[int | str] + +class Y: + x: list[int | None] + +def lst[T](x: T) -> list[T]: + return [x] + +def _(xy: X | Y): + xy.x = lst(1) +``` + ## Class constructor parameters ```toml diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 7094fdea07..f6055c0a0e 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -3574,7 +3574,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { target: &ast::ExprAttribute, object_ty: Type<'db>, attribute: &str, - infer_value_ty: &dyn Fn(&mut Self, TypeContext<'db>) -> Type<'db>, + infer_value_ty: &mut dyn FnMut(&mut Self, TypeContext<'db>) -> Type<'db>, emit_diagnostics: bool, ) -> bool { let db = self.db(); @@ -3651,7 +3651,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { match object_ty { Type::Union(union) => { - // TODO: We could perform multi-inference here with each element of the union as type context. + // First infer the value without type context, and then again for each union element. let value_ty = infer_value_ty(self, TypeContext::default()); if union.elements(self.db()).iter().all(|elem| { @@ -3659,7 +3659,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { target, *elem, attribute, - &|_, _| value_ty, + // Note that `infer_value_ty` silences diagnostics after the first inference. + &mut infer_value_ty, false, ) }) { @@ -3684,7 +3685,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } Type::Intersection(intersection) => { - // TODO: We could perform multi-inference here with each element of the union as type context. + // First infer the value without type context, and then again for each union element. let value_ty = infer_value_ty(self, TypeContext::default()); // TODO: Handle negative intersection elements @@ -3693,7 +3694,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { target, *elem, attribute, - &|_, _| value_ty, + // Note that `infer_value_ty` silences diagnostics after the first inference. + &mut infer_value_ty, false, ) }) { @@ -4254,7 +4256,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let object_ty = self.infer_expression(object, TypeContext::default()); if let Some(infer_assigned_ty) = infer_assigned_ty { - let infer_assigned_ty = &|builder: &mut Self, tcx| { + let infer_assigned_ty = &mut |builder: &mut Self, tcx| { let assigned_ty = infer_assigned_ty(builder, tcx); builder.store_expression_type(target, assigned_ty); assigned_ty