complete equality of tags

This commit is contained in:
Folkert 2021-02-10 03:17:24 +01:00
parent 6aaf12c49c
commit b70cedf587
4 changed files with 635 additions and 113 deletions

View file

@ -4,7 +4,7 @@ use crate::llvm::build_list::{list_len, load_list_ptr};
use crate::llvm::build_str::str_equal;
use crate::llvm::convert::{basic_type_from_layout, get_ptr_type};
use bumpalo::collections::Vec;
use inkwell::values::{BasicValueEnum, FunctionValue, IntValue, StructValue};
use inkwell::values::{BasicValueEnum, FunctionValue, IntValue, PointerValue, StructValue};
use inkwell::{AddressSpace, FloatPredicate, IntPredicate};
use roc_module::symbol::Symbol;
use roc_mono::layout::{Builtin, Layout, LayoutIds, UnionLayout};
@ -559,11 +559,30 @@ fn build_struct_eq_help<'a, 'ctx, 'env>(
}
WhenRecursive::Loop(union_layout) => {
let field_layout = Layout::Union(union_layout.clone());
let bt = basic_type_from_layout(
env.arena,
env.context,
&field_layout,
env.ptr_bytes,
);
// cast the i64 pointer to a pointer to block of memory
let field1_cast = env
.builder
.build_bitcast(field1, bt, "i64_to_opaque")
.into_pointer_value();
let field2_cast = env
.builder
.build_bitcast(field2, bt, "i64_to_opaque")
.into_pointer_value();
build_eq(
env,
layout_ids,
field1,
field2,
field1_cast.into(),
field2_cast.into(),
&field_layout,
&field_layout,
)
@ -647,6 +666,8 @@ fn build_tag_eq_help<'a, 'ctx, 'env>(
parent: FunctionValue<'ctx>,
union_layout: &UnionLayout<'a>,
) {
use inkwell::types::BasicType;
let ctx = env.context;
let builder = env.builder;
@ -681,16 +702,30 @@ fn build_tag_eq_help<'a, 'ctx, 'env>(
let entry = ctx.append_basic_block(parent, "entry");
let return_true = ctx.append_basic_block(parent, "return_true");
let return_false = ctx.append_basic_block(parent, "return_false");
{
env.builder.position_at_end(return_false);
env.builder
.build_return(Some(&env.context.bool_type().const_int(0, false)));
}
{
env.builder.position_at_end(return_true);
env.builder
.build_return(Some(&env.context.bool_type().const_int(1, false)));
}
env.builder.position_at_end(entry);
use UnionLayout::*;
match union_layout {
NonRecursive(tags) => {
let id1 = get_tag_id(env, union_layout, tag1);
let id2 = get_tag_id(env, union_layout, tag2);
// SAFETY we know that non-recursive tags cannot be NULL
let id1 = nonrec_tag_id(env, tag1.into_struct_value());
let id2 = nonrec_tag_id(env, tag2.into_struct_value());
let compare_tag_fields = ctx.append_basic_block(parent, "compare_tag_fields");
@ -740,18 +775,22 @@ fn build_tag_eq_help<'a, 'ctx, 'env>(
env.builder.build_return(Some(&answer));
cases.push((env.context.i8_type().const_int(tag_id as u64, false), block));
cases.push((
env.context.i64_type().const_int(tag_id as u64, false),
block,
));
}
env.builder.position_at_end(entry);
env.builder.position_at_end(compare_tag_fields);
let default = cases.pop().unwrap().1;
env.builder.build_switch(id1, default, &cases);
}
Recursive(tags) => {
let id1 = get_tag_id(env, union_layout, tag1);
let id2 = get_tag_id(env, union_layout, tag2);
// SAFETY we know that non-recursive tags cannot be NULL
let id1 = unsafe { rec_tag_id_unsafe(env, tag1.into_pointer_value()) };
let id2 = unsafe { rec_tag_id_unsafe(env, tag2.into_pointer_value()) };
let compare_tag_fields = ctx.append_basic_block(parent, "compare_tag_fields");
@ -772,127 +811,255 @@ fn build_tag_eq_help<'a, 'ctx, 'env>(
let block = env.context.append_basic_block(parent, "tag_id_modify");
env.builder.position_at_end(block);
// TODO drop tag id?
let answer = eq_ptr_to_struct(
env,
layout_ids,
union_layout,
field_layouts,
tag1.into_pointer_value(),
tag2.into_pointer_value(),
);
env.builder.build_return(Some(&answer));
cases.push((
env.context.i64_type().const_int(tag_id as u64, false),
block,
));
}
env.builder.position_at_end(compare_tag_fields);
let default = cases.pop().unwrap().1;
env.builder.build_switch(id1, default, &cases);
}
NullableUnwrapped { other_fields, .. } => {
// drop the tag id; it is not stored
let other_fields = &other_fields[1..];
// IDEA add up is_null_1 + is_null_2, then switch on the sum
let is_null_1 = env
.builder
.build_is_null(tag1.into_pointer_value(), "is_null");
let is_null_2 = env
.builder
.build_is_null(tag2.into_pointer_value(), "is_null");
let l_is_null = ctx.append_basic_block(parent, "l_is_null");
let l_is_not_null = ctx.append_basic_block(parent, "l_is_not_null");
let compare_other = ctx.append_basic_block(parent, "compare_other");
env.builder
.build_conditional_branch(is_null_1, l_is_null, l_is_not_null);
// LHS is NULL
env.builder.position_at_end(l_is_null);
env.builder.build_return(Some(&is_null_2));
// LHS is not NULL
env.builder.position_at_end(l_is_not_null);
env.builder
.build_conditional_branch(is_null_2, return_false, compare_other);
// compare the non-null case
env.builder.position_at_end(compare_other);
let answer = eq_ptr_to_struct(
env,
layout_ids,
union_layout,
other_fields,
tag1.into_pointer_value(),
tag2.into_pointer_value(),
);
env.builder.build_return(Some(&answer));
}
NullableWrapped {
nullable_id,
other_tags,
} => {
// IDEA add up is_null_1 + is_null_2, then switch on the sum
let is_null_1 = env
.builder
.build_is_null(tag1.into_pointer_value(), "is_null");
let is_null_2 = env
.builder
.build_is_null(tag2.into_pointer_value(), "is_null");
let l_is_null = ctx.append_basic_block(parent, "l_is_null");
let l_is_not_null = ctx.append_basic_block(parent, "l_is_not_null");
let compare_other = ctx.append_basic_block(parent, "compare_other");
env.builder
.build_conditional_branch(is_null_1, l_is_null, l_is_not_null);
// LHS is NULL
env.builder.position_at_end(l_is_null);
env.builder.build_return(Some(&is_null_2));
// LHS is not NULL
env.builder.position_at_end(l_is_not_null);
env.builder
.build_conditional_branch(is_null_2, return_false, compare_other);
// compare the non-null case
env.builder.position_at_end(compare_other);
// SAFETY we know at this point that tag1/tag2 are not NULL
let id1 = unsafe { rec_tag_id_unsafe(env, tag1.into_pointer_value()) };
let id2 = unsafe { rec_tag_id_unsafe(env, tag2.into_pointer_value()) };
let compare_tag_fields = ctx.append_basic_block(parent, "compare_tag_fields");
let same_tag =
env.builder
.build_int_compare(IntPredicate::EQ, id1, id2, "compare_tag_id");
env.builder
.build_conditional_branch(same_tag, compare_tag_fields, return_false);
env.builder.position_at_end(compare_tag_fields);
// switch on all the tag ids
let tags = other_tags;
let mut cases = Vec::with_capacity_in(tags.len(), env.arena);
for (tag_id, field_layouts) in tags.iter().enumerate() {
let block = env.context.append_basic_block(parent, "tag_id_modify");
env.builder.position_at_end(block);
let answer = eq_ptr_to_struct(
env,
layout_ids,
union_layout,
field_layouts,
tag1.into_pointer_value(),
tag2.into_pointer_value(),
);
env.builder.build_return(Some(&answer));
cases.push((
env.context.i64_type().const_int(tag_id as u64, false),
block,
));
}
env.builder.position_at_end(compare_tag_fields);
let default = cases.pop().unwrap().1;
env.builder.build_switch(id1, default, &cases);
}
NonNullableUnwrapped(field_layouts) => {
let answer = eq_ptr_to_struct(
env,
layout_ids,
union_layout,
field_layouts,
tag1.into_pointer_value(),
tag2.into_pointer_value(),
);
env.builder.build_return(Some(&answer));
}
}
}
fn eq_ptr_to_struct<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_ids: &mut LayoutIds<'a>,
union_layout: &UnionLayout<'a>,
field_layouts: &'a [Layout<'a>],
tag1: PointerValue<'ctx>,
tag2: PointerValue<'ctx>,
) -> IntValue<'ctx> {
use inkwell::types::BasicType;
let struct_layout = Layout::Struct(field_layouts);
let wrapper_type =
basic_type_from_layout(env.arena, env.context, &struct_layout, env.ptr_bytes);
debug_assert!(wrapper_type.is_struct_type());
let struct1 = cast_block_of_memory_to_tag(
env.builder,
tag1.into_struct_value(),
wrapper_type,
);
let struct2 = cast_block_of_memory_to_tag(
env.builder,
tag2.into_struct_value(),
wrapper_type,
);
// cast the opaque pointer to a pointer of the correct shape
let struct1_ptr = env
.builder
.build_bitcast(
tag1,
wrapper_type.ptr_type(AddressSpace::Generic),
"opaque_to_correct",
)
.into_pointer_value();
let answer = build_struct_eq(
let struct2_ptr = env
.builder
.build_bitcast(
tag2,
wrapper_type.ptr_type(AddressSpace::Generic),
"opaque_to_correct",
)
.into_pointer_value();
let struct1 = env
.builder
.build_load(struct1_ptr, "load_struct1")
.into_struct_value();
let struct2 = env
.builder
.build_load(struct2_ptr, "load_struct2")
.into_struct_value();
build_struct_eq(
env,
layout_ids,
field_layouts,
WhenRecursive::Loop(union_layout.clone()),
struct1,
struct2,
);
env.builder.build_return(Some(&answer));
cases.push((env.context.i8_type().const_int(tag_id as u64, false), block));
}
env.builder.position_at_end(entry);
let default = cases.pop().unwrap().1;
env.builder.build_switch(id1, default, &cases);
}
_ => todo!(),
}
}
fn get_tag_id<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
union_layout: &UnionLayout<'a>,
tag: BasicValueEnum<'ctx>,
) -> IntValue<'ctx> {
use UnionLayout::*;
match union_layout {
NonRecursive(_tags) | Recursive(_tags) => {
// for now, the tag is always the first field
complex_bitcast(
env.builder,
tag,
env.context.i64_type().into(),
"cast_for_tag_id",
)
.into_int_value()
}
}
NonNullableUnwrapped(_fields) => {
// there is only one tag; it has tag_id 0
env.context.i64_type().const_int(0, false)
}
fn nonrec_tag_id<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
tag: StructValue<'ctx>,
) -> IntValue<'ctx> {
complex_bitcast(
env.builder,
tag.into(),
env.context.i64_type().into(),
"load_tag_id",
)
.into_int_value()
}
NullableWrapped { nullable_id, .. } => {
debug_assert!(tag.is_pointer_value());
let tag_ptr = tag.into_pointer_value();
let is_null = env.builder.build_is_null(tag_ptr, "is_null");
let parent = env
.builder
.get_insert_block()
.unwrap()
.get_parent()
.unwrap();
let source_block = env.builder.get_insert_block().unwrap();
let merge_block = env.context.append_basic_block(parent, "merge");
let not_null_block = env.context.append_basic_block(parent, "not_null");
env.builder
.build_conditional_branch(is_null, merge_block, not_null_block);
// NOT NULL
env.builder.position_at_end(not_null_block);
let tag_ptr = env
unsafe fn rec_tag_id_unsafe<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
tag: PointerValue<'ctx>,
) -> IntValue<'ctx> {
let ptr = env
.builder
.build_bitcast(
tag_ptr,
tag,
env.context.i64_type().ptr_type(AddressSpace::Generic),
"to_tag_id_ptr",
"cast_for_tag_id",
)
.into_pointer_value();
let non_null_tag_id = env.builder.build_load(tag_ptr, "read_tag_id");
env.builder.build_unconditional_branch(merge_block);
// MERGE
env.builder.position_at_end(not_null_block);
let merged = env.builder.build_phi(env.context.i64_type(), "tag_id");
merged.add_incoming(&[
(
&env.context.i64_type().const_int(*nullable_id as u64, false),
source_block,
),
(&non_null_tag_id, not_null_block),
]);
merged.as_basic_value().into_int_value()
}
NullableUnwrapped { nullable_id, .. } => {
todo!()
}
}
env.builder.build_load(ptr, "load_tag_id").into_int_value()
}

