specialize interpretation of PEP 604 unions

This commit is contained in:
Carl Meyer 2025-11-15 15:04:04 -08:00
parent df0c8e202d
commit f6b892c5aa
No known key found for this signature in database
GPG key ID: 2D1FB7916A52E121
4 changed files with 102 additions and 7 deletions

View file

@ -46,6 +46,48 @@ except MyExc as e:
reveal_type(e) # revealed: Exception
```
## Unknown type in PEP 604 union
If we run into an unexpected type in a PEP 604 union in the RHS of a PEP 613 type alias, we still
understand it as a union type, just with an unknown element.
```py
from typing import TypeAlias
from nonexistent import unknown_type # error: [unresolved-import]
MyAlias: TypeAlias = int | unknown_type | str
def _(x: MyAlias):
reveal_type(x) # revealed: int | Unknown | str
```
## Callable type in union
```py
from typing import TypeAlias, Callable
MyAlias: TypeAlias = int | Callable[[str], int]
def _(x: MyAlias):
# TODO: int | (str) -> int
reveal_type(x) # revealed: int | @Todo(Inference of subscript on special form)
```
## Subscripted generic alias in union
```py
from typing import TypeAlias, TypeVar
T = TypeVar("T")
Alias1: TypeAlias = list[T] | set[T]
MyAlias: TypeAlias = int | Alias1[str]
def _(x: MyAlias):
# TODO: int | list[str] | set[str]
reveal_type(x) # revealed: int | @Todo(Specialization of union type alias)
```
## Imported
`alias.py`:

View file

@ -849,6 +849,10 @@ impl<'db> Type<'db> {
.is_some()
}
fn is_typealias_special_form(&self) -> bool {
matches!(self, Type::SpecialForm(SpecialFormType::TypeAlias))
}
/// Return true if this type overrides __eq__ or __ne__ methods
fn overrides_equality(&self, db: &'db dyn Db) -> bool {
let check_dunder = |dunder_name, allowed_return_value| {

View file

@ -5444,10 +5444,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// Check if this is a PEP 613 `TypeAlias`. (This must come below the SpecialForm handling
// immediately below, since that can overwrite the type to be `TypeAlias`.)
let is_pep_613_type_alias = matches!(
declared.inner_type(),
Type::SpecialForm(SpecialFormType::TypeAlias)
);
let is_pep_613_type_alias = declared.inner_type().is_typealias_special_form();
// Handle various singletons.
if let Some(name_expr) = target.as_name_expr() {
@ -6926,7 +6923,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
ast::Expr::Name(name) => self.infer_name_expression(name),
ast::Expr::Attribute(attribute) => self.infer_attribute_expression(attribute),
ast::Expr::UnaryOp(unary_op) => self.infer_unary_expression(unary_op),
ast::Expr::BinOp(binary) => self.infer_binary_expression(binary),
ast::Expr::BinOp(binary) => self.infer_binary_expression(binary, tcx),
ast::Expr::BoolOp(bool_op) => self.infer_boolean_expression(bool_op),
ast::Expr::Compare(compare) => self.infer_compare_expression(compare),
ast::Expr::Subscript(subscript) => self.infer_subscript_expression(subscript),
@ -9200,7 +9197,18 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
}
fn infer_binary_expression(&mut self, binary: &ast::ExprBinOp) -> Type<'db> {
fn infer_binary_expression(
&mut self,
binary: &ast::ExprBinOp,
tcx: TypeContext<'db>,
) -> Type<'db> {
if tcx
.annotation
.is_some_and(|ty| ty.is_typealias_special_form())
{
return self.infer_pep_604_union_type_alias(binary, tcx);
}
let ast::ExprBinOp {
left,
op,
@ -9238,6 +9246,44 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
})
}
fn infer_pep_604_union_type_alias(
&mut self,
node: &ast::ExprBinOp,
tcx: TypeContext<'db>,
) -> Type<'db> {
let ast::ExprBinOp {
left,
op,
right,
range: _,
node_index: _,
} = node;
if *op != ast::Operator::BitOr {
if let Some(builder) = self.context.report_lint(&INVALID_TYPE_FORM, node) {
let mut diag = builder.into_diagnostic(format_args!(
"Invalid binary operator `{}` in type alias",
op.as_str()
));
diag.info("Did you mean to use `|`?");
}
return Type::unknown();
}
let left_ty = self.infer_expression(left, tcx);
let right_ty = self.infer_expression(right, tcx);
if left_ty.is_equivalent_to(self.db(), right_ty) {
left_ty
} else {
Type::KnownInstance(KnownInstanceType::UnionType(InternedTypes::from_elements(
self.db(),
[left_ty, right_ty],
InferredAs::ValueExpression,
)))
}
}
fn infer_binary_expression_type(
&mut self,
node: AnyNodeRef<'_>,
@ -10916,6 +10962,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
.map(Type::from)
.unwrap_or_else(Type::unknown);
}
Type::KnownInstance(KnownInstanceType::UnionType(_)) => {
return todo_type!("Specialization of union type alias");
}
_ => {}
}

View file

@ -147,7 +147,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
}
// anything else is an invalid annotation:
op => {
self.infer_binary_expression(binary);
self.infer_binary_expression(binary, TypeContext::default());
if let Some(mut diag) = self.report_invalid_type_expression(
expression,
format_args!(