[ty] Fix match pattern value narrowing to use equality semantics (#20882)

## Summary

Resolves https://github.com/astral-sh/ty/issues/1349.

Fix match statement value patterns to use equality comparison semantics
instead of incorrectly narrowing to literal types directly. Value
patterns use equality for matching, and equality can be overridden, so
we can't always narrow to the matched literal.

## Test Plan

Updated match.md with corrected expected types and an additional example
with explanation

---------

Co-authored-by: David Peter <mail@david-peter.de>
This commit is contained in:
Eric Mark Martin 2025-10-16 03:50:32 -04:00 committed by GitHub
parent fe4e3e2e75
commit c9dfb51f49
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 154 additions and 98 deletions

View file

@ -71,98 +71,140 @@ reveal_type(x) # revealed: object
## Value patterns
Value patterns are evaluated by equality, which is overridable. Therefore successfully matching on
one can only give us information where we know how the subject type implements equality.
Consider the following example.
```py
def get_object() -> object:
return object()
from typing import Literal
x = get_object()
def _(x: Literal["foo"] | int):
match x:
case "foo":
reveal_type(x) # revealed: Literal["foo"] | int
reveal_type(x) # revealed: object
match x:
case "bar":
reveal_type(x) # revealed: int
```
match x:
case "foo":
reveal_type(x) # revealed: Literal["foo"]
case 42:
reveal_type(x) # revealed: Literal[42]
case 6.0:
reveal_type(x) # revealed: float
case 1j:
reveal_type(x) # revealed: complex
case b"foo":
reveal_type(x) # revealed: Literal[b"foo"]
In the first `match`'s `case "foo"` all we know is `x == "foo"`. `x` could be an instance of an
arbitrary `int` subclass with an arbitrary `__eq__`, so we can't actually narrow to
`Literal["foo"]`.
reveal_type(x) # revealed: object
In the second `match`'s `case "bar"` we know `x == "bar"`. As discussed above, this isn't enough to
rule out `int`, but we know that `"foo" == "bar"` is false so we can eliminate `Literal["foo"]`.
More examples follow.
```py
from typing import Literal
class C:
pass
def _(x: Literal["foo", "bar", 42, b"foo"] | bool | complex):
match x:
case "foo":
reveal_type(x) # revealed: Literal["foo"] | int | float | complex
case 42:
reveal_type(x) # revealed: int | float | complex
case 6.0:
reveal_type(x) # revealed: Literal["bar", b"foo"] | (int & ~Literal[42]) | float | complex
case 1j:
reveal_type(x) # revealed: Literal["bar", b"foo"] | (int & ~Literal[42]) | float | complex
case b"foo":
reveal_type(x) # revealed: (int & ~Literal[42]) | Literal[b"foo"] | float | complex
case _:
reveal_type(x) # revealed: Literal["bar"] | (int & ~Literal[42]) | float | complex
```
## Value patterns with guard
```py
def get_object() -> object:
return object()
from typing import Literal
x = get_object()
class C:
pass
reveal_type(x) # revealed: object
match x:
case "foo" if reveal_type(x): # revealed: Literal["foo"]
pass
case 42 if reveal_type(x): # revealed: Literal[42]
pass
case 6.0 if reveal_type(x): # revealed: float
pass
case 1j if reveal_type(x): # revealed: complex
pass
case b"foo" if reveal_type(x): # revealed: Literal[b"foo"]
pass
reveal_type(x) # revealed: object
def _(x: Literal["foo", b"bar"] | int):
match x:
case "foo" if reveal_type(x): # revealed: Literal["foo"] | int
pass
case b"bar" if reveal_type(x): # revealed: Literal[b"bar"] | int
pass
case 42 if reveal_type(x): # revealed: int
pass
```
## Or patterns
```py
def get_object() -> object:
return object()
from typing import Literal
from enum import Enum
x = get_object()
class Color(Enum):
RED = 1
GREEN = 2
BLUE = 3
reveal_type(x) # revealed: object
def _(color: Color):
match color:
case Color.RED | Color.GREEN:
reveal_type(color) # revealed: Literal[Color.RED, Color.GREEN]
case Color.BLUE:
reveal_type(color) # revealed: Literal[Color.BLUE]
match x:
case "foo" | 42 | None:
reveal_type(x) # revealed: Literal["foo", 42] | None
case "foo" | tuple():
reveal_type(x) # revealed: tuple[Unknown, ...]
case True | False:
reveal_type(x) # revealed: bool
case 3.14 | 2.718 | 1.414:
reveal_type(x) # revealed: float
match color:
case Color.RED | Color.GREEN | Color.BLUE:
reveal_type(color) # revealed: Color
reveal_type(x) # revealed: object
match color:
case Color.RED:
reveal_type(color) # revealed: Literal[Color.RED]
case _:
reveal_type(color) # revealed: Literal[Color.GREEN, Color.BLUE]
class A: ...
class B: ...
class C: ...
def _(x: A | B | C):
match x:
case A() | B():
reveal_type(x) # revealed: A | B
case C():
reveal_type(x) # revealed: C & ~A & ~B
case _:
reveal_type(x) # revealed: Never
match x:
case A() | B() | C():
reveal_type(x) # revealed: A | B | C
case _:
reveal_type(x) # revealed: Never
match x:
case A():
reveal_type(x) # revealed: A
case _:
reveal_type(x) # revealed: (B & ~A) | (C & ~A)
```
## Or patterns with guard
```py
def get_object() -> object:
return object()
from typing import Literal
x = get_object()
reveal_type(x) # revealed: object
match x:
case "foo" | 42 | None if reveal_type(x): # revealed: Literal["foo", 42] | None
pass
case "foo" | tuple() if reveal_type(x): # revealed: Literal["foo"] | tuple[Unknown, ...]
pass
case True | False if reveal_type(x): # revealed: bool
pass
case 3.14 | 2.718 | 1.414 if reveal_type(x): # revealed: float
pass
reveal_type(x) # revealed: object
def _(x: Literal["foo", b"bar"] | int):
match x:
case "foo" | 42 if reveal_type(x): # revealed: Literal["foo"] | int
pass
case b"bar" if reveal_type(x): # revealed: Literal[b"bar"] | int
pass
case _ if reveal_type(x): # revealed: Literal["foo", b"bar"] | int
pass
```
## Narrowing due to guard
@ -179,7 +221,7 @@ match x:
case str() | float() if type(x) is str:
reveal_type(x) # revealed: str
case "foo" | 42 | None if isinstance(x, int):
reveal_type(x) # revealed: Literal[42]
reveal_type(x) # revealed: int
case False if x:
reveal_type(x) # revealed: Never
case "foo" if x := "bar":
@ -201,7 +243,7 @@ reveal_type(x) # revealed: object
match x:
case str() | float() if type(x) is str and reveal_type(x): # revealed: str
pass
case "foo" | 42 | None if isinstance(x, int) and reveal_type(x): # revealed: Literal[42]
case "foo" | 42 | None if isinstance(x, int) and reveal_type(x): # revealed: int
pass
case False if x and reveal_type(x): # revealed: Never
pass

View file

@ -263,19 +263,19 @@ type NarrowingConstraints<'db> = FxHashMap<ScopedPlaceId, Type<'db>>;
fn merge_constraints_and<'db>(
into: &mut NarrowingConstraints<'db>,
from: NarrowingConstraints<'db>,
from: &NarrowingConstraints<'db>,
db: &'db dyn Db,
) {
for (key, value) in from {
match into.entry(key) {
match into.entry(*key) {
Entry::Occupied(mut entry) => {
*entry.get_mut() = IntersectionBuilder::new(db)
.add_positive(*entry.get())
.add_positive(value)
.add_positive(*value)
.build();
}
Entry::Vacant(entry) => {
entry.insert(value);
entry.insert(*value);
}
}
}
@ -303,12 +303,6 @@ fn merge_constraints_or<'db>(
}
}
fn negate_if<'db>(constraints: &mut NarrowingConstraints<'db>, db: &'db dyn Db, yes: bool) {
for (_place, ty) in constraints.iter_mut() {
*ty = ty.negate_if(db, yes);
}
}
fn place_expr(expr: &ast::Expr) -> Option<PlaceExpr> {
match expr {
ast::Expr::Named(named) => PlaceExpr::try_from_expr(named.target.as_ref()),
@ -399,12 +393,14 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
) -> Option<NarrowingConstraints<'db>> {
match pattern_predicate_kind {
PatternPredicateKind::Singleton(singleton) => {
self.evaluate_match_pattern_singleton(subject, *singleton)
self.evaluate_match_pattern_singleton(subject, *singleton, is_positive)
}
PatternPredicateKind::Class(cls, kind) => {
self.evaluate_match_pattern_class(subject, *cls, *kind, is_positive)
}
PatternPredicateKind::Value(expr) => self.evaluate_match_pattern_value(subject, *expr),
PatternPredicateKind::Value(expr) => {
self.evaluate_match_pattern_value(subject, *expr, is_positive)
}
PatternPredicateKind::Or(predicates) => {
self.evaluate_match_pattern_or(subject, predicates, is_positive)
}
@ -420,12 +416,11 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
pattern: PatternPredicate<'db>,
is_positive: bool,
) -> Option<NarrowingConstraints<'db>> {
let subject = pattern.subject(self.db);
self.evaluate_pattern_predicate_kind(pattern.kind(self.db), subject, is_positive)
.map(|mut constraints| {
negate_if(&mut constraints, self.db, !is_positive);
constraints
})
self.evaluate_pattern_predicate_kind(
pattern.kind(self.db),
pattern.subject(self.db),
is_positive,
)
}
fn places(&self) -> &'db PlaceTable {
@ -709,7 +704,10 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
lhs_ty: Type<'db>,
rhs_ty: Type<'db>,
op: ast::CmpOp,
is_positive: bool,
) -> Option<Type<'db>> {
let op = if is_positive { op } else { op.negate() };
match op {
ast::CmpOp::IsNot => {
if rhs_ty.is_singleton(self.db) {
@ -792,13 +790,12 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
| ast::Expr::Attribute(_)
| ast::Expr::Subscript(_)
| ast::Expr::Named(_) => {
if let Some(left) = place_expr(left) {
let op = if is_positive { *op } else { op.negate() };
if let Some(ty) = self.evaluate_expr_compare_op(lhs_ty, rhs_ty, op) {
let place = self.expect_place(&left);
constraints.insert(place, ty);
}
if let Some(left) = place_expr(left)
&& let Some(ty) =
self.evaluate_expr_compare_op(lhs_ty, rhs_ty, *op, is_positive)
{
let place = self.expect_place(&left);
constraints.insert(place, ty);
}
}
ast::Expr::Call(ast::ExprCall {
@ -954,6 +951,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
&mut self,
subject: Expression<'db>,
singleton: ast::Singleton,
is_positive: bool,
) -> Option<NarrowingConstraints<'db>> {
let subject = place_expr(subject.node_ref(self.db, self.module))?;
let place = self.expect_place(&subject);
@ -963,6 +961,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
ast::Singleton::True => Type::BooleanLiteral(true),
ast::Singleton::False => Type::BooleanLiteral(false),
};
let ty = ty.negate_if(self.db, !is_positive);
Some(NarrowingConstraints::from_iter([(place, ty)]))
}
@ -986,6 +985,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
let ty = infer_same_file_expression_type(self.db, cls, TypeContext::default(), self.module)
.to_instance(self.db)?;
let ty = ty.negate_if(self.db, !is_positive);
Some(NarrowingConstraints::from_iter([(place, ty)]))
}
@ -993,13 +993,20 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
&mut self,
subject: Expression<'db>,
value: Expression<'db>,
is_positive: bool,
) -> Option<NarrowingConstraints<'db>> {
let subject = place_expr(subject.node_ref(self.db, self.module))?;
let place = self.expect_place(&subject);
let place = {
let subject = place_expr(subject.node_ref(self.db, self.module))?;
self.expect_place(&subject)
};
let subject_ty =
infer_same_file_expression_type(self.db, subject, TypeContext::default(), self.module);
let ty =
let value_ty =
infer_same_file_expression_type(self.db, value, TypeContext::default(), self.module);
Some(NarrowingConstraints::from_iter([(place, ty)]))
self.evaluate_expr_compare_op(subject_ty, value_ty, ast::CmpOp::Eq, is_positive)
.map(|ty| NarrowingConstraints::from_iter([(place, ty)]))
}
fn evaluate_match_pattern_or(
@ -1010,13 +1017,20 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
) -> Option<NarrowingConstraints<'db>> {
let db = self.db;
// DeMorgan's law---if the overall `or` is negated, we need to `and` the negated sub-constraints.
let merge_constraints = if is_positive {
merge_constraints_or
} else {
merge_constraints_and
};
predicates
.iter()
.filter_map(|predicate| {
self.evaluate_pattern_predicate_kind(predicate, subject, is_positive)
})
.reduce(|mut constraints, constraints_| {
merge_constraints_or(&mut constraints, &constraints_, db);
merge_constraints(&mut constraints, &constraints_, db);
constraints
})
}
@ -1048,7 +1062,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
let mut aggregation: Option<NarrowingConstraints> = None;
for sub_constraint in sub_constraints.into_iter().flatten() {
if let Some(ref mut some_aggregation) = aggregation {
merge_constraints_and(some_aggregation, sub_constraint, self.db);
merge_constraints_and(some_aggregation, &sub_constraint, self.db);
} else {
aggregation = Some(sub_constraint);
}