[ty] Fix match pattern value narrowing to use equality semantics (#20882)

## Summary

Resolves https://github.com/astral-sh/ty/issues/1349.

Fix match statement value patterns to use equality comparison semantics
instead of incorrectly narrowing to literal types directly. Value
patterns use equality for matching, and equality can be overridden, so
we can't always narrow to the matched literal.

## Test Plan

Updated match.md with corrected expected types and an additional example
with explanation

---------

Co-authored-by: David Peter <mail@david-peter.de>
This commit is contained in:
Eric Mark Martin 2025-10-16 03:50:32 -04:00 committed by GitHub
parent fe4e3e2e75
commit c9dfb51f49
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 154 additions and 98 deletions

View file

@ -263,19 +263,19 @@ type NarrowingConstraints<'db> = FxHashMap<ScopedPlaceId, Type<'db>>;
fn merge_constraints_and<'db>(
into: &mut NarrowingConstraints<'db>,
from: NarrowingConstraints<'db>,
from: &NarrowingConstraints<'db>,
db: &'db dyn Db,
) {
for (key, value) in from {
match into.entry(key) {
match into.entry(*key) {
Entry::Occupied(mut entry) => {
*entry.get_mut() = IntersectionBuilder::new(db)
.add_positive(*entry.get())
.add_positive(value)
.add_positive(*value)
.build();
}
Entry::Vacant(entry) => {
entry.insert(value);
entry.insert(*value);
}
}
}
@ -303,12 +303,6 @@ fn merge_constraints_or<'db>(
}
}
fn negate_if<'db>(constraints: &mut NarrowingConstraints<'db>, db: &'db dyn Db, yes: bool) {
for (_place, ty) in constraints.iter_mut() {
*ty = ty.negate_if(db, yes);
}
}
fn place_expr(expr: &ast::Expr) -> Option<PlaceExpr> {
match expr {
ast::Expr::Named(named) => PlaceExpr::try_from_expr(named.target.as_ref()),
@ -399,12 +393,14 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
) -> Option<NarrowingConstraints<'db>> {
match pattern_predicate_kind {
PatternPredicateKind::Singleton(singleton) => {
self.evaluate_match_pattern_singleton(subject, *singleton)
self.evaluate_match_pattern_singleton(subject, *singleton, is_positive)
}
PatternPredicateKind::Class(cls, kind) => {
self.evaluate_match_pattern_class(subject, *cls, *kind, is_positive)
}
PatternPredicateKind::Value(expr) => self.evaluate_match_pattern_value(subject, *expr),
PatternPredicateKind::Value(expr) => {
self.evaluate_match_pattern_value(subject, *expr, is_positive)
}
PatternPredicateKind::Or(predicates) => {
self.evaluate_match_pattern_or(subject, predicates, is_positive)
}
@ -420,12 +416,11 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
pattern: PatternPredicate<'db>,
is_positive: bool,
) -> Option<NarrowingConstraints<'db>> {
let subject = pattern.subject(self.db);
self.evaluate_pattern_predicate_kind(pattern.kind(self.db), subject, is_positive)
.map(|mut constraints| {
negate_if(&mut constraints, self.db, !is_positive);
constraints
})
self.evaluate_pattern_predicate_kind(
pattern.kind(self.db),
pattern.subject(self.db),
is_positive,
)
}
fn places(&self) -> &'db PlaceTable {
@ -709,7 +704,10 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
lhs_ty: Type<'db>,
rhs_ty: Type<'db>,
op: ast::CmpOp,
is_positive: bool,
) -> Option<Type<'db>> {
let op = if is_positive { op } else { op.negate() };
match op {
ast::CmpOp::IsNot => {
if rhs_ty.is_singleton(self.db) {
@ -792,13 +790,12 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
| ast::Expr::Attribute(_)
| ast::Expr::Subscript(_)
| ast::Expr::Named(_) => {
if let Some(left) = place_expr(left) {
let op = if is_positive { *op } else { op.negate() };
if let Some(ty) = self.evaluate_expr_compare_op(lhs_ty, rhs_ty, op) {
let place = self.expect_place(&left);
constraints.insert(place, ty);
}
if let Some(left) = place_expr(left)
&& let Some(ty) =
self.evaluate_expr_compare_op(lhs_ty, rhs_ty, *op, is_positive)
{
let place = self.expect_place(&left);
constraints.insert(place, ty);
}
}
ast::Expr::Call(ast::ExprCall {
@ -954,6 +951,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
&mut self,
subject: Expression<'db>,
singleton: ast::Singleton,
is_positive: bool,
) -> Option<NarrowingConstraints<'db>> {
let subject = place_expr(subject.node_ref(self.db, self.module))?;
let place = self.expect_place(&subject);
@ -963,6 +961,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
ast::Singleton::True => Type::BooleanLiteral(true),
ast::Singleton::False => Type::BooleanLiteral(false),
};
let ty = ty.negate_if(self.db, !is_positive);
Some(NarrowingConstraints::from_iter([(place, ty)]))
}
@ -986,6 +985,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
let ty = infer_same_file_expression_type(self.db, cls, TypeContext::default(), self.module)
.to_instance(self.db)?;
let ty = ty.negate_if(self.db, !is_positive);
Some(NarrowingConstraints::from_iter([(place, ty)]))
}
@ -993,13 +993,20 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
&mut self,
subject: Expression<'db>,
value: Expression<'db>,
is_positive: bool,
) -> Option<NarrowingConstraints<'db>> {
let subject = place_expr(subject.node_ref(self.db, self.module))?;
let place = self.expect_place(&subject);
let place = {
let subject = place_expr(subject.node_ref(self.db, self.module))?;
self.expect_place(&subject)
};
let subject_ty =
infer_same_file_expression_type(self.db, subject, TypeContext::default(), self.module);
let ty =
let value_ty =
infer_same_file_expression_type(self.db, value, TypeContext::default(), self.module);
Some(NarrowingConstraints::from_iter([(place, ty)]))
self.evaluate_expr_compare_op(subject_ty, value_ty, ast::CmpOp::Eq, is_positive)
.map(|ty| NarrowingConstraints::from_iter([(place, ty)]))
}
fn evaluate_match_pattern_or(
@ -1010,13 +1017,20 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
) -> Option<NarrowingConstraints<'db>> {
let db = self.db;
// DeMorgan's law---if the overall `or` is negated, we need to `and` the negated sub-constraints.
let merge_constraints = if is_positive {
merge_constraints_or
} else {
merge_constraints_and
};
predicates
.iter()
.filter_map(|predicate| {
self.evaluate_pattern_predicate_kind(predicate, subject, is_positive)
})
.reduce(|mut constraints, constraints_| {
merge_constraints_or(&mut constraints, &constraints_, db);
merge_constraints(&mut constraints, &constraints_, db);
constraints
})
}
@ -1048,7 +1062,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
let mut aggregation: Option<NarrowingConstraints> = None;
for sub_constraint in sub_constraints.into_iter().flatten() {
if let Some(ref mut some_aggregation) = aggregation {
merge_constraints_and(some_aggregation, sub_constraint, self.db);
merge_constraints_and(some_aggregation, &sub_constraint, self.db);
} else {
aggregation = Some(sub_constraint);
}