diff --git a/crates/red_knot_python_semantic/resources/mdtest/call/union.md b/crates/red_knot_python_semantic/resources/mdtest/call/union.md index 88d6ad9c37..10055283a8 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/call/union.md +++ b/crates/red_knot_python_semantic/resources/mdtest/call/union.md @@ -162,6 +162,31 @@ def _(flag: bool): reveal_type(f("string")) # revealed: Literal["string", "'string'"] ``` +## Unions with literals and negations + +```py +from typing import Literal, Union +from knot_extensions import Not, AlwaysFalsy, static_assert, is_subtype_of, is_assignable_to + +static_assert(is_subtype_of(Literal["a", ""], Union[Literal["a", ""], Not[AlwaysFalsy]])) +static_assert(is_subtype_of(Not[AlwaysFalsy], Union[Literal["", "a"], Not[AlwaysFalsy]])) +static_assert(is_subtype_of(Literal["a", ""], Union[Not[AlwaysFalsy], Literal["a", ""]])) +static_assert(is_subtype_of(Not[AlwaysFalsy], Union[Not[AlwaysFalsy], Literal["a", ""]])) + +static_assert(is_subtype_of(Literal["a", ""], Union[Literal["a", ""], Not[Literal[""]]])) +static_assert(is_subtype_of(Not[Literal[""]], Union[Literal["a", ""], Not[Literal[""]]])) +static_assert(is_subtype_of(Literal["a", ""], Union[Not[Literal[""]], Literal["a", ""]])) +static_assert(is_subtype_of(Not[Literal[""]], Union[Not[Literal[""]], Literal["a", ""]])) + +def _( + x: Union[Literal["a", ""], Not[AlwaysFalsy]], + y: Union[Literal["a", ""], Not[Literal[""]]], +): + reveal_type(x) # revealed: Literal[""] | ~AlwaysFalsy + # TODO should be `object` + reveal_type(y) # revealed: Literal[""] | ~Literal[""] +``` + ## Cannot use an argument as both a value and a type form ```py diff --git a/crates/red_knot_python_semantic/src/types/builder.rs b/crates/red_knot_python_semantic/src/types/builder.rs index 41d194652e..6bfecfe2a9 100644 --- a/crates/red_knot_python_semantic/src/types/builder.rs +++ b/crates/red_knot_python_semantic/src/types/builder.rs @@ -44,6 +44,40 @@ use crate::types::{ use crate::{Db, FxOrderSet}; use smallvec::SmallVec; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum LiteralKind { + Int, + String, + Bytes, +} + +impl<'db> Type<'db> { + /// Return `true` if this type can be a supertype of some literals of `kind` and not others. + fn splits_literals(self, db: &'db dyn Db, kind: LiteralKind) -> bool { + match (self, kind) { + (Type::AlwaysFalsy | Type::AlwaysTruthy, _) => true, + (Type::StringLiteral(_), LiteralKind::String) => true, + (Type::BytesLiteral(_), LiteralKind::Bytes) => true, + (Type::IntLiteral(_), LiteralKind::Int) => true, + (Type::Intersection(intersection), _) => { + intersection + .positive(db) + .iter() + .any(|ty| ty.splits_literals(db, kind)) + || intersection + .negative(db) + .iter() + .any(|ty| ty.splits_literals(db, kind)) + } + (Type::Union(union), _) => union + .elements(db) + .iter() + .any(|ty| ty.splits_literals(db, kind)), + _ => false, + } + } +} + enum UnionElement<'db> { IntLiterals(FxOrderSet), StringLiterals(FxOrderSet>), @@ -61,12 +95,9 @@ impl<'db> UnionElement<'db> { /// If this `UnionElement` is some other type, return `ReduceResult::Type` so `UnionBuilder` /// can perform more complex checks on it. fn try_reduce(&mut self, db: &'db dyn Db, other_type: Type<'db>) -> ReduceResult<'db> { - // `AlwaysTruthy` and `AlwaysFalsy` are the only types which can be a supertype of only - // _some_ literals of the same kind, so we need to walk the full set in this case. - let needs_filter = matches!(other_type, Type::AlwaysTruthy | Type::AlwaysFalsy); match self { UnionElement::IntLiterals(literals) => { - ReduceResult::KeepIf(if needs_filter { + ReduceResult::KeepIf(if other_type.splits_literals(db, LiteralKind::Int) { literals.retain(|literal| { !Type::IntLiteral(*literal).is_subtype_of(db, other_type) }); @@ -77,7 +108,7 @@ impl<'db> UnionElement<'db> { }) } UnionElement::StringLiterals(literals) => { - ReduceResult::KeepIf(if needs_filter { + ReduceResult::KeepIf(if other_type.splits_literals(db, LiteralKind::String) { literals.retain(|literal| { !Type::StringLiteral(*literal).is_subtype_of(db, other_type) }); @@ -88,7 +119,7 @@ impl<'db> UnionElement<'db> { }) } UnionElement::BytesLiterals(literals) => { - ReduceResult::KeepIf(if needs_filter { + ReduceResult::KeepIf(if other_type.splits_literals(db, LiteralKind::Bytes) { literals.retain(|literal| { !Type::BytesLiteral(*literal).is_subtype_of(db, other_type) });