diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/in.md b/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/in.md index 865eb48788..6090c9c4ac 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/in.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/in.md @@ -92,3 +92,106 @@ if (x := f()) in (1,): else: reveal_type(x) # revealed: Literal[2, 3] ``` + +## Union with `Literal`, `None` and `int` + +```py +from typing import Literal + +def test(x: Literal["a", "b", "c"] | None | int = None): + if x in ("a", "b"): + # int is included because custom __eq__ methods could make + # an int equal to "a" or "b", so we can't eliminate it + reveal_type(x) # revealed: Literal["a", "b"] | int + else: + reveal_type(x) # revealed: Literal["c"] | None | int +``` + +## Direct `not in` conditional + +```py +from typing import Literal + +def test(x: Literal["a", "b", "c"] | None | int = None): + if x not in ("a", "c"): + # int is included because custom __eq__ methods could make + # an int equal to "a" or "b", so we can't eliminate it + reveal_type(x) # revealed: Literal["b"] | None | int + else: + reveal_type(x) # revealed: Literal["a", "c"] | int +``` + +## bool + +```py +def _(x: bool): + if x in (True,): + reveal_type(x) # revealed: Literal[True] + else: + reveal_type(x) # revealed: Literal[False] + +def _(x: bool | str): + if x in (False,): + # `str` remains due to possible custom __eq__ methods on a subclass + reveal_type(x) # revealed: Literal[False] | str + else: + reveal_type(x) # revealed: Literal[True] | str +``` + +## LiteralString + +```py +from typing_extensions import LiteralString + +def _(x: LiteralString): + if x in ("a", "b", "c"): + reveal_type(x) # revealed: Literal["a", "b", "c"] + else: + reveal_type(x) # revealed: LiteralString & ~Literal["a"] & ~Literal["b"] & ~Literal["c"] + +def _(x: LiteralString | int): + if x in ("a", "b", "c"): + reveal_type(x) # revealed: Literal["a", "b", "c"] | int + else: + reveal_type(x) # revealed: (LiteralString & ~Literal["a"] & ~Literal["b"] & ~Literal["c"]) | int +``` + +## enums + +```py +from enum import Enum + +class Color(Enum): + RED = "red" + GREEN = "green" + BLUE = "blue" + +def _(x: Color): + if x in (Color.RED, Color.GREEN): + # TODO should be `Literal[Color.RED, Color.GREEN]` + reveal_type(x) # revealed: Color + else: + # TODO should be `Literal[Color.BLUE]` + reveal_type(x) # revealed: Color +``` + +## Union with enum and `int` + +```py +from enum import Enum + +class Status(Enum): + PENDING = 1 + APPROVED = 2 + REJECTED = 3 + +def test(x: Status | int): + if x in (Status.PENDING, Status.APPROVED): + # TODO should be `Literal[Status.PENDING, Status.APPROVED] | int` + # int is included because custom __eq__ methods could make + # an int equal to Status.PENDING or Status.APPROVED, so we can't eliminate it + reveal_type(x) # revealed: Status | int + else: + # TODO should be `Literal[Status.REJECTED] | int` + reveal_type(x) # revealed: Status | int +``` diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 23f4422867..6f62b815db 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -1054,6 +1054,16 @@ impl<'db> Type<'db> { || self.is_literal_string() } + pub(crate) fn is_union_with_single_valued(&self, db: &'db dyn Db) -> bool { + self.into_union().is_some_and(|union| { + union + .elements(db) + .iter() + .any(|ty| ty.is_single_valued(db) || ty.is_bool(db) || ty.is_literal_string()) + }) || self.is_bool(db) + || self.is_literal_string() + } + pub(crate) fn into_string_literal(self) -> Option> { match self { Type::StringLiteral(string_literal) => Some(string_literal), @@ -9953,14 +9963,6 @@ impl<'db> StringLiteralType<'db> { pub(crate) fn python_len(self, db: &'db dyn Db) -> usize { self.value(db).chars().count() } - - /// Return an iterator over each character in the string literal. - /// as would be returned by Python's `iter()`. - pub(crate) fn iter_each_char(self, db: &'db dyn Db) -> impl Iterator { - self.value(db) - .chars() - .map(|c| StringLiteralType::new(db, c.to_string().into_boxed_str())) - } } /// # Ordering diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index 46a0b5a8f5..fcfae4851a 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -615,24 +615,88 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { } } + // TODO `expr_in` and `expr_not_in` should perhaps be unified with `expr_eq` and `expr_ne`, + // since `eq` and `ne` are equivalent to `in` and `not in` with only one element in the RHS. fn evaluate_expr_in(&mut self, lhs_ty: Type<'db>, rhs_ty: Type<'db>) -> Option> { if lhs_ty.is_single_valued(self.db) || lhs_ty.is_union_of_single_valued(self.db) { - if let Type::StringLiteral(string_literal) = rhs_ty { - Some(UnionType::from_elements( - self.db, - string_literal - .iter_each_char(self.db) - .map(Type::StringLiteral), - )) - } else if let Some(tuple_spec) = rhs_ty.tuple_instance_spec(self.db) { - // N.B. Strictly speaking this is unsound, since a tuple subclass might override `__contains__` - // but we'd still apply the narrowing here. This seems unlikely, however, and narrowing is - // generally unsound in numerous ways anyway (attribute narrowing, subscript, narrowing, - // narrowing of globals, etc.). So this doesn't seem worth worrying about too much. - Some(UnionType::from_elements(self.db, tuple_spec.all_elements())) - } else { - None + rhs_ty + .try_iterate(self.db) + .ok() + .map(|iterable| iterable.homogeneous_element_type(self.db)) + } else if lhs_ty.is_union_with_single_valued(self.db) { + let rhs_values = rhs_ty + .try_iterate(self.db) + .ok()? + .homogeneous_element_type(self.db); + + let mut builder = UnionBuilder::new(self.db); + + // Add the narrowed values from the RHS first, to keep literals before broader types. + builder = builder.add(rhs_values); + + if let Some(lhs_union) = lhs_ty.into_union() { + for element in lhs_union.elements(self.db) { + // Keep only the non-single-valued portion of the original type. + if !element.is_single_valued(self.db) + && !element.is_literal_string() + && !element.is_bool(self.db) + { + builder = builder.add(*element); + } + } } + Some(builder.build()) + } else { + None + } + } + + fn evaluate_expr_not_in(&mut self, lhs_ty: Type<'db>, rhs_ty: Type<'db>) -> Option> { + let rhs_values = rhs_ty + .try_iterate(self.db) + .ok()? + .homogeneous_element_type(self.db); + + if lhs_ty.is_single_valued(self.db) || lhs_ty.is_union_of_single_valued(self.db) { + // Exclude the RHS values from the entire (single-valued) LHS domain. + let complement = IntersectionBuilder::new(self.db) + .add_positive(lhs_ty) + .add_negative(rhs_values) + .build(); + Some(complement) + } else if lhs_ty.is_union_with_single_valued(self.db) { + // Split LHS into single-valued portion and the rest. Exclude RHS values from the + // single-valued portion, keep the rest intact. + let mut single_builder = UnionBuilder::new(self.db); + let mut rest_builder = UnionBuilder::new(self.db); + + if let Some(lhs_union) = lhs_ty.into_union() { + for element in lhs_union.elements(self.db) { + if element.is_single_valued(self.db) + || element.is_literal_string() + || element.is_bool(self.db) + { + single_builder = single_builder.add(*element); + } else { + rest_builder = rest_builder.add(*element); + } + } + } + + let single_union = single_builder.build(); + let rest_union = rest_builder.build(); + + let narrowed_single = IntersectionBuilder::new(self.db) + .add_positive(single_union) + .add_negative(rhs_values) + .build(); + + // Keep order: first literal complement, then broader arms. + let result = UnionBuilder::new(self.db) + .add(narrowed_single) + .add(rest_union) + .build(); + Some(result) } else { None } @@ -660,9 +724,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { ast::CmpOp::Eq => self.evaluate_expr_eq(lhs_ty, rhs_ty), ast::CmpOp::NotEq => self.evaluate_expr_ne(lhs_ty, rhs_ty), ast::CmpOp::In => self.evaluate_expr_in(lhs_ty, rhs_ty), - ast::CmpOp::NotIn => self - .evaluate_expr_in(lhs_ty, rhs_ty) - .map(|ty| ty.negate(self.db)), + ast::CmpOp::NotIn => self.evaluate_expr_not_in(lhs_ty, rhs_ty), _ => None, } }