[red-knot] flatten unions (#11783)

Flatten union types. Fixes #11781
This commit is contained in:
Carl Meyer 2024-06-06 16:13:40 -06:00 committed by GitHub
parent 93eefb1417
commit b2fc0df6db
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 33 additions and 1 deletions

View file

@ -234,7 +234,14 @@ impl TypeStore {
}
fn add_union(&self, file_id: FileId, elems: &[Type]) -> UnionTypeId {
self.add_or_get_module(file_id).add_union(elems)
let mut flattened = Vec::with_capacity(elems.len());
for ty in elems {
match ty {
Type::Union(union_id) => flattened.extend(union_id.elements(self)),
_ => flattened.push(*ty),
}
}
self.add_or_get_module(file_id).add_union(&flattened)
}
fn add_intersection(
@ -520,6 +527,13 @@ pub struct UnionTypeId {
union_id: ModuleUnionTypeId,
}
impl UnionTypeId {
pub fn elements(self, type_store: &TypeStore) -> Vec<Type> {
let union = type_store.get_union(self);
union.elements.iter().copied().collect()
}
}
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
pub struct IntersectionTypeId {
file_id: FileId,

View file

@ -608,4 +608,22 @@ mod tests {
assert_public_type(&case, "a", "a", "(Literal[1] | Literal[2])")
}
#[test]
fn ifexpr_nested() -> anyhow::Result<()> {
let case = create_test()?;
write_to_path(
&case,
"a.py",
"
class C1: pass
class C2: pass
class C3: pass
x = C1 if flag else C2 if flag2 else C3
",
)?;
assert_public_type(&case, "a", "x", "(Literal[C1] | Literal[C2] | Literal[C3])")
}
}