[red-knot] Add a convenience method for constructing a union from a list of elements (#13315)

This commit is contained in:
Alex Waygood 2024-09-10 17:38:56 -04:00 committed by GitHub
parent acab1f4fd8
commit e6b927a583
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 45 additions and 70 deletions

View file

@ -144,14 +144,7 @@ pub(crate) fn definitions_ty<'db>(
.expect("definitions_ty should never be called with zero definitions and no unbound_ty.");
if let Some(second) = all_types.next() {
let mut builder = UnionBuilder::new(db);
builder = builder.add(first).add(second);
for variant in all_types {
builder = builder.add(variant);
}
builder.build()
UnionType::from_elements(db, [first, second].into_iter().chain(all_types))
} else {
first
}
@ -410,13 +403,7 @@ impl<'db> Type<'db> {
fn iterate(&self, db: &'db dyn Db) -> IterationOutcome<'db> {
if let Type::Tuple(tuple_type) = self {
return IterationOutcome::Iterable {
element_ty: tuple_type
.elements(db)
.iter()
.fold(UnionBuilder::new(db), |builder, element| {
builder.add(*element)
})
.build(),
element_ty: UnionType::from_elements(db, &**tuple_type.elements(db)),
};
}
@ -497,6 +484,12 @@ impl<'db> Type<'db> {
}
}
impl<'db> From<&Type<'db>> for Type<'db> {
fn from(value: &Type<'db>) -> Self {
*value
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum IterationOutcome<'db> {
Iterable { element_ty: Type<'db> },
@ -636,19 +629,28 @@ impl<'db> UnionType<'db> {
self.elements(db).contains(&ty)
}
/// Create a union from a list of elements
/// (which may be eagerly simplified into a different variant of [`Type`] altogether)
pub fn from_elements<T: Into<Type<'db>>>(
db: &'db dyn Db,
elements: impl IntoIterator<Item = T>,
) -> Type<'db> {
elements
.into_iter()
.fold(UnionBuilder::new(db), |builder, element| {
builder.add(element.into())
})
.build()
}
/// Apply a transformation function to all elements of the union,
/// and create a new union from the resulting set of types
pub fn map(
&self,
db: &'db dyn Db,
mut transform_fn: impl FnMut(&Type<'db>) -> Type<'db>,
transform_fn: impl Fn(&Type<'db>) -> Type<'db>,
) -> Type<'db> {
self.elements(db)
.into_iter()
.fold(UnionBuilder::new(db), |builder, element| {
builder.add(transform_fn(element))
})
.build()
Self::from_elements(db, self.elements(db).into_iter().map(transform_fn))
}
}

View file

@ -169,11 +169,12 @@ impl<'db> IntersectionBuilder<'db> {
if self.intersections.len() == 1 {
self.intersections.pop().unwrap().build(self.db)
} else {
let mut builder = UnionBuilder::new(self.db);
for inner in self.intersections {
builder = builder.add(inner.build(self.db));
}
builder.build()
UnionType::from_elements(
self.db,
self.intersections
.into_iter()
.map(|inner| inner.build(self.db)),
)
}
}
}
@ -271,11 +272,11 @@ impl<'db> InnerIntersectionBuilder<'db> {
#[cfg(test)]
mod tests {
use super::{IntersectionBuilder, IntersectionType, Type, UnionBuilder, UnionType};
use super::{IntersectionBuilder, IntersectionType, Type, UnionType};
use crate::db::tests::TestDb;
use crate::program::{Program, SearchPathSettings};
use crate::python_version::PythonVersion;
use crate::types::builtins_symbol_ty;
use crate::types::{builtins_symbol_ty, UnionBuilder};
use crate::ProgramSettings;
use ruff_db::system::{DbWithTestSystem, SystemPathBuf};
@ -310,11 +311,7 @@ mod tests {
let db = setup_db();
let t0 = Type::IntLiteral(0);
let t1 = Type::IntLiteral(1);
let union = UnionBuilder::new(&db)
.add(t0)
.add(t1)
.build()
.expect_union();
let union = UnionType::from_elements(&db, [t0, t1]).expect_union();
assert_eq!(union.elements_vec(&db), &[t0, t1]);
}
@ -323,8 +320,7 @@ mod tests {
fn build_union_single() {
let db = setup_db();
let t0 = Type::IntLiteral(0);
let ty = UnionBuilder::new(&db).add(t0).build();
let ty = UnionType::from_elements(&db, [t0]);
assert_eq!(ty, t0);
}
@ -332,7 +328,6 @@ mod tests {
fn build_union_empty() {
let db = setup_db();
let ty = UnionBuilder::new(&db).build();
assert_eq!(ty, Type::Never);
}
@ -340,8 +335,7 @@ mod tests {
fn build_union_never() {
let db = setup_db();
let t0 = Type::IntLiteral(0);
let ty = UnionBuilder::new(&db).add(t0).add(Type::Never).build();
let ty = UnionType::from_elements(&db, [t0, Type::Never]);
assert_eq!(ty, t0);
}
@ -355,21 +349,10 @@ mod tests {
let t2 = Type::BooleanLiteral(false);
let t3 = Type::IntLiteral(17);
let union = UnionBuilder::new(&db)
.add(t0)
.add(t1)
.add(t3)
.build()
.expect_union();
let union = UnionType::from_elements(&db, [t0, t1, t3]).expect_union();
assert_eq!(union.elements_vec(&db), &[t0, t3]);
let union = UnionBuilder::new(&db)
.add(t0)
.add(t1)
.add(t2)
.add(t3)
.build()
.expect_union();
let union = UnionType::from_elements(&db, [t0, t1, t2, t3]).expect_union();
assert_eq!(union.elements_vec(&db), &[bool_ty, t3]);
}
@ -379,12 +362,8 @@ mod tests {
let t0 = Type::IntLiteral(0);
let t1 = Type::IntLiteral(1);
let t2 = Type::IntLiteral(2);
let u1 = UnionBuilder::new(&db).add(t0).add(t1).build();
let union = UnionBuilder::new(&db)
.add(u1)
.add(t2)
.build()
.expect_union();
let u1 = UnionType::from_elements(&db, [t0, t1]);
let union = UnionType::from_elements(&db, [u1, t2]).expect_union();
assert_eq!(union.elements_vec(&db), &[t0, t1, t2]);
}
@ -460,7 +439,7 @@ mod tests {
let t0 = Type::IntLiteral(0);
let t1 = Type::IntLiteral(1);
let ta = Type::Any;
let u0 = UnionBuilder::new(&db).add(t0).add(t1).build();
let u0 = UnionType::from_elements(&db, [t0, t1]);
let union = IntersectionBuilder::new(&db)
.add_positive(ta)

View file

@ -253,7 +253,7 @@ mod tests {
use ruff_db::system::{DbWithTestSystem, SystemPathBuf};
use crate::db::tests::TestDb;
use crate::types::{global_symbol_ty, BytesLiteralType, StringLiteralType, Type, UnionBuilder};
use crate::types::{global_symbol_ty, BytesLiteralType, StringLiteralType, Type, UnionType};
use crate::{Program, ProgramSettings, PythonVersion, SearchPathSettings};
fn setup_db() -> TestDb {
@ -295,7 +295,7 @@ mod tests {
)?;
let mod_file = system_path_to_file(&db, "src/main.py").expect("Expected file to exist.");
let vec: Vec<Type<'_>> = vec![
let union_elements = &[
Type::Unknown,
Type::IntLiteral(-1),
global_symbol_ty(&db, mod_file, "A"),
@ -311,10 +311,7 @@ mod tests {
Type::BooleanLiteral(true),
Type::None,
];
let builder = vec.iter().fold(UnionBuilder::new(&db), |builder, literal| {
builder.add(*literal)
});
let union = builder.build().expect_union();
let union = UnionType::from_elements(&db, union_elements).expect_union();
let display = format!("{}", union.display(&db));
assert_eq!(
display,

View file

@ -49,7 +49,7 @@ use crate::stdlib::builtins_module_scope;
use crate::types::diagnostic::{TypeCheckDiagnostic, TypeCheckDiagnostics};
use crate::types::{
builtins_symbol_ty, definitions_ty, global_symbol_ty, symbol_ty, BytesLiteralType, ClassType,
FunctionType, StringLiteralType, TupleType, Type, UnionBuilder,
FunctionType, StringLiteralType, TupleType, Type, UnionType,
};
use crate::Db;
@ -1827,10 +1827,7 @@ impl<'db> TypeInferenceBuilder<'db> {
let body_ty = self.infer_expression(body);
let orelse_ty = self.infer_expression(orelse);
UnionBuilder::new(self.db)
.add(body_ty)
.add(orelse_ty)
.build()
UnionType::from_elements(self.db, [body_ty, orelse_ty])
}
fn infer_lambda_body(&mut self, lambda_expression: &ast::ExprLambda) {