mirror of
https://github.com/astral-sh/ruff.git
synced 2025-10-21 07:41:53 +00:00
[ty] Fix match pattern value narrowing to use equality semantics (#20882)
## Summary Resolves https://github.com/astral-sh/ty/issues/1349. Fix match statement value patterns to use equality comparison semantics instead of incorrectly narrowing to literal types directly. Value patterns use equality for matching, and equality can be overridden, so we can't always narrow to the matched literal. ## Test Plan Updated match.md with corrected expected types and an additional example with explanation --------- Co-authored-by: David Peter <mail@david-peter.de>
This commit is contained in:
parent
fe4e3e2e75
commit
c9dfb51f49
2 changed files with 154 additions and 98 deletions
|
@ -71,98 +71,140 @@ reveal_type(x) # revealed: object
|
|||
|
||||
## Value patterns
|
||||
|
||||
Value patterns are evaluated by equality, which is overridable. Therefore successfully matching on
|
||||
one can only give us information where we know how the subject type implements equality.
|
||||
|
||||
Consider the following example.
|
||||
|
||||
```py
|
||||
def get_object() -> object:
|
||||
return object()
|
||||
from typing import Literal
|
||||
|
||||
x = get_object()
|
||||
def _(x: Literal["foo"] | int):
|
||||
match x:
|
||||
case "foo":
|
||||
reveal_type(x) # revealed: Literal["foo"] | int
|
||||
|
||||
reveal_type(x) # revealed: object
|
||||
match x:
|
||||
case "bar":
|
||||
reveal_type(x) # revealed: int
|
||||
```
|
||||
|
||||
match x:
|
||||
case "foo":
|
||||
reveal_type(x) # revealed: Literal["foo"]
|
||||
case 42:
|
||||
reveal_type(x) # revealed: Literal[42]
|
||||
case 6.0:
|
||||
reveal_type(x) # revealed: float
|
||||
case 1j:
|
||||
reveal_type(x) # revealed: complex
|
||||
case b"foo":
|
||||
reveal_type(x) # revealed: Literal[b"foo"]
|
||||
In the first `match`'s `case "foo"` all we know is `x == "foo"`. `x` could be an instance of an
|
||||
arbitrary `int` subclass with an arbitrary `__eq__`, so we can't actually narrow to
|
||||
`Literal["foo"]`.
|
||||
|
||||
reveal_type(x) # revealed: object
|
||||
In the second `match`'s `case "bar"` we know `x == "bar"`. As discussed above, this isn't enough to
|
||||
rule out `int`, but we know that `"foo" == "bar"` is false so we can eliminate `Literal["foo"]`.
|
||||
|
||||
More examples follow.
|
||||
|
||||
```py
|
||||
from typing import Literal
|
||||
|
||||
class C:
|
||||
pass
|
||||
|
||||
def _(x: Literal["foo", "bar", 42, b"foo"] | bool | complex):
|
||||
match x:
|
||||
case "foo":
|
||||
reveal_type(x) # revealed: Literal["foo"] | int | float | complex
|
||||
case 42:
|
||||
reveal_type(x) # revealed: int | float | complex
|
||||
case 6.0:
|
||||
reveal_type(x) # revealed: Literal["bar", b"foo"] | (int & ~Literal[42]) | float | complex
|
||||
case 1j:
|
||||
reveal_type(x) # revealed: Literal["bar", b"foo"] | (int & ~Literal[42]) | float | complex
|
||||
case b"foo":
|
||||
reveal_type(x) # revealed: (int & ~Literal[42]) | Literal[b"foo"] | float | complex
|
||||
case _:
|
||||
reveal_type(x) # revealed: Literal["bar"] | (int & ~Literal[42]) | float | complex
|
||||
```
|
||||
|
||||
## Value patterns with guard
|
||||
|
||||
```py
|
||||
def get_object() -> object:
|
||||
return object()
|
||||
from typing import Literal
|
||||
|
||||
x = get_object()
|
||||
class C:
|
||||
pass
|
||||
|
||||
reveal_type(x) # revealed: object
|
||||
|
||||
match x:
|
||||
case "foo" if reveal_type(x): # revealed: Literal["foo"]
|
||||
pass
|
||||
case 42 if reveal_type(x): # revealed: Literal[42]
|
||||
pass
|
||||
case 6.0 if reveal_type(x): # revealed: float
|
||||
pass
|
||||
case 1j if reveal_type(x): # revealed: complex
|
||||
pass
|
||||
case b"foo" if reveal_type(x): # revealed: Literal[b"foo"]
|
||||
pass
|
||||
|
||||
reveal_type(x) # revealed: object
|
||||
def _(x: Literal["foo", b"bar"] | int):
|
||||
match x:
|
||||
case "foo" if reveal_type(x): # revealed: Literal["foo"] | int
|
||||
pass
|
||||
case b"bar" if reveal_type(x): # revealed: Literal[b"bar"] | int
|
||||
pass
|
||||
case 42 if reveal_type(x): # revealed: int
|
||||
pass
|
||||
```
|
||||
|
||||
## Or patterns
|
||||
|
||||
```py
|
||||
def get_object() -> object:
|
||||
return object()
|
||||
from typing import Literal
|
||||
from enum import Enum
|
||||
|
||||
x = get_object()
|
||||
class Color(Enum):
|
||||
RED = 1
|
||||
GREEN = 2
|
||||
BLUE = 3
|
||||
|
||||
reveal_type(x) # revealed: object
|
||||
def _(color: Color):
|
||||
match color:
|
||||
case Color.RED | Color.GREEN:
|
||||
reveal_type(color) # revealed: Literal[Color.RED, Color.GREEN]
|
||||
case Color.BLUE:
|
||||
reveal_type(color) # revealed: Literal[Color.BLUE]
|
||||
|
||||
match x:
|
||||
case "foo" | 42 | None:
|
||||
reveal_type(x) # revealed: Literal["foo", 42] | None
|
||||
case "foo" | tuple():
|
||||
reveal_type(x) # revealed: tuple[Unknown, ...]
|
||||
case True | False:
|
||||
reveal_type(x) # revealed: bool
|
||||
case 3.14 | 2.718 | 1.414:
|
||||
reveal_type(x) # revealed: float
|
||||
match color:
|
||||
case Color.RED | Color.GREEN | Color.BLUE:
|
||||
reveal_type(color) # revealed: Color
|
||||
|
||||
reveal_type(x) # revealed: object
|
||||
match color:
|
||||
case Color.RED:
|
||||
reveal_type(color) # revealed: Literal[Color.RED]
|
||||
case _:
|
||||
reveal_type(color) # revealed: Literal[Color.GREEN, Color.BLUE]
|
||||
|
||||
class A: ...
|
||||
class B: ...
|
||||
class C: ...
|
||||
|
||||
def _(x: A | B | C):
|
||||
match x:
|
||||
case A() | B():
|
||||
reveal_type(x) # revealed: A | B
|
||||
case C():
|
||||
reveal_type(x) # revealed: C & ~A & ~B
|
||||
case _:
|
||||
reveal_type(x) # revealed: Never
|
||||
|
||||
match x:
|
||||
case A() | B() | C():
|
||||
reveal_type(x) # revealed: A | B | C
|
||||
case _:
|
||||
reveal_type(x) # revealed: Never
|
||||
|
||||
match x:
|
||||
case A():
|
||||
reveal_type(x) # revealed: A
|
||||
case _:
|
||||
reveal_type(x) # revealed: (B & ~A) | (C & ~A)
|
||||
```
|
||||
|
||||
## Or patterns with guard
|
||||
|
||||
```py
|
||||
def get_object() -> object:
|
||||
return object()
|
||||
from typing import Literal
|
||||
|
||||
x = get_object()
|
||||
|
||||
reveal_type(x) # revealed: object
|
||||
|
||||
match x:
|
||||
case "foo" | 42 | None if reveal_type(x): # revealed: Literal["foo", 42] | None
|
||||
pass
|
||||
case "foo" | tuple() if reveal_type(x): # revealed: Literal["foo"] | tuple[Unknown, ...]
|
||||
pass
|
||||
case True | False if reveal_type(x): # revealed: bool
|
||||
pass
|
||||
case 3.14 | 2.718 | 1.414 if reveal_type(x): # revealed: float
|
||||
pass
|
||||
|
||||
reveal_type(x) # revealed: object
|
||||
def _(x: Literal["foo", b"bar"] | int):
|
||||
match x:
|
||||
case "foo" | 42 if reveal_type(x): # revealed: Literal["foo"] | int
|
||||
pass
|
||||
case b"bar" if reveal_type(x): # revealed: Literal[b"bar"] | int
|
||||
pass
|
||||
case _ if reveal_type(x): # revealed: Literal["foo", b"bar"] | int
|
||||
pass
|
||||
```
|
||||
|
||||
## Narrowing due to guard
|
||||
|
@ -179,7 +221,7 @@ match x:
|
|||
case str() | float() if type(x) is str:
|
||||
reveal_type(x) # revealed: str
|
||||
case "foo" | 42 | None if isinstance(x, int):
|
||||
reveal_type(x) # revealed: Literal[42]
|
||||
reveal_type(x) # revealed: int
|
||||
case False if x:
|
||||
reveal_type(x) # revealed: Never
|
||||
case "foo" if x := "bar":
|
||||
|
@ -201,7 +243,7 @@ reveal_type(x) # revealed: object
|
|||
match x:
|
||||
case str() | float() if type(x) is str and reveal_type(x): # revealed: str
|
||||
pass
|
||||
case "foo" | 42 | None if isinstance(x, int) and reveal_type(x): # revealed: Literal[42]
|
||||
case "foo" | 42 | None if isinstance(x, int) and reveal_type(x): # revealed: int
|
||||
pass
|
||||
case False if x and reveal_type(x): # revealed: Never
|
||||
pass
|
||||
|
|
|
@ -263,19 +263,19 @@ type NarrowingConstraints<'db> = FxHashMap<ScopedPlaceId, Type<'db>>;
|
|||
|
||||
fn merge_constraints_and<'db>(
|
||||
into: &mut NarrowingConstraints<'db>,
|
||||
from: NarrowingConstraints<'db>,
|
||||
from: &NarrowingConstraints<'db>,
|
||||
db: &'db dyn Db,
|
||||
) {
|
||||
for (key, value) in from {
|
||||
match into.entry(key) {
|
||||
match into.entry(*key) {
|
||||
Entry::Occupied(mut entry) => {
|
||||
*entry.get_mut() = IntersectionBuilder::new(db)
|
||||
.add_positive(*entry.get())
|
||||
.add_positive(value)
|
||||
.add_positive(*value)
|
||||
.build();
|
||||
}
|
||||
Entry::Vacant(entry) => {
|
||||
entry.insert(value);
|
||||
entry.insert(*value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -303,12 +303,6 @@ fn merge_constraints_or<'db>(
|
|||
}
|
||||
}
|
||||
|
||||
fn negate_if<'db>(constraints: &mut NarrowingConstraints<'db>, db: &'db dyn Db, yes: bool) {
|
||||
for (_place, ty) in constraints.iter_mut() {
|
||||
*ty = ty.negate_if(db, yes);
|
||||
}
|
||||
}
|
||||
|
||||
fn place_expr(expr: &ast::Expr) -> Option<PlaceExpr> {
|
||||
match expr {
|
||||
ast::Expr::Named(named) => PlaceExpr::try_from_expr(named.target.as_ref()),
|
||||
|
@ -399,12 +393,14 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
|
|||
) -> Option<NarrowingConstraints<'db>> {
|
||||
match pattern_predicate_kind {
|
||||
PatternPredicateKind::Singleton(singleton) => {
|
||||
self.evaluate_match_pattern_singleton(subject, *singleton)
|
||||
self.evaluate_match_pattern_singleton(subject, *singleton, is_positive)
|
||||
}
|
||||
PatternPredicateKind::Class(cls, kind) => {
|
||||
self.evaluate_match_pattern_class(subject, *cls, *kind, is_positive)
|
||||
}
|
||||
PatternPredicateKind::Value(expr) => self.evaluate_match_pattern_value(subject, *expr),
|
||||
PatternPredicateKind::Value(expr) => {
|
||||
self.evaluate_match_pattern_value(subject, *expr, is_positive)
|
||||
}
|
||||
PatternPredicateKind::Or(predicates) => {
|
||||
self.evaluate_match_pattern_or(subject, predicates, is_positive)
|
||||
}
|
||||
|
@ -420,12 +416,11 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
|
|||
pattern: PatternPredicate<'db>,
|
||||
is_positive: bool,
|
||||
) -> Option<NarrowingConstraints<'db>> {
|
||||
let subject = pattern.subject(self.db);
|
||||
self.evaluate_pattern_predicate_kind(pattern.kind(self.db), subject, is_positive)
|
||||
.map(|mut constraints| {
|
||||
negate_if(&mut constraints, self.db, !is_positive);
|
||||
constraints
|
||||
})
|
||||
self.evaluate_pattern_predicate_kind(
|
||||
pattern.kind(self.db),
|
||||
pattern.subject(self.db),
|
||||
is_positive,
|
||||
)
|
||||
}
|
||||
|
||||
fn places(&self) -> &'db PlaceTable {
|
||||
|
@ -709,7 +704,10 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
|
|||
lhs_ty: Type<'db>,
|
||||
rhs_ty: Type<'db>,
|
||||
op: ast::CmpOp,
|
||||
is_positive: bool,
|
||||
) -> Option<Type<'db>> {
|
||||
let op = if is_positive { op } else { op.negate() };
|
||||
|
||||
match op {
|
||||
ast::CmpOp::IsNot => {
|
||||
if rhs_ty.is_singleton(self.db) {
|
||||
|
@ -792,13 +790,12 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
|
|||
| ast::Expr::Attribute(_)
|
||||
| ast::Expr::Subscript(_)
|
||||
| ast::Expr::Named(_) => {
|
||||
if let Some(left) = place_expr(left) {
|
||||
let op = if is_positive { *op } else { op.negate() };
|
||||
|
||||
if let Some(ty) = self.evaluate_expr_compare_op(lhs_ty, rhs_ty, op) {
|
||||
let place = self.expect_place(&left);
|
||||
constraints.insert(place, ty);
|
||||
}
|
||||
if let Some(left) = place_expr(left)
|
||||
&& let Some(ty) =
|
||||
self.evaluate_expr_compare_op(lhs_ty, rhs_ty, *op, is_positive)
|
||||
{
|
||||
let place = self.expect_place(&left);
|
||||
constraints.insert(place, ty);
|
||||
}
|
||||
}
|
||||
ast::Expr::Call(ast::ExprCall {
|
||||
|
@ -954,6 +951,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
|
|||
&mut self,
|
||||
subject: Expression<'db>,
|
||||
singleton: ast::Singleton,
|
||||
is_positive: bool,
|
||||
) -> Option<NarrowingConstraints<'db>> {
|
||||
let subject = place_expr(subject.node_ref(self.db, self.module))?;
|
||||
let place = self.expect_place(&subject);
|
||||
|
@ -963,6 +961,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
|
|||
ast::Singleton::True => Type::BooleanLiteral(true),
|
||||
ast::Singleton::False => Type::BooleanLiteral(false),
|
||||
};
|
||||
let ty = ty.negate_if(self.db, !is_positive);
|
||||
Some(NarrowingConstraints::from_iter([(place, ty)]))
|
||||
}
|
||||
|
||||
|
@ -986,6 +985,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
|
|||
let ty = infer_same_file_expression_type(self.db, cls, TypeContext::default(), self.module)
|
||||
.to_instance(self.db)?;
|
||||
|
||||
let ty = ty.negate_if(self.db, !is_positive);
|
||||
Some(NarrowingConstraints::from_iter([(place, ty)]))
|
||||
}
|
||||
|
||||
|
@ -993,13 +993,20 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
|
|||
&mut self,
|
||||
subject: Expression<'db>,
|
||||
value: Expression<'db>,
|
||||
is_positive: bool,
|
||||
) -> Option<NarrowingConstraints<'db>> {
|
||||
let subject = place_expr(subject.node_ref(self.db, self.module))?;
|
||||
let place = self.expect_place(&subject);
|
||||
let place = {
|
||||
let subject = place_expr(subject.node_ref(self.db, self.module))?;
|
||||
self.expect_place(&subject)
|
||||
};
|
||||
let subject_ty =
|
||||
infer_same_file_expression_type(self.db, subject, TypeContext::default(), self.module);
|
||||
|
||||
let ty =
|
||||
let value_ty =
|
||||
infer_same_file_expression_type(self.db, value, TypeContext::default(), self.module);
|
||||
Some(NarrowingConstraints::from_iter([(place, ty)]))
|
||||
|
||||
self.evaluate_expr_compare_op(subject_ty, value_ty, ast::CmpOp::Eq, is_positive)
|
||||
.map(|ty| NarrowingConstraints::from_iter([(place, ty)]))
|
||||
}
|
||||
|
||||
fn evaluate_match_pattern_or(
|
||||
|
@ -1010,13 +1017,20 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
|
|||
) -> Option<NarrowingConstraints<'db>> {
|
||||
let db = self.db;
|
||||
|
||||
// DeMorgan's law---if the overall `or` is negated, we need to `and` the negated sub-constraints.
|
||||
let merge_constraints = if is_positive {
|
||||
merge_constraints_or
|
||||
} else {
|
||||
merge_constraints_and
|
||||
};
|
||||
|
||||
predicates
|
||||
.iter()
|
||||
.filter_map(|predicate| {
|
||||
self.evaluate_pattern_predicate_kind(predicate, subject, is_positive)
|
||||
})
|
||||
.reduce(|mut constraints, constraints_| {
|
||||
merge_constraints_or(&mut constraints, &constraints_, db);
|
||||
merge_constraints(&mut constraints, &constraints_, db);
|
||||
constraints
|
||||
})
|
||||
}
|
||||
|
@ -1048,7 +1062,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
|
|||
let mut aggregation: Option<NarrowingConstraints> = None;
|
||||
for sub_constraint in sub_constraints.into_iter().flatten() {
|
||||
if let Some(ref mut some_aggregation) = aggregation {
|
||||
merge_constraints_and(some_aggregation, sub_constraint, self.db);
|
||||
merge_constraints_and(some_aggregation, &sub_constraint, self.db);
|
||||
} else {
|
||||
aggregation = Some(sub_constraint);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue