[red-knot] Reduce some repetitiveness in tests (#13135)

This commit is contained in:
Alex Waygood 2024-09-03 11:26:44 +01:00 committed by GitHub
parent facf6febf0
commit 9d517061f2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 136 additions and 97 deletions

View file

@ -152,9 +152,9 @@ pub(crate) fn definitions_ty<'db>(
); );
let mut all_types = unbound_ty.into_iter().chain(def_types); let mut all_types = unbound_ty.into_iter().chain(def_types);
let Some(first) = all_types.next() else { let first = all_types
panic!("definitions_ty should never be called with zero definitions and no unbound_ty.") .next()
}; .expect("definitions_ty should never be called with zero definitions and no unbound_ty.");
if let Some(second) = all_types.next() { if let Some(second) = all_types.next() {
let mut builder = UnionBuilder::new(db); let mut builder = UnionBuilder::new(db);
@ -171,7 +171,7 @@ pub(crate) fn definitions_ty<'db>(
} }
/// Unique ID for a type. /// Unique ID for a type.
#[derive(Copy, Clone, Debug, PartialOrd, Ord, PartialEq, Eq, Hash)] #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub enum Type<'db> { pub enum Type<'db> {
/// the dynamic type: a statically-unknown set of values /// the dynamic type: a statically-unknown set of values
Any, Any,
@ -216,10 +216,6 @@ impl<'db> Type<'db> {
matches!(self, Type::Unbound) matches!(self, Type::Unbound)
} }
pub const fn is_unknown(&self) -> bool {
matches!(self, Type::Unknown)
}
pub const fn is_never(&self) -> bool { pub const fn is_never(&self) -> bool {
matches!(self, Type::Never) matches!(self, Type::Never)
} }
@ -237,6 +233,78 @@ impl<'db> Type<'db> {
) )
} }
pub const fn into_class_type(self) -> Option<ClassType<'db>> {
match self {
Type::Class(class_type) => Some(class_type),
_ => None,
}
}
pub fn expect_class(self) -> ClassType<'db> {
self.into_class_type()
.expect("Expected a Type::Class variant")
}
pub const fn into_module_type(self) -> Option<File> {
match self {
Type::Module(file) => Some(file),
_ => None,
}
}
pub fn expect_module(self) -> File {
self.into_module_type()
.expect("Expected a Type::Module variant")
}
pub const fn into_union_type(self) -> Option<UnionType<'db>> {
match self {
Type::Union(union_type) => Some(union_type),
_ => None,
}
}
pub fn expect_union(self) -> UnionType<'db> {
self.into_union_type()
.expect("Expected a Type::Union variant")
}
pub const fn into_intersection_type(self) -> Option<IntersectionType<'db>> {
match self {
Type::Intersection(intersection_type) => Some(intersection_type),
_ => None,
}
}
pub fn expect_intersection(self) -> IntersectionType<'db> {
self.into_intersection_type()
.expect("Expected a Type::Intersection variant")
}
pub const fn into_function_type(self) -> Option<FunctionType<'db>> {
match self {
Type::Function(function_type) => Some(function_type),
_ => None,
}
}
pub fn expect_function(self) -> FunctionType<'db> {
self.into_function_type()
.expect("Expected a Type::Function variant")
}
pub const fn into_int_literal_type(self) -> Option<i64> {
match self {
Type::IntLiteral(value) => Some(value),
_ => None,
}
}
pub fn expect_int_literal(self) -> i64 {
self.into_int_literal_type()
.expect("Expected a Type::IntLiteral variant")
}
pub fn may_be_unbound(&self, db: &'db dyn Db) -> bool { pub fn may_be_unbound(&self, db: &'db dyn Db) -> bool {
match self { match self {
Type::Unbound => true, Type::Unbound => true,
@ -361,7 +429,7 @@ impl<'db> Type<'db> {
} }
#[must_use] #[must_use]
pub fn instance(&self) -> Type<'db> { pub fn to_instance(&self) -> Type<'db> {
match self { match self {
Type::Any => Type::Any, Type::Any => Type::Any,
Type::Unknown => Type::Unknown, Type::Unknown => Type::Unknown,

View file

@ -313,9 +313,11 @@ mod tests {
let db = setup_db(); let db = setup_db();
let t0 = Type::IntLiteral(0); let t0 = Type::IntLiteral(0);
let t1 = Type::IntLiteral(1); let t1 = Type::IntLiteral(1);
let Type::Union(union) = UnionBuilder::new(&db).add(t0).add(t1).build() else { let union = UnionBuilder::new(&db)
panic!("expected a union"); .add(t0)
}; .add(t1)
.build()
.expect_union();
assert_eq!(union.elements_vec(&db), &[t0, t1]); assert_eq!(union.elements_vec(&db), &[t0, t1]);
} }
@ -356,19 +358,20 @@ mod tests {
let t2 = Type::BooleanLiteral(false); let t2 = Type::BooleanLiteral(false);
let t3 = Type::IntLiteral(17); let t3 = Type::IntLiteral(17);
let Type::Union(union) = UnionBuilder::new(&db).add(t0).add(t1).add(t3).build() else { let union = UnionBuilder::new(&db)
panic!("expected a union"); .add(t0)
}; .add(t1)
.add(t3)
.build()
.expect_union();
assert_eq!(union.elements_vec(&db), &[t0, t3]); assert_eq!(union.elements_vec(&db), &[t0, t3]);
let Type::Union(union) = UnionBuilder::new(&db) let union = UnionBuilder::new(&db)
.add(t0) .add(t0)
.add(t1) .add(t1)
.add(t2) .add(t2)
.add(t3) .add(t3)
.build() .build()
else { .expect_union();
panic!("expected a union");
};
assert_eq!(union.elements_vec(&db), &[bool_ty, t3]); assert_eq!(union.elements_vec(&db), &[bool_ty, t3]);
} }
@ -380,9 +383,11 @@ mod tests {
let t1 = Type::IntLiteral(1); let t1 = Type::IntLiteral(1);
let t2 = Type::IntLiteral(2); let t2 = Type::IntLiteral(2);
let u1 = UnionBuilder::new(&db).add(t0).add(t1).build(); let u1 = UnionBuilder::new(&db).add(t0).add(t1).build();
let Type::Union(union) = UnionBuilder::new(&db).add(u1).add(t2).build() else { let union = UnionBuilder::new(&db)
panic!("expected a union"); .add(u1)
}; .add(t2)
.build()
.expect_union();
assert_eq!(union.elements_vec(&db), &[t0, t1, t2]); assert_eq!(union.elements_vec(&db), &[t0, t1, t2]);
} }
@ -402,16 +407,14 @@ mod tests {
let db = setup_db(); let db = setup_db();
let t0 = Type::IntLiteral(0); let t0 = Type::IntLiteral(0);
let ta = Type::Any; let ta = Type::Any;
let Type::Intersection(inter) = IntersectionBuilder::new(&db) let intersection = IntersectionBuilder::new(&db)
.add_positive(ta) .add_positive(ta)
.add_negative(t0) .add_negative(t0)
.build() .build()
else { .expect_intersection();
panic!("expected to be an intersection");
};
assert_eq!(inter.pos_vec(&db), &[ta]); assert_eq!(intersection.pos_vec(&db), &[ta]);
assert_eq!(inter.neg_vec(&db), &[t0]); assert_eq!(intersection.neg_vec(&db), &[t0]);
} }
#[test] #[test]
@ -424,16 +427,14 @@ mod tests {
.add_positive(ta) .add_positive(ta)
.add_negative(t1) .add_negative(t1)
.build(); .build();
let Type::Intersection(inter) = IntersectionBuilder::new(&db) let intersection = IntersectionBuilder::new(&db)
.add_positive(t2) .add_positive(t2)
.add_positive(i0) .add_positive(i0)
.build() .build()
else { .expect_intersection();
panic!("expected to be an intersection");
};
assert_eq!(inter.pos_vec(&db), &[t2, ta]); assert_eq!(intersection.pos_vec(&db), &[t2, ta]);
assert_eq!(inter.neg_vec(&db), &[t1]); assert_eq!(intersection.neg_vec(&db), &[t1]);
} }
#[test] #[test]
@ -446,16 +447,14 @@ mod tests {
.add_positive(ta) .add_positive(ta)
.add_negative(t1) .add_negative(t1)
.build(); .build();
let Type::Intersection(inter) = IntersectionBuilder::new(&db) let intersection = IntersectionBuilder::new(&db)
.add_positive(t2) .add_positive(t2)
.add_negative(i0) .add_negative(i0)
.build() .build()
else { .expect_intersection();
panic!("expected to be an intersection");
};
assert_eq!(inter.pos_vec(&db), &[t2, t1]); assert_eq!(intersection.pos_vec(&db), &[t2, t1]);
assert_eq!(inter.neg_vec(&db), &[ta]); assert_eq!(intersection.neg_vec(&db), &[ta]);
} }
#[test] #[test]
@ -466,13 +465,11 @@ mod tests {
let ta = Type::Any; let ta = Type::Any;
let u0 = UnionBuilder::new(&db).add(t0).add(t1).build(); let u0 = UnionBuilder::new(&db).add(t0).add(t1).build();
let Type::Union(union) = IntersectionBuilder::new(&db) let union = IntersectionBuilder::new(&db)
.add_positive(ta) .add_positive(ta)
.add_positive(u0) .add_positive(u0)
.build() .build()
else { .expect_union();
panic!("expected a union");
};
let [Type::Intersection(i0), Type::Intersection(i1)] = union.elements_vec(&db)[..] else { let [Type::Intersection(i0), Type::Intersection(i1)] = union.elements_vec(&db)[..] else {
panic!("expected a union of two intersections"); panic!("expected a union of two intersections");
}; };

View file

@ -125,10 +125,7 @@ impl Display for DisplayUnionType<'_> {
f.write_str("Literal[")?; f.write_str("Literal[")?;
if literal_kind == LiteralTypeKind::IntLiteral { if literal_kind == LiteralTypeKind::IntLiteral {
literals.sort_unstable_by_key(|ty| match ty { literals.sort_unstable_by_key(|ty| ty.expect_int_literal());
Type::IntLiteral(n) => *n,
_ => panic!("Expected only int literals when kind is IntLiteral"),
});
} }
for (i, literal_ty) in literals.iter().enumerate() { for (i, literal_ty) in literals.iter().enumerate() {
@ -294,9 +291,7 @@ mod tests {
let builder = vec.iter().fold(UnionBuilder::new(&db), |builder, literal| { let builder = vec.iter().fold(UnionBuilder::new(&db), |builder, literal| {
builder.add(*literal) builder.add(*literal)
}); });
let Type::Union(union) = builder.build() else { let union = builder.build().expect_union();
panic!("expected a union");
};
let display = format!("{}", union.display(&db)); let display = format!("{}", union.display(&db));
assert_eq!( assert_eq!(
display, display,

View file

@ -463,9 +463,10 @@ impl<'db> TypeInferenceBuilder<'db> {
} }
fn infer_function_type_params(&mut self, function: &ast::StmtFunctionDef) { fn infer_function_type_params(&mut self, function: &ast::StmtFunctionDef) {
let Some(type_params) = function.type_params.as_deref() else { let type_params = function
panic!("function type params scope without type params"); .type_params
}; .as_deref()
.expect("function type params scope without type params");
// TODO: this should also be applied to parameter annotations. // TODO: this should also be applied to parameter annotations.
if !self.is_stub() { if !self.is_stub() {
@ -1398,10 +1399,10 @@ impl<'db> TypeInferenceBuilder<'db> {
ast::Number::Int(n) => n ast::Number::Int(n) => n
.as_i64() .as_i64()
.map(Type::IntLiteral) .map(Type::IntLiteral)
.unwrap_or_else(|| builtins_symbol_ty_by_name(self.db, "int").instance()), .unwrap_or_else(|| builtins_symbol_ty_by_name(self.db, "int").to_instance()),
ast::Number::Float(_) => builtins_symbol_ty_by_name(self.db, "float").instance(), ast::Number::Float(_) => builtins_symbol_ty_by_name(self.db, "float").to_instance(),
ast::Number::Complex { .. } => { ast::Number::Complex { .. } => {
builtins_symbol_ty_by_name(self.db, "complex").instance() builtins_symbol_ty_by_name(self.db, "complex").to_instance()
} }
} }
} }
@ -1501,7 +1502,7 @@ impl<'db> TypeInferenceBuilder<'db> {
} }
// TODO generic // TODO generic
builtins_symbol_ty_by_name(self.db, "tuple").instance() builtins_symbol_ty_by_name(self.db, "tuple").to_instance()
} }
fn infer_list_expression(&mut self, list: &ast::ExprList) -> Type<'db> { fn infer_list_expression(&mut self, list: &ast::ExprList) -> Type<'db> {
@ -1516,7 +1517,7 @@ impl<'db> TypeInferenceBuilder<'db> {
} }
// TODO generic // TODO generic
builtins_symbol_ty_by_name(self.db, "list").instance() builtins_symbol_ty_by_name(self.db, "list").to_instance()
} }
fn infer_set_expression(&mut self, set: &ast::ExprSet) -> Type<'db> { fn infer_set_expression(&mut self, set: &ast::ExprSet) -> Type<'db> {
@ -1527,7 +1528,7 @@ impl<'db> TypeInferenceBuilder<'db> {
} }
// TODO generic // TODO generic
builtins_symbol_ty_by_name(self.db, "set").instance() builtins_symbol_ty_by_name(self.db, "set").to_instance()
} }
fn infer_dict_expression(&mut self, dict: &ast::ExprDict) -> Type<'db> { fn infer_dict_expression(&mut self, dict: &ast::ExprDict) -> Type<'db> {
@ -1539,7 +1540,7 @@ impl<'db> TypeInferenceBuilder<'db> {
} }
// TODO generic // TODO generic
builtins_symbol_ty_by_name(self.db, "dict").instance() builtins_symbol_ty_by_name(self.db, "dict").to_instance()
} }
/// Infer the type of the `iter` expression of the first comprehension. /// Infer the type of the `iter` expression of the first comprehension.
@ -1927,22 +1928,22 @@ impl<'db> TypeInferenceBuilder<'db> {
(Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Add) => n (Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Add) => n
.checked_add(m) .checked_add(m)
.map(Type::IntLiteral) .map(Type::IntLiteral)
.unwrap_or_else(|| builtins_symbol_ty_by_name(self.db, "int").instance()), .unwrap_or_else(|| builtins_symbol_ty_by_name(self.db, "int").to_instance()),
(Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Sub) => n (Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Sub) => n
.checked_sub(m) .checked_sub(m)
.map(Type::IntLiteral) .map(Type::IntLiteral)
.unwrap_or_else(|| builtins_symbol_ty_by_name(self.db, "int").instance()), .unwrap_or_else(|| builtins_symbol_ty_by_name(self.db, "int").to_instance()),
(Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Mult) => n (Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Mult) => n
.checked_mul(m) .checked_mul(m)
.map(Type::IntLiteral) .map(Type::IntLiteral)
.unwrap_or_else(|| builtins_symbol_ty_by_name(self.db, "int").instance()), .unwrap_or_else(|| builtins_symbol_ty_by_name(self.db, "int").to_instance()),
(Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Div) => n (Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Div) => n
.checked_div(m) .checked_div(m)
.map(Type::IntLiteral) .map(Type::IntLiteral)
.unwrap_or_else(|| builtins_symbol_ty_by_name(self.db, "int").instance()), .unwrap_or_else(|| builtins_symbol_ty_by_name(self.db, "int").to_instance()),
(Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Mod) => n (Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Mod) => n
.checked_rem(m) .checked_rem(m)
@ -2152,7 +2153,7 @@ impl<'db> TypeInferenceBuilder<'db> {
name.ctx name.ctx
); );
self.infer_name_expression(name).instance() self.infer_name_expression(name).to_instance()
} }
ast::Expr::NoneLiteral(_literal) => Type::None, ast::Expr::NoneLiteral(_literal) => Type::None,
@ -2328,7 +2329,7 @@ mod tests {
use crate::semantic_index::definition::Definition; use crate::semantic_index::definition::Definition;
use crate::semantic_index::symbol::FileScopeId; use crate::semantic_index::symbol::FileScopeId;
use crate::semantic_index::{global_scope, semantic_index, symbol_table, use_def_map}; use crate::semantic_index::{global_scope, semantic_index, symbol_table, use_def_map};
use crate::types::{global_symbol_ty_by_name, infer_definition_types, symbol_ty_by_name, Type}; use crate::types::{global_symbol_ty_by_name, infer_definition_types, symbol_ty_by_name};
use crate::{HasTy, ProgramSettings, SemanticModel}; use crate::{HasTy, ProgramSettings, SemanticModel};
use super::TypeInferenceBuilder; use super::TypeInferenceBuilder;
@ -2587,9 +2588,7 @@ mod tests {
let mod_file = system_path_to_file(&db, "src/mod.py").expect("Expected file to exist."); let mod_file = system_path_to_file(&db, "src/mod.py").expect("Expected file to exist.");
let ty = global_symbol_ty_by_name(&db, mod_file, "Sub"); let ty = global_symbol_ty_by_name(&db, mod_file, "Sub");
let Type::Class(class) = ty else { let class = ty.expect_class();
panic!("Sub is not a Class")
};
let base_names: Vec<_> = class let base_names: Vec<_> = class
.bases(&db) .bases(&db)
@ -2615,19 +2614,11 @@ mod tests {
let mod_file = system_path_to_file(&db, "src/mod.py").unwrap(); let mod_file = system_path_to_file(&db, "src/mod.py").unwrap();
let ty = global_symbol_ty_by_name(&db, mod_file, "C"); let ty = global_symbol_ty_by_name(&db, mod_file, "C");
let class_id = ty.expect_class();
let Type::Class(class_id) = ty else {
panic!("C is not a Class");
};
let member_ty = class_id.class_member(&db, &Name::new_static("f")); let member_ty = class_id.class_member(&db, &Name::new_static("f"));
let func = member_ty.expect_function();
let Type::Function(func) = member_ty else {
panic!("C.f is not a Function");
};
assert_eq!(func.name(&db), "f"); assert_eq!(func.name(&db), "f");
Ok(()) Ok(())
} }
@ -2826,11 +2817,7 @@ mod tests {
db.write_file("src/a.py", "def example() -> int: return 42")?; db.write_file("src/a.py", "def example() -> int: return 42")?;
let mod_file = system_path_to_file(&db, "src/a.py").unwrap(); let mod_file = system_path_to_file(&db, "src/a.py").unwrap();
let ty = global_symbol_ty_by_name(&db, mod_file, "example"); let function = global_symbol_ty_by_name(&db, mod_file, "example").expect_function();
let Type::Function(function) = ty else {
panic!("example is not a function");
};
let returns = function.return_type(&db); let returns = function.return_type(&db);
assert_eq!(returns.display(&db).to_string(), "int"); assert_eq!(returns.display(&db).to_string(), "int");
@ -3248,20 +3235,14 @@ mod tests {
let a = system_path_to_file(&db, "src/a.py").expect("Expected file to exist."); let a = system_path_to_file(&db, "src/a.py").expect("Expected file to exist.");
let c_ty = global_symbol_ty_by_name(&db, a, "C"); let c_ty = global_symbol_ty_by_name(&db, a, "C");
let Type::Class(c_class) = c_ty else { let c_class = c_ty.expect_class();
panic!("C is not a Class")
};
let mut c_bases = c_class.bases(&db); let mut c_bases = c_class.bases(&db);
let b_ty = c_bases.next().unwrap(); let b_ty = c_bases.next().unwrap();
let Type::Class(b_class) = b_ty else { let b_class = b_ty.expect_class();
panic!("B is not a Class")
};
assert_eq!(b_class.name(&db), "B"); assert_eq!(b_class.name(&db), "B");
let mut b_bases = b_class.bases(&db); let mut b_bases = b_class.bases(&db);
let a_ty = b_bases.next().unwrap(); let a_ty = b_bases.next().unwrap();
let Type::Class(a_class) = a_ty else { let a_class = a_ty.expect_class();
panic!("A is not a Class")
};
assert_eq!(a_class.name(&db), "A"); assert_eq!(a_class.name(&db), "A");
Ok(()) Ok(())
@ -3481,9 +3462,7 @@ mod tests {
// imported builtins module is the same file as the implicit builtins // imported builtins module is the same file as the implicit builtins
let file = system_path_to_file(&db, "/src/a.py").expect("Expected file to exist."); let file = system_path_to_file(&db, "/src/a.py").expect("Expected file to exist.");
let builtins_ty = global_symbol_ty_by_name(&db, file, "builtins"); let builtins_ty = global_symbol_ty_by_name(&db, file, "builtins");
let Type::Module(builtins_file) = builtins_ty else { let builtins_file = builtins_ty.expect_module();
panic!("Builtins are not a module?");
};
let implicit_builtins_file = builtins_scope(&db).expect("builtins to exist").file(&db); let implicit_builtins_file = builtins_scope(&db).expect("builtins to exist").file(&db);
assert_eq!(builtins_file, implicit_builtins_file); assert_eq!(builtins_file, implicit_builtins_file);