From c9dfb51f49a99ecc838dd21ee4ed564eec193740 Mon Sep 17 00:00:00 2001 From: Eric Mark Martin Date: Thu, 16 Oct 2025 03:50:32 -0400 Subject: [PATCH] [ty] Fix match pattern value narrowing to use equality semantics (#20882) ## Summary Resolves https://github.com/astral-sh/ty/issues/1349. Fix match statement value patterns to use equality comparison semantics instead of incorrectly narrowing to literal types directly. Value patterns use equality for matching, and equality can be overridden, so we can't always narrow to the matched literal. ## Test Plan Updated match.md with corrected expected types and an additional example with explanation --------- Co-authored-by: David Peter --- .../resources/mdtest/narrow/match.md | 176 +++++++++++------- crates/ty_python_semantic/src/types/narrow.rs | 76 +++++--- 2 files changed, 154 insertions(+), 98 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/match.md b/crates/ty_python_semantic/resources/mdtest/narrow/match.md index b6b0ec90e8..55772eab24 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/match.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/match.md @@ -71,98 +71,140 @@ reveal_type(x) # revealed: object ## Value patterns +Value patterns are evaluated by equality, which is overridable. Therefore successfully matching on +one can only give us information where we know how the subject type implements equality. + +Consider the following example. + ```py -def get_object() -> object: - return object() +from typing import Literal -x = get_object() +def _(x: Literal["foo"] | int): + match x: + case "foo": + reveal_type(x) # revealed: Literal["foo"] | int -reveal_type(x) # revealed: object + match x: + case "bar": + reveal_type(x) # revealed: int +``` -match x: - case "foo": - reveal_type(x) # revealed: Literal["foo"] - case 42: - reveal_type(x) # revealed: Literal[42] - case 6.0: - reveal_type(x) # revealed: float - case 1j: - reveal_type(x) # revealed: complex - case b"foo": - reveal_type(x) # revealed: Literal[b"foo"] +In the first `match`'s `case "foo"` all we know is `x == "foo"`. `x` could be an instance of an +arbitrary `int` subclass with an arbitrary `__eq__`, so we can't actually narrow to +`Literal["foo"]`. -reveal_type(x) # revealed: object +In the second `match`'s `case "bar"` we know `x == "bar"`. As discussed above, this isn't enough to +rule out `int`, but we know that `"foo" == "bar"` is false so we can eliminate `Literal["foo"]`. + +More examples follow. + +```py +from typing import Literal + +class C: + pass + +def _(x: Literal["foo", "bar", 42, b"foo"] | bool | complex): + match x: + case "foo": + reveal_type(x) # revealed: Literal["foo"] | int | float | complex + case 42: + reveal_type(x) # revealed: int | float | complex + case 6.0: + reveal_type(x) # revealed: Literal["bar", b"foo"] | (int & ~Literal[42]) | float | complex + case 1j: + reveal_type(x) # revealed: Literal["bar", b"foo"] | (int & ~Literal[42]) | float | complex + case b"foo": + reveal_type(x) # revealed: (int & ~Literal[42]) | Literal[b"foo"] | float | complex + case _: + reveal_type(x) # revealed: Literal["bar"] | (int & ~Literal[42]) | float | complex ``` ## Value patterns with guard ```py -def get_object() -> object: - return object() +from typing import Literal -x = get_object() +class C: + pass -reveal_type(x) # revealed: object - -match x: - case "foo" if reveal_type(x): # revealed: Literal["foo"] - pass - case 42 if reveal_type(x): # revealed: Literal[42] - pass - case 6.0 if reveal_type(x): # revealed: float - pass - case 1j if reveal_type(x): # revealed: complex - pass - case b"foo" if reveal_type(x): # revealed: Literal[b"foo"] - pass - -reveal_type(x) # revealed: object +def _(x: Literal["foo", b"bar"] | int): + match x: + case "foo" if reveal_type(x): # revealed: Literal["foo"] | int + pass + case b"bar" if reveal_type(x): # revealed: Literal[b"bar"] | int + pass + case 42 if reveal_type(x): # revealed: int + pass ``` ## Or patterns ```py -def get_object() -> object: - return object() +from typing import Literal +from enum import Enum -x = get_object() +class Color(Enum): + RED = 1 + GREEN = 2 + BLUE = 3 -reveal_type(x) # revealed: object +def _(color: Color): + match color: + case Color.RED | Color.GREEN: + reveal_type(color) # revealed: Literal[Color.RED, Color.GREEN] + case Color.BLUE: + reveal_type(color) # revealed: Literal[Color.BLUE] -match x: - case "foo" | 42 | None: - reveal_type(x) # revealed: Literal["foo", 42] | None - case "foo" | tuple(): - reveal_type(x) # revealed: tuple[Unknown, ...] - case True | False: - reveal_type(x) # revealed: bool - case 3.14 | 2.718 | 1.414: - reveal_type(x) # revealed: float + match color: + case Color.RED | Color.GREEN | Color.BLUE: + reveal_type(color) # revealed: Color -reveal_type(x) # revealed: object + match color: + case Color.RED: + reveal_type(color) # revealed: Literal[Color.RED] + case _: + reveal_type(color) # revealed: Literal[Color.GREEN, Color.BLUE] + +class A: ... +class B: ... +class C: ... + +def _(x: A | B | C): + match x: + case A() | B(): + reveal_type(x) # revealed: A | B + case C(): + reveal_type(x) # revealed: C & ~A & ~B + case _: + reveal_type(x) # revealed: Never + + match x: + case A() | B() | C(): + reveal_type(x) # revealed: A | B | C + case _: + reveal_type(x) # revealed: Never + + match x: + case A(): + reveal_type(x) # revealed: A + case _: + reveal_type(x) # revealed: (B & ~A) | (C & ~A) ``` ## Or patterns with guard ```py -def get_object() -> object: - return object() +from typing import Literal -x = get_object() - -reveal_type(x) # revealed: object - -match x: - case "foo" | 42 | None if reveal_type(x): # revealed: Literal["foo", 42] | None - pass - case "foo" | tuple() if reveal_type(x): # revealed: Literal["foo"] | tuple[Unknown, ...] - pass - case True | False if reveal_type(x): # revealed: bool - pass - case 3.14 | 2.718 | 1.414 if reveal_type(x): # revealed: float - pass - -reveal_type(x) # revealed: object +def _(x: Literal["foo", b"bar"] | int): + match x: + case "foo" | 42 if reveal_type(x): # revealed: Literal["foo"] | int + pass + case b"bar" if reveal_type(x): # revealed: Literal[b"bar"] | int + pass + case _ if reveal_type(x): # revealed: Literal["foo", b"bar"] | int + pass ``` ## Narrowing due to guard @@ -179,7 +221,7 @@ match x: case str() | float() if type(x) is str: reveal_type(x) # revealed: str case "foo" | 42 | None if isinstance(x, int): - reveal_type(x) # revealed: Literal[42] + reveal_type(x) # revealed: int case False if x: reveal_type(x) # revealed: Never case "foo" if x := "bar": @@ -201,7 +243,7 @@ reveal_type(x) # revealed: object match x: case str() | float() if type(x) is str and reveal_type(x): # revealed: str pass - case "foo" | 42 | None if isinstance(x, int) and reveal_type(x): # revealed: Literal[42] + case "foo" | 42 | None if isinstance(x, int) and reveal_type(x): # revealed: int pass case False if x and reveal_type(x): # revealed: Never pass diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index 20d505bcab..f725f32220 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -263,19 +263,19 @@ type NarrowingConstraints<'db> = FxHashMap>; fn merge_constraints_and<'db>( into: &mut NarrowingConstraints<'db>, - from: NarrowingConstraints<'db>, + from: &NarrowingConstraints<'db>, db: &'db dyn Db, ) { for (key, value) in from { - match into.entry(key) { + match into.entry(*key) { Entry::Occupied(mut entry) => { *entry.get_mut() = IntersectionBuilder::new(db) .add_positive(*entry.get()) - .add_positive(value) + .add_positive(*value) .build(); } Entry::Vacant(entry) => { - entry.insert(value); + entry.insert(*value); } } } @@ -303,12 +303,6 @@ fn merge_constraints_or<'db>( } } -fn negate_if<'db>(constraints: &mut NarrowingConstraints<'db>, db: &'db dyn Db, yes: bool) { - for (_place, ty) in constraints.iter_mut() { - *ty = ty.negate_if(db, yes); - } -} - fn place_expr(expr: &ast::Expr) -> Option { match expr { ast::Expr::Named(named) => PlaceExpr::try_from_expr(named.target.as_ref()), @@ -399,12 +393,14 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { ) -> Option> { match pattern_predicate_kind { PatternPredicateKind::Singleton(singleton) => { - self.evaluate_match_pattern_singleton(subject, *singleton) + self.evaluate_match_pattern_singleton(subject, *singleton, is_positive) } PatternPredicateKind::Class(cls, kind) => { self.evaluate_match_pattern_class(subject, *cls, *kind, is_positive) } - PatternPredicateKind::Value(expr) => self.evaluate_match_pattern_value(subject, *expr), + PatternPredicateKind::Value(expr) => { + self.evaluate_match_pattern_value(subject, *expr, is_positive) + } PatternPredicateKind::Or(predicates) => { self.evaluate_match_pattern_or(subject, predicates, is_positive) } @@ -420,12 +416,11 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { pattern: PatternPredicate<'db>, is_positive: bool, ) -> Option> { - let subject = pattern.subject(self.db); - self.evaluate_pattern_predicate_kind(pattern.kind(self.db), subject, is_positive) - .map(|mut constraints| { - negate_if(&mut constraints, self.db, !is_positive); - constraints - }) + self.evaluate_pattern_predicate_kind( + pattern.kind(self.db), + pattern.subject(self.db), + is_positive, + ) } fn places(&self) -> &'db PlaceTable { @@ -709,7 +704,10 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { lhs_ty: Type<'db>, rhs_ty: Type<'db>, op: ast::CmpOp, + is_positive: bool, ) -> Option> { + let op = if is_positive { op } else { op.negate() }; + match op { ast::CmpOp::IsNot => { if rhs_ty.is_singleton(self.db) { @@ -792,13 +790,12 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { | ast::Expr::Attribute(_) | ast::Expr::Subscript(_) | ast::Expr::Named(_) => { - if let Some(left) = place_expr(left) { - let op = if is_positive { *op } else { op.negate() }; - - if let Some(ty) = self.evaluate_expr_compare_op(lhs_ty, rhs_ty, op) { - let place = self.expect_place(&left); - constraints.insert(place, ty); - } + if let Some(left) = place_expr(left) + && let Some(ty) = + self.evaluate_expr_compare_op(lhs_ty, rhs_ty, *op, is_positive) + { + let place = self.expect_place(&left); + constraints.insert(place, ty); } } ast::Expr::Call(ast::ExprCall { @@ -954,6 +951,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { &mut self, subject: Expression<'db>, singleton: ast::Singleton, + is_positive: bool, ) -> Option> { let subject = place_expr(subject.node_ref(self.db, self.module))?; let place = self.expect_place(&subject); @@ -963,6 +961,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { ast::Singleton::True => Type::BooleanLiteral(true), ast::Singleton::False => Type::BooleanLiteral(false), }; + let ty = ty.negate_if(self.db, !is_positive); Some(NarrowingConstraints::from_iter([(place, ty)])) } @@ -986,6 +985,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { let ty = infer_same_file_expression_type(self.db, cls, TypeContext::default(), self.module) .to_instance(self.db)?; + let ty = ty.negate_if(self.db, !is_positive); Some(NarrowingConstraints::from_iter([(place, ty)])) } @@ -993,13 +993,20 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { &mut self, subject: Expression<'db>, value: Expression<'db>, + is_positive: bool, ) -> Option> { - let subject = place_expr(subject.node_ref(self.db, self.module))?; - let place = self.expect_place(&subject); + let place = { + let subject = place_expr(subject.node_ref(self.db, self.module))?; + self.expect_place(&subject) + }; + let subject_ty = + infer_same_file_expression_type(self.db, subject, TypeContext::default(), self.module); - let ty = + let value_ty = infer_same_file_expression_type(self.db, value, TypeContext::default(), self.module); - Some(NarrowingConstraints::from_iter([(place, ty)])) + + self.evaluate_expr_compare_op(subject_ty, value_ty, ast::CmpOp::Eq, is_positive) + .map(|ty| NarrowingConstraints::from_iter([(place, ty)])) } fn evaluate_match_pattern_or( @@ -1010,13 +1017,20 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { ) -> Option> { let db = self.db; + // DeMorgan's law---if the overall `or` is negated, we need to `and` the negated sub-constraints. + let merge_constraints = if is_positive { + merge_constraints_or + } else { + merge_constraints_and + }; + predicates .iter() .filter_map(|predicate| { self.evaluate_pattern_predicate_kind(predicate, subject, is_positive) }) .reduce(|mut constraints, constraints_| { - merge_constraints_or(&mut constraints, &constraints_, db); + merge_constraints(&mut constraints, &constraints_, db); constraints }) } @@ -1048,7 +1062,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { let mut aggregation: Option = None; for sub_constraint in sub_constraints.into_iter().flatten() { if let Some(ref mut some_aggregation) = aggregation { - merge_constraints_and(some_aggregation, sub_constraint, self.db); + merge_constraints_and(some_aggregation, &sub_constraint, self.db); } else { aggregation = Some(sub_constraint); }