From b2fc0df6db3d49e67678b64f571b270af99ac34d Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Thu, 6 Jun 2024 16:13:40 -0600 Subject: [PATCH] [red-knot] flatten unions (#11783) Flatten union types. Fixes #11781 --- crates/red_knot/src/semantic/types.rs | 16 +++++++++++++++- crates/red_knot/src/semantic/types/infer.rs | 18 ++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/crates/red_knot/src/semantic/types.rs b/crates/red_knot/src/semantic/types.rs index 8db0d7da9c..73ba85c081 100644 --- a/crates/red_knot/src/semantic/types.rs +++ b/crates/red_knot/src/semantic/types.rs @@ -234,7 +234,14 @@ impl TypeStore { } fn add_union(&self, file_id: FileId, elems: &[Type]) -> UnionTypeId { - self.add_or_get_module(file_id).add_union(elems) + let mut flattened = Vec::with_capacity(elems.len()); + for ty in elems { + match ty { + Type::Union(union_id) => flattened.extend(union_id.elements(self)), + _ => flattened.push(*ty), + } + } + self.add_or_get_module(file_id).add_union(&flattened) } fn add_intersection( @@ -520,6 +527,13 @@ pub struct UnionTypeId { union_id: ModuleUnionTypeId, } +impl UnionTypeId { + pub fn elements(self, type_store: &TypeStore) -> Vec { + let union = type_store.get_union(self); + union.elements.iter().copied().collect() + } +} + #[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)] pub struct IntersectionTypeId { file_id: FileId, diff --git a/crates/red_knot/src/semantic/types/infer.rs b/crates/red_knot/src/semantic/types/infer.rs index 3e42ad0388..8cb14f49cd 100644 --- a/crates/red_knot/src/semantic/types/infer.rs +++ b/crates/red_knot/src/semantic/types/infer.rs @@ -608,4 +608,22 @@ mod tests { assert_public_type(&case, "a", "a", "(Literal[1] | Literal[2])") } + + #[test] + fn ifexpr_nested() -> anyhow::Result<()> { + let case = create_test()?; + + write_to_path( + &case, + "a.py", + " + class C1: pass + class C2: pass + class C3: pass + x = C1 if flag else C2 if flag2 else C3 + ", + )?; + + assert_public_type(&case, "a", "x", "(Literal[C1] | Literal[C2] | Literal[C3])") + } }