[ty]eliminate definitely-impossible types from union in equality narrowing (#20164)
Some checks are pending
CI / test ruff-lsp (push) Blocked by required conditions
CI / mkdocs (push) Waiting to run
CI / Determine changes (push) Waiting to run
CI / cargo fmt (push) Waiting to run
CI / cargo clippy (push) Blocked by required conditions
CI / cargo test (linux) (push) Blocked by required conditions
CI / cargo test (linux, release) (push) Blocked by required conditions
CI / cargo test (windows) (push) Blocked by required conditions
CI / cargo test (wasm) (push) Blocked by required conditions
CI / cargo build (release) (push) Waiting to run
CI / formatter instabilities and black similarity (push) Blocked by required conditions
CI / cargo build (msrv) (push) Blocked by required conditions
CI / cargo fuzz build (push) Blocked by required conditions
CI / fuzz parser (push) Blocked by required conditions
CI / test scripts (push) Blocked by required conditions
CI / ecosystem (push) Blocked by required conditions
CI / Fuzz for new ty panics (push) Blocked by required conditions
CI / cargo shear (push) Blocked by required conditions
CI / python package (push) Waiting to run
CI / pre-commit (push) Waiting to run
CI / check playground (push) Blocked by required conditions
CI / benchmarks-instrumented (push) Blocked by required conditions
CI / benchmarks-walltime (push) Blocked by required conditions
[ty Playground] Release / publish (push) Waiting to run

solves https://github.com/astral-sh/ty/issues/939

---------

Co-authored-by: Carl Meyer <carl@astral.sh>
This commit is contained in:
Renkai Ge 2025-09-03 23:34:22 +08:00 committed by GitHub
parent b14fc96141
commit cda376afe0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 193 additions and 26 deletions

View file

@ -92,3 +92,106 @@ if (x := f()) in (1,):
else:
reveal_type(x) # revealed: Literal[2, 3]
```
## Union with `Literal`, `None` and `int`
```py
from typing import Literal
def test(x: Literal["a", "b", "c"] | None | int = None):
if x in ("a", "b"):
# int is included because custom __eq__ methods could make
# an int equal to "a" or "b", so we can't eliminate it
reveal_type(x) # revealed: Literal["a", "b"] | int
else:
reveal_type(x) # revealed: Literal["c"] | None | int
```
## Direct `not in` conditional
```py
from typing import Literal
def test(x: Literal["a", "b", "c"] | None | int = None):
if x not in ("a", "c"):
# int is included because custom __eq__ methods could make
# an int equal to "a" or "b", so we can't eliminate it
reveal_type(x) # revealed: Literal["b"] | None | int
else:
reveal_type(x) # revealed: Literal["a", "c"] | int
```
## bool
```py
def _(x: bool):
if x in (True,):
reveal_type(x) # revealed: Literal[True]
else:
reveal_type(x) # revealed: Literal[False]
def _(x: bool | str):
if x in (False,):
# `str` remains due to possible custom __eq__ methods on a subclass
reveal_type(x) # revealed: Literal[False] | str
else:
reveal_type(x) # revealed: Literal[True] | str
```
## LiteralString
```py
from typing_extensions import LiteralString
def _(x: LiteralString):
if x in ("a", "b", "c"):
reveal_type(x) # revealed: Literal["a", "b", "c"]
else:
reveal_type(x) # revealed: LiteralString & ~Literal["a"] & ~Literal["b"] & ~Literal["c"]
def _(x: LiteralString | int):
if x in ("a", "b", "c"):
reveal_type(x) # revealed: Literal["a", "b", "c"] | int
else:
reveal_type(x) # revealed: (LiteralString & ~Literal["a"] & ~Literal["b"] & ~Literal["c"]) | int
```
## enums
```py
from enum import Enum
class Color(Enum):
RED = "red"
GREEN = "green"
BLUE = "blue"
def _(x: Color):
if x in (Color.RED, Color.GREEN):
# TODO should be `Literal[Color.RED, Color.GREEN]`
reveal_type(x) # revealed: Color
else:
# TODO should be `Literal[Color.BLUE]`
reveal_type(x) # revealed: Color
```
## Union with enum and `int`
```py
from enum import Enum
class Status(Enum):
PENDING = 1
APPROVED = 2
REJECTED = 3
def test(x: Status | int):
if x in (Status.PENDING, Status.APPROVED):
# TODO should be `Literal[Status.PENDING, Status.APPROVED] | int`
# int is included because custom __eq__ methods could make
# an int equal to Status.PENDING or Status.APPROVED, so we can't eliminate it
reveal_type(x) # revealed: Status | int
else:
# TODO should be `Literal[Status.REJECTED] | int`
reveal_type(x) # revealed: Status | int
```

View file

@ -1054,6 +1054,16 @@ impl<'db> Type<'db> {
|| self.is_literal_string()
}
pub(crate) fn is_union_with_single_valued(&self, db: &'db dyn Db) -> bool {
self.into_union().is_some_and(|union| {
union
.elements(db)
.iter()
.any(|ty| ty.is_single_valued(db) || ty.is_bool(db) || ty.is_literal_string())
}) || self.is_bool(db)
|| self.is_literal_string()
}
pub(crate) fn into_string_literal(self) -> Option<StringLiteralType<'db>> {
match self {
Type::StringLiteral(string_literal) => Some(string_literal),
@ -9953,14 +9963,6 @@ impl<'db> StringLiteralType<'db> {
pub(crate) fn python_len(self, db: &'db dyn Db) -> usize {
self.value(db).chars().count()
}
/// Return an iterator over each character in the string literal.
/// as would be returned by Python's `iter()`.
pub(crate) fn iter_each_char(self, db: &'db dyn Db) -> impl Iterator<Item = Self> {
self.value(db)
.chars()
.map(|c| StringLiteralType::new(db, c.to_string().into_boxed_str()))
}
}
/// # Ordering

View file

@ -615,24 +615,88 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
}
}
// TODO `expr_in` and `expr_not_in` should perhaps be unified with `expr_eq` and `expr_ne`,
// since `eq` and `ne` are equivalent to `in` and `not in` with only one element in the RHS.
fn evaluate_expr_in(&mut self, lhs_ty: Type<'db>, rhs_ty: Type<'db>) -> Option<Type<'db>> {
if lhs_ty.is_single_valued(self.db) || lhs_ty.is_union_of_single_valued(self.db) {
if let Type::StringLiteral(string_literal) = rhs_ty {
Some(UnionType::from_elements(
self.db,
string_literal
.iter_each_char(self.db)
.map(Type::StringLiteral),
))
} else if let Some(tuple_spec) = rhs_ty.tuple_instance_spec(self.db) {
// N.B. Strictly speaking this is unsound, since a tuple subclass might override `__contains__`
// but we'd still apply the narrowing here. This seems unlikely, however, and narrowing is
// generally unsound in numerous ways anyway (attribute narrowing, subscript, narrowing,
// narrowing of globals, etc.). So this doesn't seem worth worrying about too much.
Some(UnionType::from_elements(self.db, tuple_spec.all_elements()))
rhs_ty
.try_iterate(self.db)
.ok()
.map(|iterable| iterable.homogeneous_element_type(self.db))
} else if lhs_ty.is_union_with_single_valued(self.db) {
let rhs_values = rhs_ty
.try_iterate(self.db)
.ok()?
.homogeneous_element_type(self.db);
let mut builder = UnionBuilder::new(self.db);
// Add the narrowed values from the RHS first, to keep literals before broader types.
builder = builder.add(rhs_values);
if let Some(lhs_union) = lhs_ty.into_union() {
for element in lhs_union.elements(self.db) {
// Keep only the non-single-valued portion of the original type.
if !element.is_single_valued(self.db)
&& !element.is_literal_string()
&& !element.is_bool(self.db)
{
builder = builder.add(*element);
}
}
}
Some(builder.build())
} else {
None
}
}
fn evaluate_expr_not_in(&mut self, lhs_ty: Type<'db>, rhs_ty: Type<'db>) -> Option<Type<'db>> {
let rhs_values = rhs_ty
.try_iterate(self.db)
.ok()?
.homogeneous_element_type(self.db);
if lhs_ty.is_single_valued(self.db) || lhs_ty.is_union_of_single_valued(self.db) {
// Exclude the RHS values from the entire (single-valued) LHS domain.
let complement = IntersectionBuilder::new(self.db)
.add_positive(lhs_ty)
.add_negative(rhs_values)
.build();
Some(complement)
} else if lhs_ty.is_union_with_single_valued(self.db) {
// Split LHS into single-valued portion and the rest. Exclude RHS values from the
// single-valued portion, keep the rest intact.
let mut single_builder = UnionBuilder::new(self.db);
let mut rest_builder = UnionBuilder::new(self.db);
if let Some(lhs_union) = lhs_ty.into_union() {
for element in lhs_union.elements(self.db) {
if element.is_single_valued(self.db)
|| element.is_literal_string()
|| element.is_bool(self.db)
{
single_builder = single_builder.add(*element);
} else {
rest_builder = rest_builder.add(*element);
}
}
}
let single_union = single_builder.build();
let rest_union = rest_builder.build();
let narrowed_single = IntersectionBuilder::new(self.db)
.add_positive(single_union)
.add_negative(rhs_values)
.build();
// Keep order: first literal complement, then broader arms.
let result = UnionBuilder::new(self.db)
.add(narrowed_single)
.add(rest_union)
.build();
Some(result)
} else {
None
}
@ -660,9 +724,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
ast::CmpOp::Eq => self.evaluate_expr_eq(lhs_ty, rhs_ty),
ast::CmpOp::NotEq => self.evaluate_expr_ne(lhs_ty, rhs_ty),
ast::CmpOp::In => self.evaluate_expr_in(lhs_ty, rhs_ty),
ast::CmpOp::NotIn => self
.evaluate_expr_in(lhs_ty, rhs_ty)
.map(|ty| ty.negate(self.db)),
ast::CmpOp::NotIn => self.evaluate_expr_not_in(lhs_ty, rhs_ty),
_ => None,
}
}