diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 3eaee2fefe..e61a0f4843 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -144,14 +144,7 @@ pub(crate) fn definitions_ty<'db>( .expect("definitions_ty should never be called with zero definitions and no unbound_ty."); if let Some(second) = all_types.next() { - let mut builder = UnionBuilder::new(db); - builder = builder.add(first).add(second); - - for variant in all_types { - builder = builder.add(variant); - } - - builder.build() + UnionType::from_elements(db, [first, second].into_iter().chain(all_types)) } else { first } @@ -410,13 +403,7 @@ impl<'db> Type<'db> { fn iterate(&self, db: &'db dyn Db) -> IterationOutcome<'db> { if let Type::Tuple(tuple_type) = self { return IterationOutcome::Iterable { - element_ty: tuple_type - .elements(db) - .iter() - .fold(UnionBuilder::new(db), |builder, element| { - builder.add(*element) - }) - .build(), + element_ty: UnionType::from_elements(db, &**tuple_type.elements(db)), }; } @@ -497,6 +484,12 @@ impl<'db> Type<'db> { } } +impl<'db> From<&Type<'db>> for Type<'db> { + fn from(value: &Type<'db>) -> Self { + *value + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum IterationOutcome<'db> { Iterable { element_ty: Type<'db> }, @@ -636,19 +629,28 @@ impl<'db> UnionType<'db> { self.elements(db).contains(&ty) } + /// Create a union from a list of elements + /// (which may be eagerly simplified into a different variant of [`Type`] altogether) + pub fn from_elements>>( + db: &'db dyn Db, + elements: impl IntoIterator, + ) -> Type<'db> { + elements + .into_iter() + .fold(UnionBuilder::new(db), |builder, element| { + builder.add(element.into()) + }) + .build() + } + /// Apply a transformation function to all elements of the union, /// and create a new union from the resulting set of types pub fn map( &self, db: &'db dyn Db, - mut transform_fn: impl FnMut(&Type<'db>) -> Type<'db>, + transform_fn: impl Fn(&Type<'db>) -> Type<'db>, ) -> Type<'db> { - self.elements(db) - .into_iter() - .fold(UnionBuilder::new(db), |builder, element| { - builder.add(transform_fn(element)) - }) - .build() + Self::from_elements(db, self.elements(db).into_iter().map(transform_fn)) } } diff --git a/crates/red_knot_python_semantic/src/types/builder.rs b/crates/red_knot_python_semantic/src/types/builder.rs index c461459f05..0db9fee05a 100644 --- a/crates/red_knot_python_semantic/src/types/builder.rs +++ b/crates/red_knot_python_semantic/src/types/builder.rs @@ -169,11 +169,12 @@ impl<'db> IntersectionBuilder<'db> { if self.intersections.len() == 1 { self.intersections.pop().unwrap().build(self.db) } else { - let mut builder = UnionBuilder::new(self.db); - for inner in self.intersections { - builder = builder.add(inner.build(self.db)); - } - builder.build() + UnionType::from_elements( + self.db, + self.intersections + .into_iter() + .map(|inner| inner.build(self.db)), + ) } } } @@ -271,11 +272,11 @@ impl<'db> InnerIntersectionBuilder<'db> { #[cfg(test)] mod tests { - use super::{IntersectionBuilder, IntersectionType, Type, UnionBuilder, UnionType}; + use super::{IntersectionBuilder, IntersectionType, Type, UnionType}; use crate::db::tests::TestDb; use crate::program::{Program, SearchPathSettings}; use crate::python_version::PythonVersion; - use crate::types::builtins_symbol_ty; + use crate::types::{builtins_symbol_ty, UnionBuilder}; use crate::ProgramSettings; use ruff_db::system::{DbWithTestSystem, SystemPathBuf}; @@ -310,11 +311,7 @@ mod tests { let db = setup_db(); let t0 = Type::IntLiteral(0); let t1 = Type::IntLiteral(1); - let union = UnionBuilder::new(&db) - .add(t0) - .add(t1) - .build() - .expect_union(); + let union = UnionType::from_elements(&db, [t0, t1]).expect_union(); assert_eq!(union.elements_vec(&db), &[t0, t1]); } @@ -323,8 +320,7 @@ mod tests { fn build_union_single() { let db = setup_db(); let t0 = Type::IntLiteral(0); - let ty = UnionBuilder::new(&db).add(t0).build(); - + let ty = UnionType::from_elements(&db, [t0]); assert_eq!(ty, t0); } @@ -332,7 +328,6 @@ mod tests { fn build_union_empty() { let db = setup_db(); let ty = UnionBuilder::new(&db).build(); - assert_eq!(ty, Type::Never); } @@ -340,8 +335,7 @@ mod tests { fn build_union_never() { let db = setup_db(); let t0 = Type::IntLiteral(0); - let ty = UnionBuilder::new(&db).add(t0).add(Type::Never).build(); - + let ty = UnionType::from_elements(&db, [t0, Type::Never]); assert_eq!(ty, t0); } @@ -355,21 +349,10 @@ mod tests { let t2 = Type::BooleanLiteral(false); let t3 = Type::IntLiteral(17); - let union = UnionBuilder::new(&db) - .add(t0) - .add(t1) - .add(t3) - .build() - .expect_union(); + let union = UnionType::from_elements(&db, [t0, t1, t3]).expect_union(); assert_eq!(union.elements_vec(&db), &[t0, t3]); - let union = UnionBuilder::new(&db) - .add(t0) - .add(t1) - .add(t2) - .add(t3) - .build() - .expect_union(); + let union = UnionType::from_elements(&db, [t0, t1, t2, t3]).expect_union(); assert_eq!(union.elements_vec(&db), &[bool_ty, t3]); } @@ -379,12 +362,8 @@ mod tests { let t0 = Type::IntLiteral(0); let t1 = Type::IntLiteral(1); let t2 = Type::IntLiteral(2); - let u1 = UnionBuilder::new(&db).add(t0).add(t1).build(); - let union = UnionBuilder::new(&db) - .add(u1) - .add(t2) - .build() - .expect_union(); + let u1 = UnionType::from_elements(&db, [t0, t1]); + let union = UnionType::from_elements(&db, [u1, t2]).expect_union(); assert_eq!(union.elements_vec(&db), &[t0, t1, t2]); } @@ -460,7 +439,7 @@ mod tests { let t0 = Type::IntLiteral(0); let t1 = Type::IntLiteral(1); let ta = Type::Any; - let u0 = UnionBuilder::new(&db).add(t0).add(t1).build(); + let u0 = UnionType::from_elements(&db, [t0, t1]); let union = IntersectionBuilder::new(&db) .add_positive(ta) diff --git a/crates/red_knot_python_semantic/src/types/display.rs b/crates/red_knot_python_semantic/src/types/display.rs index 0d7e2ecf51..df398bc435 100644 --- a/crates/red_knot_python_semantic/src/types/display.rs +++ b/crates/red_knot_python_semantic/src/types/display.rs @@ -253,7 +253,7 @@ mod tests { use ruff_db::system::{DbWithTestSystem, SystemPathBuf}; use crate::db::tests::TestDb; - use crate::types::{global_symbol_ty, BytesLiteralType, StringLiteralType, Type, UnionBuilder}; + use crate::types::{global_symbol_ty, BytesLiteralType, StringLiteralType, Type, UnionType}; use crate::{Program, ProgramSettings, PythonVersion, SearchPathSettings}; fn setup_db() -> TestDb { @@ -295,7 +295,7 @@ mod tests { )?; let mod_file = system_path_to_file(&db, "src/main.py").expect("Expected file to exist."); - let vec: Vec> = vec![ + let union_elements = &[ Type::Unknown, Type::IntLiteral(-1), global_symbol_ty(&db, mod_file, "A"), @@ -311,10 +311,7 @@ mod tests { Type::BooleanLiteral(true), Type::None, ]; - let builder = vec.iter().fold(UnionBuilder::new(&db), |builder, literal| { - builder.add(*literal) - }); - let union = builder.build().expect_union(); + let union = UnionType::from_elements(&db, union_elements).expect_union(); let display = format!("{}", union.display(&db)); assert_eq!( display, diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 5df68b41cd..d9b88dddb5 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -49,7 +49,7 @@ use crate::stdlib::builtins_module_scope; use crate::types::diagnostic::{TypeCheckDiagnostic, TypeCheckDiagnostics}; use crate::types::{ builtins_symbol_ty, definitions_ty, global_symbol_ty, symbol_ty, BytesLiteralType, ClassType, - FunctionType, StringLiteralType, TupleType, Type, UnionBuilder, + FunctionType, StringLiteralType, TupleType, Type, UnionType, }; use crate::Db; @@ -1827,10 +1827,7 @@ impl<'db> TypeInferenceBuilder<'db> { let body_ty = self.infer_expression(body); let orelse_ty = self.infer_expression(orelse); - UnionBuilder::new(self.db) - .add(body_ty) - .add(orelse_ty) - .build() + UnionType::from_elements(self.db, [body_ty, orelse_ty]) } fn infer_lambda_body(&mut self, lambda_expression: &ast::ExprLambda) {