View file

@ -0,0 +1,335 @@
#[macro_use]
extern crate pretty_assertions;
#[macro_use]
extern crate indoc;
extern crate bumpalo;
extern crate inkwell;
extern crate libc;
extern crate roc_gen;
#[macro_use]
mod helpers;
#[cfg(test)]
mod gen_num {
#[test]
fn eq_i64() {
assert_evals_to!(
indoc!(
r#"
i : I64
i = 1
i == i
"#
),
true,
bool
);
}
#[test]
fn neq_i64() {
assert_evals_to!(
indoc!(
r#"
i : I64
i = 1
i != i
"#
),
false,
bool
);
}
#[test]
fn eq_u64() {
assert_evals_to!(
indoc!(
r#"
i : U64
i = 1
i == i
"#
),
true,
bool
);
}
#[test]
fn neq_u64() {
assert_evals_to!(
indoc!(
r#"
i : U64
i = 1
i != i
"#
),
false,
bool
);
}
#[test]
fn eq_f64() {
assert_evals_to!(
indoc!(
r#"
i : F64
i = 1
i == i
"#
),
true,
bool
);
}
#[test]
fn neq_f64() {
assert_evals_to!(
indoc!(
r#"
i : F64
i = 1
i != i
"#
),
false,
bool
);
}
#[test]
fn eq_bool_tag() {
assert_evals_to!(
indoc!(
r#"
true : Bool
true = True
true == True
"#
),
true,
bool
);
}
#[test]
fn neq_bool_tag() {
assert_evals_to!(
indoc!(
r#"
true : Bool
true = True
true == False
"#
),
false,
bool
);
}
#[test]
fn empty_record() {
assert_evals_to!("{} == {}", true, bool);
assert_evals_to!("{} != {}", false, bool);
}
#[test]
fn unit() {
assert_evals_to!("Unit == Unit", true, bool);
assert_evals_to!("Unit != Unit", false, bool);
}
#[test]
fn newtype() {
assert_evals_to!("Identity 42 == Identity 42", true, bool);
assert_evals_to!("Identity 42 != Identity 42", false, bool);
}
#[test]
fn small_str() {
assert_evals_to!("\"aaa\" == \"aaa\"", true, bool);
assert_evals_to!("\"aaa\" == \"bbb\"", false, bool);
assert_evals_to!("\"aaa\" != \"aaa\"", false, bool);
}
#[test]
fn large_str() {
assert_evals_to!(
indoc!(
r#"
x = "Unicode can represent text values which span multiple languages"
y = "Unicode can represent text values which span multiple languages"
x == y
"#
),
true,
bool
);
assert_evals_to!(
indoc!(
r#"
x = "Unicode can represent text values which span multiple languages"
y = "Here are some valid Roc strings"
x != y
"#
),
true,
bool
);
}
#[test]
fn eq_result_tag_true() {
assert_evals_to!(
indoc!(
r#"
x : Result I64 I64
x = Ok 1
y : Result I64 I64
y = Ok 1
x == y
"#
),
true,
bool
);
}
#[test]
fn eq_result_tag_false() {
assert_evals_to!(
indoc!(
r#"
x : Result I64 I64
x = Ok 1
y : Result I64 I64
y = Err 1
x == y
"#
),
false,
bool
);
}
#[test]
fn eq_expr() {
assert_evals_to!(
indoc!(
r#"
Expr : [ Add Expr Expr, Mul Expr Expr, Val I64, Var I64 ]
x : Expr
x = Val 0
y : Expr
y = Val 0
x == y
"#
),
true,
bool
);
}
#[test]
fn eq_linked_list() {
assert_evals_to!(
indoc!(
r#"
LinkedList a : [ Nil, Cons a (LinkedList a) ]
x : LinkedList I64
x = Nil
y : LinkedList I64
y = Nil
x == y
"#
),
true,
bool
);
assert_evals_to!(
indoc!(
r#"
LinkedList a : [ Nil, Cons a (LinkedList a) ]
x : LinkedList I64
x = Cons 1 Nil
y : LinkedList I64
y = Cons 1 Nil
x == y
"#
),
true,
bool
);
assert_evals_to!(
indoc!(
r#"
LinkedList a : [ Nil, Cons a (LinkedList a) ]
x : LinkedList I64
x = Cons 1 (Cons 2 Nil)
y : LinkedList I64
y = Cons 1 (Cons 2 Nil)
x == y
"#
),
true,
bool
);
}
#[test]
fn eq_linked_list_false() {
assert_evals_to!(
indoc!(
r#"
LinkedList a : [ Nil, Cons a (LinkedList a) ]
x : LinkedList I64
x = Cons 1 Nil
y : LinkedList I64
y = Cons 1 (Cons 2 Nil)
y == x
"#
),
false,
bool
);
}
}

