[ty] Support type context of union attribute assignments (#21170)

## Summary

Turns out this is easy to implement. Resolves
https://github.com/astral-sh/ty/issues/1375.
This commit is contained in:
Ibraheem Ahmed 2025-10-31 12:41:14 -04:00 committed by GitHub
parent 9664474c51
commit ff3a6a8fbd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 27 additions and 7 deletions

View file

@ -200,7 +200,7 @@ def f() -> list[Literal[1]]:
return [1]
```
## Instance attribute
## Instance attributes
```toml
[environment]
@ -235,6 +235,24 @@ def _(flag: bool):
C.x = lst(1)
```
For union targets, each element of the union is considered as a separate type context:
```py
from typing import Literal
class X:
x: list[int | str]
class Y:
x: list[int | None]
def lst[T](x: T) -> list[T]:
return [x]
def _(xy: X | Y):
xy.x = lst(1)
```
## Class constructor parameters
```toml

View file

@ -3574,7 +3574,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
target: &ast::ExprAttribute,
object_ty: Type<'db>,
attribute: &str,
infer_value_ty: &dyn Fn(&mut Self, TypeContext<'db>) -> Type<'db>,
infer_value_ty: &mut dyn FnMut(&mut Self, TypeContext<'db>) -> Type<'db>,
emit_diagnostics: bool,
) -> bool {
let db = self.db();
@ -3651,7 +3651,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
match object_ty {
Type::Union(union) => {
// TODO: We could perform multi-inference here with each element of the union as type context.
// First infer the value without type context, and then again for each union element.
let value_ty = infer_value_ty(self, TypeContext::default());
if union.elements(self.db()).iter().all(|elem| {
@ -3659,7 +3659,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
target,
*elem,
attribute,
&|_, _| value_ty,
// Note that `infer_value_ty` silences diagnostics after the first inference.
&mut infer_value_ty,
false,
)
}) {
@ -3684,7 +3685,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
Type::Intersection(intersection) => {
// TODO: We could perform multi-inference here with each element of the union as type context.
// First infer the value without type context, and then again for each union element.
let value_ty = infer_value_ty(self, TypeContext::default());
// TODO: Handle negative intersection elements
@ -3693,7 +3694,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
target,
*elem,
attribute,
&|_, _| value_ty,
// Note that `infer_value_ty` silences diagnostics after the first inference.
&mut infer_value_ty,
false,
)
}) {
@ -4254,7 +4256,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let object_ty = self.infer_expression(object, TypeContext::default());
if let Some(infer_assigned_ty) = infer_assigned_ty {
let infer_assigned_ty = &|builder: &mut Self, tcx| {
let infer_assigned_ty = &mut |builder: &mut Self, tcx| {
let assigned_ty = infer_assigned_ty(builder, tcx);
builder.store_expression_type(target, assigned_ty);
assigned_ty