diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/match.md b/crates/ty_python_semantic/resources/mdtest/narrow/match.md index b6b0ec90e8..55772eab24 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/match.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/match.md @@ -71,98 +71,140 @@ reveal_type(x) # revealed: object ## Value patterns +Value patterns are evaluated by equality, which is overridable. Therefore successfully matching on +one can only give us information where we know how the subject type implements equality. + +Consider the following example. + ```py -def get_object() -> object: - return object() +from typing import Literal -x = get_object() +def _(x: Literal["foo"] | int): + match x: + case "foo": + reveal_type(x) # revealed: Literal["foo"] | int -reveal_type(x) # revealed: object + match x: + case "bar": + reveal_type(x) # revealed: int +``` -match x: - case "foo": - reveal_type(x) # revealed: Literal["foo"] - case 42: - reveal_type(x) # revealed: Literal[42] - case 6.0: - reveal_type(x) # revealed: float - case 1j: - reveal_type(x) # revealed: complex - case b"foo": - reveal_type(x) # revealed: Literal[b"foo"] +In the first `match`'s `case "foo"` all we know is `x == "foo"`. `x` could be an instance of an +arbitrary `int` subclass with an arbitrary `__eq__`, so we can't actually narrow to +`Literal["foo"]`. -reveal_type(x) # revealed: object +In the second `match`'s `case "bar"` we know `x == "bar"`. As discussed above, this isn't enough to +rule out `int`, but we know that `"foo" == "bar"` is false so we can eliminate `Literal["foo"]`. + +More examples follow. + +```py +from typing import Literal + +class C: + pass + +def _(x: Literal["foo", "bar", 42, b"foo"] | bool | complex): + match x: + case "foo": + reveal_type(x) # revealed: Literal["foo"] | int | float | complex + case 42: + reveal_type(x) # revealed: int | float | complex + case 6.0: + reveal_type(x) # revealed: Literal["bar", b"foo"] | (int & ~Literal[42]) | float | complex + case 1j: + reveal_type(x) # revealed: Literal["bar", b"foo"] | (int & ~Literal[42]) | float | complex + case b"foo": + reveal_type(x) # revealed: (int & ~Literal[42]) | Literal[b"foo"] | float | complex + case _: + reveal_type(x) # revealed: Literal["bar"] | (int & ~Literal[42]) | float | complex ``` ## Value patterns with guard ```py -def get_object() -> object: - return object() +from typing import Literal -x = get_object() +class C: + pass -reveal_type(x) # revealed: object - -match x: - case "foo" if reveal_type(x): # revealed: Literal["foo"] - pass - case 42 if reveal_type(x): # revealed: Literal[42] - pass - case 6.0 if reveal_type(x): # revealed: float - pass - case 1j if reveal_type(x): # revealed: complex - pass - case b"foo" if reveal_type(x): # revealed: Literal[b"foo"] - pass - -reveal_type(x) # revealed: object +def _(x: Literal["foo", b"bar"] | int): + match x: + case "foo" if reveal_type(x): # revealed: Literal["foo"] | int + pass + case b"bar" if reveal_type(x): # revealed: Literal[b"bar"] | int + pass + case 42 if reveal_type(x): # revealed: int + pass ``` ## Or patterns ```py -def get_object() -> object: - return object() +from typing import Literal +from enum import Enum -x = get_object() +class Color(Enum): + RED = 1 + GREEN = 2 + BLUE = 3 -reveal_type(x) # revealed: object +def _(color: Color): + match color: + case Color.RED | Color.GREEN: + reveal_type(color) # revealed: Literal[Color.RED, Color.GREEN] + case Color.BLUE: + reveal_type(color) # revealed: Literal[Color.BLUE] -match x: - case "foo" | 42 | None: - reveal_type(x) # revealed: Literal["foo", 42] | None - case "foo" | tuple(): - reveal_type(x) # revealed: tuple[Unknown, ...] - case True | False: - reveal_type(x) # revealed: bool - case 3.14 | 2.718 | 1.414: - reveal_type(x) # revealed: float + match color: + case Color.RED | Color.GREEN | Color.BLUE: + reveal_type(color) # revealed: Color -reveal_type(x) # revealed: object + match color: + case Color.RED: + reveal_type(color) # revealed: Literal[Color.RED] + case _: + reveal_type(color) # revealed: Literal[Color.GREEN, Color.BLUE] + +class A: ... +class B: ... +class C: ... + +def _(x: A | B | C): + match x: + case A() | B(): + reveal_type(x) # revealed: A | B + case C(): + reveal_type(x) # revealed: C & ~A & ~B + case _: + reveal_type(x) # revealed: Never + + match x: + case A() | B() | C(): + reveal_type(x) # revealed: A | B | C + case _: + reveal_type(x) # revealed: Never + + match x: + case A(): + reveal_type(x) # revealed: A + case _: + reveal_type(x) # revealed: (B & ~A) | (C & ~A) ``` ## Or patterns with guard ```py -def get_object() -> object: - return object() +from typing import Literal -x = get_object() - -reveal_type(x) # revealed: object - -match x: - case "foo" | 42 | None if reveal_type(x): # revealed: Literal["foo", 42] | None - pass - case "foo" | tuple() if reveal_type(x): # revealed: Literal["foo"] | tuple[Unknown, ...] - pass - case True | False if reveal_type(x): # revealed: bool - pass - case 3.14 | 2.718 | 1.414 if reveal_type(x): # revealed: float - pass - -reveal_type(x) # revealed: object +def _(x: Literal["foo", b"bar"] | int): + match x: + case "foo" | 42 if reveal_type(x): # revealed: Literal["foo"] | int + pass + case b"bar" if reveal_type(x): # revealed: Literal[b"bar"] | int + pass + case _ if reveal_type(x): # revealed: Literal["foo", b"bar"] | int + pass ``` ## Narrowing due to guard @@ -179,7 +221,7 @@ match x: case str() | float() if type(x) is str: reveal_type(x) # revealed: str case "foo" | 42 | None if isinstance(x, int): - reveal_type(x) # revealed: Literal[42] + reveal_type(x) # revealed: int case False if x: reveal_type(x) # revealed: Never case "foo" if x := "bar": @@ -201,7 +243,7 @@ reveal_type(x) # revealed: object match x: case str() | float() if type(x) is str and reveal_type(x): # revealed: str pass - case "foo" | 42 | None if isinstance(x, int) and reveal_type(x): # revealed: Literal[42] + case "foo" | 42 | None if isinstance(x, int) and reveal_type(x): # revealed: int pass case False if x and reveal_type(x): # revealed: Never pass diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index 20d505bcab..f725f32220 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -263,19 +263,19 @@ type NarrowingConstraints<'db> = FxHashMap>; fn merge_constraints_and<'db>( into: &mut NarrowingConstraints<'db>, - from: NarrowingConstraints<'db>, + from: &NarrowingConstraints<'db>, db: &'db dyn Db, ) { for (key, value) in from { - match into.entry(key) { + match into.entry(*key) { Entry::Occupied(mut entry) => { *entry.get_mut() = IntersectionBuilder::new(db) .add_positive(*entry.get()) - .add_positive(value) + .add_positive(*value) .build(); } Entry::Vacant(entry) => { - entry.insert(value); + entry.insert(*value); } } } @@ -303,12 +303,6 @@ fn merge_constraints_or<'db>( } } -fn negate_if<'db>(constraints: &mut NarrowingConstraints<'db>, db: &'db dyn Db, yes: bool) { - for (_place, ty) in constraints.iter_mut() { - *ty = ty.negate_if(db, yes); - } -} - fn place_expr(expr: &ast::Expr) -> Option { match expr { ast::Expr::Named(named) => PlaceExpr::try_from_expr(named.target.as_ref()), @@ -399,12 +393,14 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { ) -> Option> { match pattern_predicate_kind { PatternPredicateKind::Singleton(singleton) => { - self.evaluate_match_pattern_singleton(subject, *singleton) + self.evaluate_match_pattern_singleton(subject, *singleton, is_positive) } PatternPredicateKind::Class(cls, kind) => { self.evaluate_match_pattern_class(subject, *cls, *kind, is_positive) } - PatternPredicateKind::Value(expr) => self.evaluate_match_pattern_value(subject, *expr), + PatternPredicateKind::Value(expr) => { + self.evaluate_match_pattern_value(subject, *expr, is_positive) + } PatternPredicateKind::Or(predicates) => { self.evaluate_match_pattern_or(subject, predicates, is_positive) } @@ -420,12 +416,11 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { pattern: PatternPredicate<'db>, is_positive: bool, ) -> Option> { - let subject = pattern.subject(self.db); - self.evaluate_pattern_predicate_kind(pattern.kind(self.db), subject, is_positive) - .map(|mut constraints| { - negate_if(&mut constraints, self.db, !is_positive); - constraints - }) + self.evaluate_pattern_predicate_kind( + pattern.kind(self.db), + pattern.subject(self.db), + is_positive, + ) } fn places(&self) -> &'db PlaceTable { @@ -709,7 +704,10 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { lhs_ty: Type<'db>, rhs_ty: Type<'db>, op: ast::CmpOp, + is_positive: bool, ) -> Option> { + let op = if is_positive { op } else { op.negate() }; + match op { ast::CmpOp::IsNot => { if rhs_ty.is_singleton(self.db) { @@ -792,13 +790,12 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { | ast::Expr::Attribute(_) | ast::Expr::Subscript(_) | ast::Expr::Named(_) => { - if let Some(left) = place_expr(left) { - let op = if is_positive { *op } else { op.negate() }; - - if let Some(ty) = self.evaluate_expr_compare_op(lhs_ty, rhs_ty, op) { - let place = self.expect_place(&left); - constraints.insert(place, ty); - } + if let Some(left) = place_expr(left) + && let Some(ty) = + self.evaluate_expr_compare_op(lhs_ty, rhs_ty, *op, is_positive) + { + let place = self.expect_place(&left); + constraints.insert(place, ty); } } ast::Expr::Call(ast::ExprCall { @@ -954,6 +951,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { &mut self, subject: Expression<'db>, singleton: ast::Singleton, + is_positive: bool, ) -> Option> { let subject = place_expr(subject.node_ref(self.db, self.module))?; let place = self.expect_place(&subject); @@ -963,6 +961,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { ast::Singleton::True => Type::BooleanLiteral(true), ast::Singleton::False => Type::BooleanLiteral(false), }; + let ty = ty.negate_if(self.db, !is_positive); Some(NarrowingConstraints::from_iter([(place, ty)])) } @@ -986,6 +985,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { let ty = infer_same_file_expression_type(self.db, cls, TypeContext::default(), self.module) .to_instance(self.db)?; + let ty = ty.negate_if(self.db, !is_positive); Some(NarrowingConstraints::from_iter([(place, ty)])) } @@ -993,13 +993,20 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { &mut self, subject: Expression<'db>, value: Expression<'db>, + is_positive: bool, ) -> Option> { - let subject = place_expr(subject.node_ref(self.db, self.module))?; - let place = self.expect_place(&subject); + let place = { + let subject = place_expr(subject.node_ref(self.db, self.module))?; + self.expect_place(&subject) + }; + let subject_ty = + infer_same_file_expression_type(self.db, subject, TypeContext::default(), self.module); - let ty = + let value_ty = infer_same_file_expression_type(self.db, value, TypeContext::default(), self.module); - Some(NarrowingConstraints::from_iter([(place, ty)])) + + self.evaluate_expr_compare_op(subject_ty, value_ty, ast::CmpOp::Eq, is_positive) + .map(|ty| NarrowingConstraints::from_iter([(place, ty)])) } fn evaluate_match_pattern_or( @@ -1010,13 +1017,20 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { ) -> Option> { let db = self.db; + // DeMorgan's law---if the overall `or` is negated, we need to `and` the negated sub-constraints. + let merge_constraints = if is_positive { + merge_constraints_or + } else { + merge_constraints_and + }; + predicates .iter() .filter_map(|predicate| { self.evaluate_pattern_predicate_kind(predicate, subject, is_positive) }) .reduce(|mut constraints, constraints_| { - merge_constraints_or(&mut constraints, &constraints_, db); + merge_constraints(&mut constraints, &constraints_, db); constraints }) } @@ -1048,7 +1062,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { let mut aggregation: Option = None; for sub_constraint in sub_constraints.into_iter().flatten() { if let Some(ref mut some_aggregation) = aggregation { - merge_constraints_and(some_aggregation, sub_constraint, self.db); + merge_constraints_and(some_aggregation, &sub_constraint, self.db); } else { aggregation = Some(sub_constraint); }