View file

@ -251,8 +251,8 @@ pub fn helper<'a>(
mode,
);
fn_val.print_to_stderr();
// module.print_to_stderr();
// fn_val.print_to_stderr();
module.print_to_stderr();
panic!(
"The preceding code was from {:?}, which failed LLVM verification in {} build.",

View file

@ -19,6 +19,28 @@ use ven_pretty::{BoxAllocator, DocAllocator, DocBuilder};
pub const PRETTY_PRINT_IR_SYMBOLS: bool = false;
macro_rules! return_on_layout_error {
($env:expr, $layout_result:expr) => {
match $layout_result {
Ok(cached) => cached,
Err(LayoutProblem::UnresolvedTypeVar(_)) => {
return Stmt::RuntimeError($env.arena.alloc(format!(
"UnresolvedTypeVar {} line {}",
file!(),
line!()
)));
}
Err(LayoutProblem::Erroneous) => {
return Stmt::RuntimeError($env.arena.alloc(format!(
"Erroneous {} line {}",
file!(),
line!()
)));
}
}
};
}
#[derive(Clone, Debug, PartialEq)]
pub enum MonoProblem {
PatternProblem(crate::exhaustive::Error),
@ -3786,11 +3808,10 @@ pub fn with_hole<'a>(
)
.into_bump_slice();
let full_layout = layout_cache
.from_var(env.arena, fn_var, env.subs)
.unwrap_or_else(|err| {
panic!("TODO turn fn_var into a RuntimeError {:?}", err)
});
let full_layout = return_on_layout_error!(
env,
layout_cache.from_var(env.arena, fn_var, env.subs)
);
let arg_layouts = match full_layout {
Layout::FunctionPointer(args, _) => args,
@ -3798,11 +3819,10 @@ pub fn with_hole<'a>(
_ => unreachable!("function has layout that is not function pointer"),
};
let ret_layout = layout_cache
.from_var(env.arena, ret_var, env.subs)
.unwrap_or_else(|err| {
panic!("TODO turn fn_var into a RuntimeError {:?}", err)
});
let ret_layout = return_on_layout_error!(
env,
layout_cache.from_var(env.arena, ret_var, env.subs)
);
// if the function expression (loc_expr) is already a symbol,
// re-use that symbol, and don't define its value again