[ty] Fix match pattern value narrowing to use equality semantics (#20882)

## Summary

Resolves https://github.com/astral-sh/ty/issues/1349.

Fix match statement value patterns to use equality comparison semantics
instead of incorrectly narrowing to literal types directly. Value
patterns use equality for matching, and equality can be overridden, so
we can't always narrow to the matched literal.

## Test Plan

Updated match.md with corrected expected types and an additional example
with explanation

---------

Co-authored-by: David Peter <mail@david-peter.de>
This commit is contained in:
Eric Mark Martin 2025-10-16 03:50:32 -04:00 committed by GitHub
parent fe4e3e2e75
commit c9dfb51f49
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 154 additions and 98 deletions

View file

@ -71,98 +71,140 @@ reveal_type(x) # revealed: object
## Value patterns
Value patterns are evaluated by equality, which is overridable. Therefore successfully matching on
one can only give us information where we know how the subject type implements equality.
Consider the following example.
```py
def get_object() -> object:
return object()
x = get_object()
reveal_type(x) # revealed: object
from typing import Literal
def _(x: Literal["foo"] | int):
match x:
case "foo":
reveal_type(x) # revealed: Literal["foo"]
case 42:
reveal_type(x) # revealed: Literal[42]
case 6.0:
reveal_type(x) # revealed: float
case 1j:
reveal_type(x) # revealed: complex
case b"foo":
reveal_type(x) # revealed: Literal[b"foo"]
reveal_type(x) # revealed: Literal["foo"] | int
reveal_type(x) # revealed: object
match x:
case "bar":
reveal_type(x) # revealed: int
```
In the first `match`'s `case "foo"` all we know is `x == "foo"`. `x` could be an instance of an
arbitrary `int` subclass with an arbitrary `__eq__`, so we can't actually narrow to
`Literal["foo"]`.
In the second `match`'s `case "bar"` we know `x == "bar"`. As discussed above, this isn't enough to
rule out `int`, but we know that `"foo" == "bar"` is false so we can eliminate `Literal["foo"]`.
More examples follow.
```py
from typing import Literal
class C:
pass
def _(x: Literal["foo", "bar", 42, b"foo"] | bool | complex):
match x:
case "foo":
reveal_type(x) # revealed: Literal["foo"] | int | float | complex
case 42:
reveal_type(x) # revealed: int | float | complex
case 6.0:
reveal_type(x) # revealed: Literal["bar", b"foo"] | (int & ~Literal[42]) | float | complex
case 1j:
reveal_type(x) # revealed: Literal["bar", b"foo"] | (int & ~Literal[42]) | float | complex
case b"foo":
reveal_type(x) # revealed: (int & ~Literal[42]) | Literal[b"foo"] | float | complex
case _:
reveal_type(x) # revealed: Literal["bar"] | (int & ~Literal[42]) | float | complex
```
## Value patterns with guard
```py
def get_object() -> object:
return object()
from typing import Literal
x = get_object()
reveal_type(x) # revealed: object
class C:
pass
def _(x: Literal["foo", b"bar"] | int):
match x:
case "foo" if reveal_type(x): # revealed: Literal["foo"]
case "foo" if reveal_type(x): # revealed: Literal["foo"] | int
pass
case 42 if reveal_type(x): # revealed: Literal[42]
case b"bar" if reveal_type(x): # revealed: Literal[b"bar"] | int
pass
case 6.0 if reveal_type(x): # revealed: float
case 42 if reveal_type(x): # revealed: int
pass
case 1j if reveal_type(x): # revealed: complex
pass
case b"foo" if reveal_type(x): # revealed: Literal[b"foo"]
pass
reveal_type(x) # revealed: object
```
## Or patterns
```py
def get_object() -> object:
return object()
from typing import Literal
from enum import Enum
x = get_object()
class Color(Enum):
RED = 1
GREEN = 2
BLUE = 3
reveal_type(x) # revealed: object
def _(color: Color):
match color:
case Color.RED | Color.GREEN:
reveal_type(color) # revealed: Literal[Color.RED, Color.GREEN]
case Color.BLUE:
reveal_type(color) # revealed: Literal[Color.BLUE]
match color:
case Color.RED | Color.GREEN | Color.BLUE:
reveal_type(color) # revealed: Color
match color:
case Color.RED:
reveal_type(color) # revealed: Literal[Color.RED]
case _:
reveal_type(color) # revealed: Literal[Color.GREEN, Color.BLUE]
class A: ...
class B: ...
class C: ...
def _(x: A | B | C):
match x:
case A() | B():
reveal_type(x) # revealed: A | B
case C():
reveal_type(x) # revealed: C & ~A & ~B
case _:
reveal_type(x) # revealed: Never
match x:
case "foo" | 42 | None:
reveal_type(x) # revealed: Literal["foo", 42] | None
case "foo" | tuple():
reveal_type(x) # revealed: tuple[Unknown, ...]
case True | False:
reveal_type(x) # revealed: bool
case 3.14 | 2.718 | 1.414:
reveal_type(x) # revealed: float
case A() | B() | C():
reveal_type(x) # revealed: A | B | C
case _:
reveal_type(x) # revealed: Never
reveal_type(x) # revealed: object
match x:
case A():
reveal_type(x) # revealed: A
case _:
reveal_type(x) # revealed: (B & ~A) | (C & ~A)
```
## Or patterns with guard
```py
def get_object() -> object:
return object()
x = get_object()
reveal_type(x) # revealed: object
from typing import Literal
def _(x: Literal["foo", b"bar"] | int):
match x:
case "foo" | 42 | None if reveal_type(x): # revealed: Literal["foo", 42] | None
case "foo" | 42 if reveal_type(x): # revealed: Literal["foo"] | int
pass
case "foo" | tuple() if reveal_type(x): # revealed: Literal["foo"] | tuple[Unknown, ...]
case b"bar" if reveal_type(x): # revealed: Literal[b"bar"] | int
pass
case True | False if reveal_type(x): # revealed: bool
case _ if reveal_type(x): # revealed: Literal["foo", b"bar"] | int
pass
case 3.14 | 2.718 | 1.414 if reveal_type(x): # revealed: float
pass
reveal_type(x) # revealed: object
```
## Narrowing due to guard
@ -179,7 +221,7 @@ match x:
case str() | float() if type(x) is str:
reveal_type(x) # revealed: str
case "foo" | 42 | None if isinstance(x, int):
reveal_type(x) # revealed: Literal[42]
reveal_type(x) # revealed: int
case False if x:
reveal_type(x) # revealed: Never
case "foo" if x := "bar":
@ -201,7 +243,7 @@ reveal_type(x) # revealed: object
match x:
case str() | float() if type(x) is str and reveal_type(x): # revealed: str
pass
case "foo" | 42 | None if isinstance(x, int) and reveal_type(x): # revealed: Literal[42]
case "foo" | 42 | None if isinstance(x, int) and reveal_type(x): # revealed: int
pass
case False if x and reveal_type(x): # revealed: Never
pass

View file

@ -263,19 +263,19 @@ type NarrowingConstraints<'db> = FxHashMap<ScopedPlaceId, Type<'db>>;
fn merge_constraints_and<'db>(
into: &mut NarrowingConstraints<'db>,
from: NarrowingConstraints<'db>,
from: &NarrowingConstraints<'db>,
db: &'db dyn Db,
) {
for (key, value) in from {
match into.entry(key) {
match into.entry(*key) {
Entry::Occupied(mut entry) => {
*entry.get_mut() = IntersectionBuilder::new(db)
.add_positive(*entry.get())
.add_positive(value)
.add_positive(*value)
.build();
}
Entry::Vacant(entry) => {
entry.insert(value);
entry.insert(*value);
}
}
}
@ -303,12 +303,6 @@ fn merge_constraints_or<'db>(
}
}
fn negate_if<'db>(constraints: &mut NarrowingConstraints<'db>, db: &'db dyn Db, yes: bool) {
for (_place, ty) in constraints.iter_mut() {
*ty = ty.negate_if(db, yes);
}
}
fn place_expr(expr: &ast::Expr) -> Option<PlaceExpr> {
match expr {
ast::Expr::Named(named) => PlaceExpr::try_from_expr(named.target.as_ref()),
@ -399,12 +393,14 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
) -> Option<NarrowingConstraints<'db>> {
match pattern_predicate_kind {
PatternPredicateKind::Singleton(singleton) => {
self.evaluate_match_pattern_singleton(subject, *singleton)
self.evaluate_match_pattern_singleton(subject, *singleton, is_positive)
}
PatternPredicateKind::Class(cls, kind) => {
self.evaluate_match_pattern_class(subject, *cls, *kind, is_positive)
}
PatternPredicateKind::Value(expr) => self.evaluate_match_pattern_value(subject, *expr),
PatternPredicateKind::Value(expr) => {
self.evaluate_match_pattern_value(subject, *expr, is_positive)
}
PatternPredicateKind::Or(predicates) => {
self.evaluate_match_pattern_or(subject, predicates, is_positive)
}
@ -420,12 +416,11 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
pattern: PatternPredicate<'db>,
is_positive: bool,
) -> Option<NarrowingConstraints<'db>> {
let subject = pattern.subject(self.db);
self.evaluate_pattern_predicate_kind(pattern.kind(self.db), subject, is_positive)
.map(|mut constraints| {
negate_if(&mut constraints, self.db, !is_positive);
constraints
})
self.evaluate_pattern_predicate_kind(
pattern.kind(self.db),
pattern.subject(self.db),
is_positive,
)
}
fn places(&self) -> &'db PlaceTable {
@ -709,7 +704,10 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
lhs_ty: Type<'db>,
rhs_ty: Type<'db>,
op: ast::CmpOp,
is_positive: bool,
) -> Option<Type<'db>> {
let op = if is_positive { op } else { op.negate() };
match op {
ast::CmpOp::IsNot => {
if rhs_ty.is_singleton(self.db) {
@ -792,15 +790,14 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
| ast::Expr::Attribute(_)
| ast::Expr::Subscript(_)
| ast::Expr::Named(_) => {
if let Some(left) = place_expr(left) {
let op = if is_positive { *op } else { op.negate() };
if let Some(ty) = self.evaluate_expr_compare_op(lhs_ty, rhs_ty, op) {
if let Some(left) = place_expr(left)
&& let Some(ty) =
self.evaluate_expr_compare_op(lhs_ty, rhs_ty, *op, is_positive)
{
let place = self.expect_place(&left);
constraints.insert(place, ty);
}
}
}
ast::Expr::Call(ast::ExprCall {
range: _,
node_index: _,
@ -954,6 +951,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
&mut self,
subject: Expression<'db>,
singleton: ast::Singleton,
is_positive: bool,
) -> Option<NarrowingConstraints<'db>> {
let subject = place_expr(subject.node_ref(self.db, self.module))?;
let place = self.expect_place(&subject);
@ -963,6 +961,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
ast::Singleton::True => Type::BooleanLiteral(true),
ast::Singleton::False => Type::BooleanLiteral(false),
};
let ty = ty.negate_if(self.db, !is_positive);
Some(NarrowingConstraints::from_iter([(place, ty)]))
}
@ -986,6 +985,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
let ty = infer_same_file_expression_type(self.db, cls, TypeContext::default(), self.module)
.to_instance(self.db)?;
let ty = ty.negate_if(self.db, !is_positive);
Some(NarrowingConstraints::from_iter([(place, ty)]))
}
@ -993,13 +993,20 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
&mut self,
subject: Expression<'db>,
value: Expression<'db>,
is_positive: bool,
) -> Option<NarrowingConstraints<'db>> {
let place = {
let subject = place_expr(subject.node_ref(self.db, self.module))?;
let place = self.expect_place(&subject);
self.expect_place(&subject)
};
let subject_ty =
infer_same_file_expression_type(self.db, subject, TypeContext::default(), self.module);
let ty =
let value_ty =
infer_same_file_expression_type(self.db, value, TypeContext::default(), self.module);
Some(NarrowingConstraints::from_iter([(place, ty)]))
self.evaluate_expr_compare_op(subject_ty, value_ty, ast::CmpOp::Eq, is_positive)
.map(|ty| NarrowingConstraints::from_iter([(place, ty)]))
}
fn evaluate_match_pattern_or(
@ -1010,13 +1017,20 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
) -> Option<NarrowingConstraints<'db>> {
let db = self.db;
// DeMorgan's law---if the overall `or` is negated, we need to `and` the negated sub-constraints.
let merge_constraints = if is_positive {
merge_constraints_or
} else {
merge_constraints_and
};
predicates
.iter()
.filter_map(|predicate| {
self.evaluate_pattern_predicate_kind(predicate, subject, is_positive)
})
.reduce(|mut constraints, constraints_| {
merge_constraints_or(&mut constraints, &constraints_, db);
merge_constraints(&mut constraints, &constraints_, db);
constraints
})
}
@ -1048,7 +1062,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
let mut aggregation: Option<NarrowingConstraints> = None;
for sub_constraint in sub_constraints.into_iter().flatten() {
if let Some(ref mut some_aggregation) = aggregation {
merge_constraints_and(some_aggregation, sub_constraint, self.db);
merge_constraints_and(some_aggregation, &sub_constraint, self.db);
} else {
aggregation = Some(sub_constraint);
}