[ty] Fix match pattern value narrowing to use equality semantics (#20882)

## Summary

Resolves https://github.com/astral-sh/ty/issues/1349.

Fix match statement value patterns to use equality comparison semantics
instead of incorrectly narrowing to literal types directly. Value
patterns use equality for matching, and equality can be overridden, so
we can't always narrow to the matched literal.

## Test Plan

Updated match.md with corrected expected types and an additional example
with explanation

---------

Co-authored-by: David Peter <mail@david-peter.de>
This commit is contained in:
Eric Mark Martin 2025-10-16 03:50:32 -04:00 committed by GitHub
parent fe4e3e2e75
commit c9dfb51f49
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 154 additions and 98 deletions

View file

@ -71,98 +71,140 @@ reveal_type(x) # revealed: object
## Value patterns ## Value patterns
Value patterns are evaluated by equality, which is overridable. Therefore successfully matching on
one can only give us information where we know how the subject type implements equality.
Consider the following example.
```py ```py
def get_object() -> object: from typing import Literal
return object()
x = get_object() def _(x: Literal["foo"] | int):
match x:
reveal_type(x) # revealed: object
match x:
case "foo": case "foo":
reveal_type(x) # revealed: Literal["foo"] reveal_type(x) # revealed: Literal["foo"] | int
case 42:
reveal_type(x) # revealed: Literal[42]
case 6.0:
reveal_type(x) # revealed: float
case 1j:
reveal_type(x) # revealed: complex
case b"foo":
reveal_type(x) # revealed: Literal[b"foo"]
reveal_type(x) # revealed: object match x:
case "bar":
reveal_type(x) # revealed: int
```
In the first `match`'s `case "foo"` all we know is `x == "foo"`. `x` could be an instance of an
arbitrary `int` subclass with an arbitrary `__eq__`, so we can't actually narrow to
`Literal["foo"]`.
In the second `match`'s `case "bar"` we know `x == "bar"`. As discussed above, this isn't enough to
rule out `int`, but we know that `"foo" == "bar"` is false so we can eliminate `Literal["foo"]`.
More examples follow.
```py
from typing import Literal
class C:
pass
def _(x: Literal["foo", "bar", 42, b"foo"] | bool | complex):
match x:
case "foo":
reveal_type(x) # revealed: Literal["foo"] | int | float | complex
case 42:
reveal_type(x) # revealed: int | float | complex
case 6.0:
reveal_type(x) # revealed: Literal["bar", b"foo"] | (int & ~Literal[42]) | float | complex
case 1j:
reveal_type(x) # revealed: Literal["bar", b"foo"] | (int & ~Literal[42]) | float | complex
case b"foo":
reveal_type(x) # revealed: (int & ~Literal[42]) | Literal[b"foo"] | float | complex
case _:
reveal_type(x) # revealed: Literal["bar"] | (int & ~Literal[42]) | float | complex
``` ```
## Value patterns with guard ## Value patterns with guard
```py ```py
def get_object() -> object: from typing import Literal
return object()
x = get_object() class C:
reveal_type(x) # revealed: object
match x:
case "foo" if reveal_type(x): # revealed: Literal["foo"]
pass
case 42 if reveal_type(x): # revealed: Literal[42]
pass
case 6.0 if reveal_type(x): # revealed: float
pass
case 1j if reveal_type(x): # revealed: complex
pass
case b"foo" if reveal_type(x): # revealed: Literal[b"foo"]
pass pass
reveal_type(x) # revealed: object def _(x: Literal["foo", b"bar"] | int):
match x:
case "foo" if reveal_type(x): # revealed: Literal["foo"] | int
pass
case b"bar" if reveal_type(x): # revealed: Literal[b"bar"] | int
pass
case 42 if reveal_type(x): # revealed: int
pass
``` ```
## Or patterns ## Or patterns
```py ```py
def get_object() -> object: from typing import Literal
return object() from enum import Enum
x = get_object() class Color(Enum):
RED = 1
GREEN = 2
BLUE = 3
reveal_type(x) # revealed: object def _(color: Color):
match color:
case Color.RED | Color.GREEN:
reveal_type(color) # revealed: Literal[Color.RED, Color.GREEN]
case Color.BLUE:
reveal_type(color) # revealed: Literal[Color.BLUE]
match x: match color:
case "foo" | 42 | None: case Color.RED | Color.GREEN | Color.BLUE:
reveal_type(x) # revealed: Literal["foo", 42] | None reveal_type(color) # revealed: Color
case "foo" | tuple():
reveal_type(x) # revealed: tuple[Unknown, ...]
case True | False:
reveal_type(x) # revealed: bool
case 3.14 | 2.718 | 1.414:
reveal_type(x) # revealed: float
reveal_type(x) # revealed: object match color:
case Color.RED:
reveal_type(color) # revealed: Literal[Color.RED]
case _:
reveal_type(color) # revealed: Literal[Color.GREEN, Color.BLUE]
class A: ...
class B: ...
class C: ...
def _(x: A | B | C):
match x:
case A() | B():
reveal_type(x) # revealed: A | B
case C():
reveal_type(x) # revealed: C & ~A & ~B
case _:
reveal_type(x) # revealed: Never
match x:
case A() | B() | C():
reveal_type(x) # revealed: A | B | C
case _:
reveal_type(x) # revealed: Never
match x:
case A():
reveal_type(x) # revealed: A
case _:
reveal_type(x) # revealed: (B & ~A) | (C & ~A)
``` ```
## Or patterns with guard ## Or patterns with guard
```py ```py
def get_object() -> object: from typing import Literal
return object()
x = get_object() def _(x: Literal["foo", b"bar"] | int):
match x:
reveal_type(x) # revealed: object case "foo" | 42 if reveal_type(x): # revealed: Literal["foo"] | int
match x:
case "foo" | 42 | None if reveal_type(x): # revealed: Literal["foo", 42] | None
pass pass
case "foo" | tuple() if reveal_type(x): # revealed: Literal["foo"] | tuple[Unknown, ...] case b"bar" if reveal_type(x): # revealed: Literal[b"bar"] | int
pass pass
case True | False if reveal_type(x): # revealed: bool case _ if reveal_type(x): # revealed: Literal["foo", b"bar"] | int
pass pass
case 3.14 | 2.718 | 1.414 if reveal_type(x): # revealed: float
pass
reveal_type(x) # revealed: object
``` ```
## Narrowing due to guard ## Narrowing due to guard
@ -179,7 +221,7 @@ match x:
case str() | float() if type(x) is str: case str() | float() if type(x) is str:
reveal_type(x) # revealed: str reveal_type(x) # revealed: str
case "foo" | 42 | None if isinstance(x, int): case "foo" | 42 | None if isinstance(x, int):
reveal_type(x) # revealed: Literal[42] reveal_type(x) # revealed: int
case False if x: case False if x:
reveal_type(x) # revealed: Never reveal_type(x) # revealed: Never
case "foo" if x := "bar": case "foo" if x := "bar":
@ -201,7 +243,7 @@ reveal_type(x) # revealed: object
match x: match x:
case str() | float() if type(x) is str and reveal_type(x): # revealed: str case str() | float() if type(x) is str and reveal_type(x): # revealed: str
pass pass
case "foo" | 42 | None if isinstance(x, int) and reveal_type(x): # revealed: Literal[42] case "foo" | 42 | None if isinstance(x, int) and reveal_type(x): # revealed: int
pass pass
case False if x and reveal_type(x): # revealed: Never case False if x and reveal_type(x): # revealed: Never
pass pass

View file

@ -263,19 +263,19 @@ type NarrowingConstraints<'db> = FxHashMap<ScopedPlaceId, Type<'db>>;
fn merge_constraints_and<'db>( fn merge_constraints_and<'db>(
into: &mut NarrowingConstraints<'db>, into: &mut NarrowingConstraints<'db>,
from: NarrowingConstraints<'db>, from: &NarrowingConstraints<'db>,
db: &'db dyn Db, db: &'db dyn Db,
) { ) {
for (key, value) in from { for (key, value) in from {
match into.entry(key) { match into.entry(*key) {
Entry::Occupied(mut entry) => { Entry::Occupied(mut entry) => {
*entry.get_mut() = IntersectionBuilder::new(db) *entry.get_mut() = IntersectionBuilder::new(db)
.add_positive(*entry.get()) .add_positive(*entry.get())
.add_positive(value) .add_positive(*value)
.build(); .build();
} }
Entry::Vacant(entry) => { Entry::Vacant(entry) => {
entry.insert(value); entry.insert(*value);
} }
} }
} }
@ -303,12 +303,6 @@ fn merge_constraints_or<'db>(
} }
} }
fn negate_if<'db>(constraints: &mut NarrowingConstraints<'db>, db: &'db dyn Db, yes: bool) {
for (_place, ty) in constraints.iter_mut() {
*ty = ty.negate_if(db, yes);
}
}
fn place_expr(expr: &ast::Expr) -> Option<PlaceExpr> { fn place_expr(expr: &ast::Expr) -> Option<PlaceExpr> {
match expr { match expr {
ast::Expr::Named(named) => PlaceExpr::try_from_expr(named.target.as_ref()), ast::Expr::Named(named) => PlaceExpr::try_from_expr(named.target.as_ref()),
@ -399,12 +393,14 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
) -> Option<NarrowingConstraints<'db>> { ) -> Option<NarrowingConstraints<'db>> {
match pattern_predicate_kind { match pattern_predicate_kind {
PatternPredicateKind::Singleton(singleton) => { PatternPredicateKind::Singleton(singleton) => {
self.evaluate_match_pattern_singleton(subject, *singleton) self.evaluate_match_pattern_singleton(subject, *singleton, is_positive)
} }
PatternPredicateKind::Class(cls, kind) => { PatternPredicateKind::Class(cls, kind) => {
self.evaluate_match_pattern_class(subject, *cls, *kind, is_positive) self.evaluate_match_pattern_class(subject, *cls, *kind, is_positive)
} }
PatternPredicateKind::Value(expr) => self.evaluate_match_pattern_value(subject, *expr), PatternPredicateKind::Value(expr) => {
self.evaluate_match_pattern_value(subject, *expr, is_positive)
}
PatternPredicateKind::Or(predicates) => { PatternPredicateKind::Or(predicates) => {
self.evaluate_match_pattern_or(subject, predicates, is_positive) self.evaluate_match_pattern_or(subject, predicates, is_positive)
} }
@ -420,12 +416,11 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
pattern: PatternPredicate<'db>, pattern: PatternPredicate<'db>,
is_positive: bool, is_positive: bool,
) -> Option<NarrowingConstraints<'db>> { ) -> Option<NarrowingConstraints<'db>> {
let subject = pattern.subject(self.db); self.evaluate_pattern_predicate_kind(
self.evaluate_pattern_predicate_kind(pattern.kind(self.db), subject, is_positive) pattern.kind(self.db),
.map(|mut constraints| { pattern.subject(self.db),
negate_if(&mut constraints, self.db, !is_positive); is_positive,
constraints )
})
} }
fn places(&self) -> &'db PlaceTable { fn places(&self) -> &'db PlaceTable {
@ -709,7 +704,10 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
lhs_ty: Type<'db>, lhs_ty: Type<'db>,
rhs_ty: Type<'db>, rhs_ty: Type<'db>,
op: ast::CmpOp, op: ast::CmpOp,
is_positive: bool,
) -> Option<Type<'db>> { ) -> Option<Type<'db>> {
let op = if is_positive { op } else { op.negate() };
match op { match op {
ast::CmpOp::IsNot => { ast::CmpOp::IsNot => {
if rhs_ty.is_singleton(self.db) { if rhs_ty.is_singleton(self.db) {
@ -792,15 +790,14 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
| ast::Expr::Attribute(_) | ast::Expr::Attribute(_)
| ast::Expr::Subscript(_) | ast::Expr::Subscript(_)
| ast::Expr::Named(_) => { | ast::Expr::Named(_) => {
if let Some(left) = place_expr(left) { if let Some(left) = place_expr(left)
let op = if is_positive { *op } else { op.negate() }; && let Some(ty) =
self.evaluate_expr_compare_op(lhs_ty, rhs_ty, *op, is_positive)
if let Some(ty) = self.evaluate_expr_compare_op(lhs_ty, rhs_ty, op) { {
let place = self.expect_place(&left); let place = self.expect_place(&left);
constraints.insert(place, ty); constraints.insert(place, ty);
} }
} }
}
ast::Expr::Call(ast::ExprCall { ast::Expr::Call(ast::ExprCall {
range: _, range: _,
node_index: _, node_index: _,
@ -954,6 +951,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
&mut self, &mut self,
subject: Expression<'db>, subject: Expression<'db>,
singleton: ast::Singleton, singleton: ast::Singleton,
is_positive: bool,
) -> Option<NarrowingConstraints<'db>> { ) -> Option<NarrowingConstraints<'db>> {
let subject = place_expr(subject.node_ref(self.db, self.module))?; let subject = place_expr(subject.node_ref(self.db, self.module))?;
let place = self.expect_place(&subject); let place = self.expect_place(&subject);
@ -963,6 +961,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
ast::Singleton::True => Type::BooleanLiteral(true), ast::Singleton::True => Type::BooleanLiteral(true),
ast::Singleton::False => Type::BooleanLiteral(false), ast::Singleton::False => Type::BooleanLiteral(false),
}; };
let ty = ty.negate_if(self.db, !is_positive);
Some(NarrowingConstraints::from_iter([(place, ty)])) Some(NarrowingConstraints::from_iter([(place, ty)]))
} }
@ -986,6 +985,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
let ty = infer_same_file_expression_type(self.db, cls, TypeContext::default(), self.module) let ty = infer_same_file_expression_type(self.db, cls, TypeContext::default(), self.module)
.to_instance(self.db)?; .to_instance(self.db)?;
let ty = ty.negate_if(self.db, !is_positive);
Some(NarrowingConstraints::from_iter([(place, ty)])) Some(NarrowingConstraints::from_iter([(place, ty)]))
} }
@ -993,13 +993,20 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
&mut self, &mut self,
subject: Expression<'db>, subject: Expression<'db>,
value: Expression<'db>, value: Expression<'db>,
is_positive: bool,
) -> Option<NarrowingConstraints<'db>> { ) -> Option<NarrowingConstraints<'db>> {
let place = {
let subject = place_expr(subject.node_ref(self.db, self.module))?; let subject = place_expr(subject.node_ref(self.db, self.module))?;
let place = self.expect_place(&subject); self.expect_place(&subject)
};
let subject_ty =
infer_same_file_expression_type(self.db, subject, TypeContext::default(), self.module);
let ty = let value_ty =
infer_same_file_expression_type(self.db, value, TypeContext::default(), self.module); infer_same_file_expression_type(self.db, value, TypeContext::default(), self.module);
Some(NarrowingConstraints::from_iter([(place, ty)]))
self.evaluate_expr_compare_op(subject_ty, value_ty, ast::CmpOp::Eq, is_positive)
.map(|ty| NarrowingConstraints::from_iter([(place, ty)]))
} }
fn evaluate_match_pattern_or( fn evaluate_match_pattern_or(
@ -1010,13 +1017,20 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
) -> Option<NarrowingConstraints<'db>> { ) -> Option<NarrowingConstraints<'db>> {
let db = self.db; let db = self.db;
// DeMorgan's law---if the overall `or` is negated, we need to `and` the negated sub-constraints.
let merge_constraints = if is_positive {
merge_constraints_or
} else {
merge_constraints_and
};
predicates predicates
.iter() .iter()
.filter_map(|predicate| { .filter_map(|predicate| {
self.evaluate_pattern_predicate_kind(predicate, subject, is_positive) self.evaluate_pattern_predicate_kind(predicate, subject, is_positive)
}) })
.reduce(|mut constraints, constraints_| { .reduce(|mut constraints, constraints_| {
merge_constraints_or(&mut constraints, &constraints_, db); merge_constraints(&mut constraints, &constraints_, db);
constraints constraints
}) })
} }
@ -1048,7 +1062,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
let mut aggregation: Option<NarrowingConstraints> = None; let mut aggregation: Option<NarrowingConstraints> = None;
for sub_constraint in sub_constraints.into_iter().flatten() { for sub_constraint in sub_constraints.into_iter().flatten() {
if let Some(ref mut some_aggregation) = aggregation { if let Some(ref mut some_aggregation) = aggregation {
merge_constraints_and(some_aggregation, sub_constraint, self.db); merge_constraints_and(some_aggregation, &sub_constraint, self.db);
} else { } else {
aggregation = Some(sub_constraint); aggregation = Some(sub_constraint);
} }