roc/crates/compiler/gen_llvm/src/llvm/compare.rs
2023-01-25 13:14:38 -06:00

1459 lines
45 KiB
Rust

use crate::llvm::build::{get_tag_id, tag_pointer_clear_tag_id, Env, FAST_CALL_CONV};
use crate::llvm::build_list::{list_len, load_list_ptr};
use crate::llvm::build_str::str_equal;
use crate::llvm::convert::basic_type_from_layout;
use bumpalo::collections::Vec;
use inkwell::types::BasicType;
use inkwell::values::{
BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue, StructValue,
};
use inkwell::{AddressSpace, FloatPredicate, IntPredicate};
use roc_builtins::bitcode;
use roc_builtins::bitcode::{FloatWidth, IntWidth};
use roc_error_macros::internal_error;
use roc_module::symbol::Symbol;
use roc_mono::layout::{
Builtin, InLayout, Layout, LayoutIds, LayoutInterner, STLayoutInterner, UnionLayout,
};
use super::build::{load_roc_value, use_roc_value, BuilderExt};
use super::convert::argument_type_from_union_layout;
use super::lowlevel::dec_binop_with_unchecked;
pub fn generic_eq<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_interner: &mut STLayoutInterner<'a>,
layout_ids: &mut LayoutIds<'a>,
lhs_val: BasicValueEnum<'ctx>,
rhs_val: BasicValueEnum<'ctx>,
lhs_layout: InLayout<'a>,
rhs_layout: InLayout<'a>,
) -> BasicValueEnum<'ctx> {
build_eq(
env,
layout_interner,
layout_ids,
lhs_val,
rhs_val,
lhs_layout,
rhs_layout,
)
}
pub fn generic_neq<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_interner: &mut STLayoutInterner<'a>,
layout_ids: &mut LayoutIds<'a>,
lhs_val: BasicValueEnum<'ctx>,
rhs_val: BasicValueEnum<'ctx>,
lhs_layout: InLayout<'a>,
rhs_layout: InLayout<'a>,
) -> BasicValueEnum<'ctx> {
build_neq(
env,
layout_interner,
layout_ids,
lhs_val,
rhs_val,
lhs_layout,
rhs_layout,
)
}
fn build_eq_builtin<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_interner: &mut STLayoutInterner<'a>,
layout_ids: &mut LayoutIds<'a>,
lhs_val: BasicValueEnum<'ctx>,
rhs_val: BasicValueEnum<'ctx>,
builtin_layout: InLayout<'a>,
builtin: &Builtin<'a>,
) -> BasicValueEnum<'ctx> {
let int_cmp = |pred, label| {
let int_val = env.builder.build_int_compare(
pred,
lhs_val.into_int_value(),
rhs_val.into_int_value(),
label,
);
BasicValueEnum::IntValue(int_val)
};
let float_cmp = |pred, label| {
let int_val = env.builder.build_float_compare(
pred,
lhs_val.into_float_value(),
rhs_val.into_float_value(),
label,
);
BasicValueEnum::IntValue(int_val)
};
match builtin {
Builtin::Int(int_width) => {
use IntWidth::*;
let name = match int_width {
I128 => "eq_i128",
I64 => "eq_i64",
I32 => "eq_i32",
I16 => "eq_i16",
I8 => "eq_i8",
U128 => "eq_u128",
U64 => "eq_u64",
U32 => "eq_u32",
U16 => "eq_u16",
U8 => "eq_u8",
};
int_cmp(IntPredicate::EQ, name)
}
Builtin::Float(float_width) => {
use FloatWidth::*;
let name = match float_width {
F64 => "eq_f64",
F32 => "eq_f32",
};
float_cmp(FloatPredicate::OEQ, name)
}
Builtin::Bool => int_cmp(IntPredicate::EQ, "eq_i1"),
Builtin::Decimal => dec_binop_with_unchecked(env, bitcode::DEC_EQ, lhs_val, rhs_val),
Builtin::Str => str_equal(env, lhs_val, rhs_val),
Builtin::List(elem) => build_list_eq(
env,
layout_interner,
layout_ids,
builtin_layout,
*elem,
lhs_val.into_struct_value(),
rhs_val.into_struct_value(),
),
}
}
fn build_eq<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_interner: &mut STLayoutInterner<'a>,
layout_ids: &mut LayoutIds<'a>,
lhs_val: BasicValueEnum<'ctx>,
rhs_val: BasicValueEnum<'ctx>,
lhs_layout: InLayout<'a>,
rhs_layout: InLayout<'a>,
) -> BasicValueEnum<'ctx> {
let lhs_layout = &layout_interner.runtime_representation_in(lhs_layout);
let rhs_layout = &layout_interner.runtime_representation_in(rhs_layout);
if lhs_layout != rhs_layout {
panic!(
"Equality of different layouts; did you have a type mismatch?\n{:?} == {:?}",
lhs_layout, rhs_layout
);
}
match layout_interner.get(*lhs_layout) {
Layout::Builtin(builtin) => build_eq_builtin(
env,
layout_interner,
layout_ids,
lhs_val,
rhs_val,
*lhs_layout,
&builtin,
),
Layout::Struct { field_layouts, .. } => build_struct_eq(
env,
layout_interner,
layout_ids,
*lhs_layout,
field_layouts,
lhs_val.into_struct_value(),
rhs_val.into_struct_value(),
),
Layout::LambdaSet(_) => unreachable!("cannot compare closures"),
Layout::Union(union_layout) => build_tag_eq(
env,
layout_interner,
layout_ids,
*lhs_layout,
&union_layout,
lhs_val,
rhs_val,
),
Layout::Boxed(inner_layout) => build_box_eq(
env,
layout_interner,
layout_ids,
*lhs_layout,
inner_layout,
lhs_val,
rhs_val,
),
Layout::RecursivePointer(rec_layout) => {
let layout = rec_layout;
let bt = basic_type_from_layout(env, layout_interner, layout);
// cast the i64 pointer to a pointer to block of memory
let field1_cast = env.builder.build_pointer_cast(
lhs_val.into_pointer_value(),
bt.into_pointer_type(),
"i64_to_opaque",
);
let field2_cast = env.builder.build_pointer_cast(
rhs_val.into_pointer_value(),
bt.into_pointer_type(),
"i64_to_opaque",
);
let union_layout = match layout_interner.get(rec_layout) {
Layout::Union(union_layout) => {
debug_assert!(!matches!(union_layout, UnionLayout::NonRecursive(..)));
union_layout
}
_ => internal_error!(),
};
build_tag_eq(
env,
layout_interner,
layout_ids,
rec_layout,
&union_layout,
field1_cast.into(),
field2_cast.into(),
)
}
}
}
fn build_neq_builtin<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_interner: &mut STLayoutInterner<'a>,
layout_ids: &mut LayoutIds<'a>,
lhs_val: BasicValueEnum<'ctx>,
rhs_val: BasicValueEnum<'ctx>,
builtin_layout: InLayout<'a>,
builtin: &Builtin<'a>,
) -> BasicValueEnum<'ctx> {
let int_cmp = |pred, label| {
let int_val = env.builder.build_int_compare(
pred,
lhs_val.into_int_value(),
rhs_val.into_int_value(),
label,
);
BasicValueEnum::IntValue(int_val)
};
let float_cmp = |pred, label| {
let int_val = env.builder.build_float_compare(
pred,
lhs_val.into_float_value(),
rhs_val.into_float_value(),
label,
);
BasicValueEnum::IntValue(int_val)
};
match builtin {
Builtin::Int(int_width) => {
use IntWidth::*;
let name = match int_width {
I128 => "neq_i128",
I64 => "neq_i64",
I32 => "neq_i32",
I16 => "neq_i16",
I8 => "neq_i8",
U128 => "neq_u128",
U64 => "neq_u64",
U32 => "neq_u32",
U16 => "neq_u16",
U8 => "neq_u8",
};
int_cmp(IntPredicate::NE, name)
}
Builtin::Float(float_width) => {
use FloatWidth::*;
let name = match float_width {
F64 => "neq_f64",
F32 => "neq_f32",
};
float_cmp(FloatPredicate::ONE, name)
}
Builtin::Bool => int_cmp(IntPredicate::NE, "neq_i1"),
Builtin::Decimal => dec_binop_with_unchecked(env, bitcode::DEC_NEQ, lhs_val, rhs_val),
Builtin::Str => {
let is_equal = str_equal(env, lhs_val, rhs_val).into_int_value();
let result: IntValue = env.builder.build_not(is_equal, "negate");
result.into()
}
Builtin::List(elem) => {
let is_equal = build_list_eq(
env,
layout_interner,
layout_ids,
builtin_layout,
*elem,
lhs_val.into_struct_value(),
rhs_val.into_struct_value(),
)
.into_int_value();
let result: IntValue = env.builder.build_not(is_equal, "negate");
result.into()
}
}
}
fn build_neq<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_interner: &mut STLayoutInterner<'a>,
layout_ids: &mut LayoutIds<'a>,
lhs_val: BasicValueEnum<'ctx>,
rhs_val: BasicValueEnum<'ctx>,
lhs_layout: InLayout<'a>,
rhs_layout: InLayout<'a>,
) -> BasicValueEnum<'ctx> {
if lhs_layout != rhs_layout {
panic!(
"Inequality of different layouts; did you have a type mismatch?\n{:?} != {:?}",
lhs_layout, rhs_layout
);
}
match layout_interner.get(lhs_layout) {
Layout::Builtin(builtin) => build_neq_builtin(
env,
layout_interner,
layout_ids,
lhs_val,
rhs_val,
lhs_layout,
&builtin,
),
Layout::Struct { field_layouts, .. } => {
let is_equal = build_struct_eq(
env,
layout_interner,
layout_ids,
lhs_layout,
field_layouts,
lhs_val.into_struct_value(),
rhs_val.into_struct_value(),
)
.into_int_value();
let result: IntValue = env.builder.build_not(is_equal, "negate");
result.into()
}
Layout::Union(union_layout) => {
let is_equal = build_tag_eq(
env,
layout_interner,
layout_ids,
lhs_layout,
&union_layout,
lhs_val,
rhs_val,
)
.into_int_value();
let result: IntValue = env.builder.build_not(is_equal, "negate");
result.into()
}
Layout::Boxed(inner_layout) => {
let is_equal = build_box_eq(
env,
layout_interner,
layout_ids,
lhs_layout,
inner_layout,
lhs_val,
rhs_val,
)
.into_int_value();
let result: IntValue = env.builder.build_not(is_equal, "negate");
result.into()
}
Layout::RecursivePointer(_) => {
unreachable!("recursion pointers should never be compared directly")
}
Layout::LambdaSet(_) => unreachable!("cannot compare closure"),
}
}
fn build_list_eq<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_interner: &mut STLayoutInterner<'a>,
layout_ids: &mut LayoutIds<'a>,
list_layout: InLayout<'a>,
element_layout: InLayout<'a>,
list1: StructValue<'ctx>,
list2: StructValue<'ctx>,
) -> BasicValueEnum<'ctx> {
let block = env.builder.get_insert_block().expect("to be in a function");
let di_location = env.builder.get_current_debug_location().unwrap();
let symbol = Symbol::LIST_EQ;
let element_layout = if let Layout::RecursivePointer(rec) = layout_interner.get(element_layout)
{
rec
} else {
element_layout
};
let fn_name = layout_ids
.get(symbol, &element_layout)
.to_symbol_string(symbol, &env.interns);
let function = match env.module.get_function(fn_name.as_str()) {
Some(function_value) => function_value,
None => {
let arg_type = basic_type_from_layout(env, layout_interner, list_layout);
let function_value = crate::llvm::refcounting::build_header_help(
env,
&fn_name,
env.context.bool_type().into(),
&[arg_type, arg_type],
);
build_list_eq_help(
env,
layout_interner,
layout_ids,
function_value,
element_layout,
);
function_value
}
};
env.builder.position_at_end(block);
env.builder
.set_current_debug_location(env.context, di_location);
let call = env
.builder
.build_call(function, &[list1.into(), list2.into()], "list_eq");
call.set_call_convention(FAST_CALL_CONV);
call.try_as_basic_value().left().unwrap()
}
fn build_list_eq_help<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_interner: &mut STLayoutInterner<'a>,
layout_ids: &mut LayoutIds<'a>,
parent: FunctionValue<'ctx>,
element_layout: InLayout<'a>,
) {
let ctx = env.context;
let builder = env.builder;
{
use inkwell::debug_info::AsDIScope;
let func_scope = parent.get_subprogram().unwrap();
let lexical_block = env.dibuilder.create_lexical_block(
/* scope */ func_scope.as_debug_info_scope(),
/* file */ env.compile_unit.get_file(),
/* line_no */ 0,
/* column_no */ 0,
);
let loc = env.dibuilder.create_debug_location(
ctx,
/* line */ 0,
/* column */ 0,
/* current_scope */ lexical_block.as_debug_info_scope(),
/* inlined_at */ None,
);
builder.set_current_debug_location(ctx, loc);
}
// Add args to scope
let mut it = parent.get_param_iter();
let list1 = it.next().unwrap().into_struct_value();
let list2 = it.next().unwrap().into_struct_value();
list1.set_name(Symbol::ARG_1.as_str(&env.interns));
list2.set_name(Symbol::ARG_2.as_str(&env.interns));
let entry = ctx.append_basic_block(parent, "entry");
env.builder.position_at_end(entry);
let return_true = ctx.append_basic_block(parent, "return_true");
let return_false = ctx.append_basic_block(parent, "return_false");
// first, check whether the length is equal
let len1 = list_len(env.builder, list1);
let len2 = list_len(env.builder, list2);
let length_equal: IntValue =
env.builder
.build_int_compare(IntPredicate::EQ, len1, len2, "bounds_check");
let then_block = ctx.append_basic_block(parent, "then");
env.builder
.build_conditional_branch(length_equal, then_block, return_false);
{
// the length is equal; check elements pointwise
env.builder.position_at_end(then_block);
let builder = env.builder;
let element_type = basic_type_from_layout(env, layout_interner, element_layout);
let ptr_type = element_type.ptr_type(AddressSpace::Generic);
let ptr1 = load_list_ptr(env.builder, list1, ptr_type);
let ptr2 = load_list_ptr(env.builder, list2, ptr_type);
// we know that len1 == len2
let end = len1;
// allocate a stack slot for the current index
let index_alloca = builder.build_alloca(env.ptr_int(), "index");
builder.build_store(index_alloca, env.ptr_int().const_zero());
let loop_bb = ctx.append_basic_block(parent, "loop");
let body_bb = ctx.append_basic_block(parent, "body");
let increment_bb = ctx.append_basic_block(parent, "increment");
// the "top" of the loop
builder.build_unconditional_branch(loop_bb);
builder.position_at_end(loop_bb);
let curr_index = builder
.new_build_load(env.ptr_int(), index_alloca, "index")
.into_int_value();
// #index < end
let loop_end_cond =
builder.build_int_compare(IntPredicate::ULT, curr_index, end, "bounds_check");
// if we're at the end, and all elements were equal so far, return true
// otherwise check the current elements for equality
builder.build_conditional_branch(loop_end_cond, body_bb, return_true);
{
// loop body
builder.position_at_end(body_bb);
let elem1 = {
let elem_ptr = unsafe {
builder.new_build_in_bounds_gep(element_type, ptr1, &[curr_index], "load_index")
};
load_roc_value(env, layout_interner, element_layout, elem_ptr, "get_elem")
};
let elem2 = {
let elem_ptr = unsafe {
builder.new_build_in_bounds_gep(element_type, ptr2, &[curr_index], "load_index")
};
load_roc_value(env, layout_interner, element_layout, elem_ptr, "get_elem")
};
let are_equal = build_eq(
env,
layout_interner,
layout_ids,
elem1,
elem2,
element_layout,
element_layout,
)
.into_int_value();
// if the elements are equal, increment the index and check the next element
// otherwise, return false
builder.build_conditional_branch(are_equal, increment_bb, return_false);
}
{
env.builder.position_at_end(increment_bb);
// constant 1isize
let one = env.ptr_int().const_int(1, false);
let next_index = builder.build_int_add(curr_index, one, "nextindex");
builder.build_store(index_alloca, next_index);
// jump back to the top of the loop
builder.build_unconditional_branch(loop_bb);
}
}
{
env.builder.position_at_end(return_true);
env.builder
.build_return(Some(&env.context.bool_type().const_int(1, false)));
}
{
env.builder.position_at_end(return_false);
env.builder
.build_return(Some(&env.context.bool_type().const_int(0, false)));
}
}
fn build_struct_eq<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_interner: &mut STLayoutInterner<'a>,
layout_ids: &mut LayoutIds<'a>,
struct_layout: InLayout<'a>,
field_layouts: &'a [InLayout<'a>],
struct1: StructValue<'ctx>,
struct2: StructValue<'ctx>,
) -> BasicValueEnum<'ctx> {
let block = env.builder.get_insert_block().expect("to be in a function");
let di_location = env.builder.get_current_debug_location().unwrap();
let symbol = Symbol::GENERIC_EQ;
let fn_name = layout_ids
.get(symbol, &struct_layout)
.to_symbol_string(symbol, &env.interns);
let function = match env.module.get_function(fn_name.as_str()) {
Some(function_value) => function_value,
None => {
let arg_type = basic_type_from_layout(env, layout_interner, struct_layout);
let function_value = crate::llvm::refcounting::build_header_help(
env,
&fn_name,
env.context.bool_type().into(),
&[arg_type, arg_type],
);
build_struct_eq_help(
env,
layout_interner,
layout_ids,
function_value,
field_layouts,
);
function_value
}
};
env.builder.position_at_end(block);
env.builder
.set_current_debug_location(env.context, di_location);
let call = env
.builder
.build_call(function, &[struct1.into(), struct2.into()], "struct_eq");
call.set_call_convention(FAST_CALL_CONV);
call.try_as_basic_value().left().unwrap()
}
fn build_struct_eq_help<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_interner: &mut STLayoutInterner<'a>,
layout_ids: &mut LayoutIds<'a>,
parent: FunctionValue<'ctx>,
field_layouts: &[InLayout<'a>],
) {
let ctx = env.context;
let builder = env.builder;
{
use inkwell::debug_info::AsDIScope;
let func_scope = parent.get_subprogram().unwrap();
let lexical_block = env.dibuilder.create_lexical_block(
/* scope */ func_scope.as_debug_info_scope(),
/* file */ env.compile_unit.get_file(),
/* line_no */ 0,
/* column_no */ 0,
);
let loc = env.dibuilder.create_debug_location(
ctx,
/* line */ 0,
/* column */ 0,
/* current_scope */ lexical_block.as_debug_info_scope(),
/* inlined_at */ None,
);
builder.set_current_debug_location(ctx, loc);
}
// Add args to scope
let mut it = parent.get_param_iter();
let struct1 = it.next().unwrap().into_struct_value();
let struct2 = it.next().unwrap().into_struct_value();
struct1.set_name(Symbol::ARG_1.as_str(&env.interns));
struct2.set_name(Symbol::ARG_2.as_str(&env.interns));
let entry = ctx.append_basic_block(parent, "entry");
let start = ctx.append_basic_block(parent, "start");
env.builder.position_at_end(entry);
env.builder.build_unconditional_branch(start);
let return_true = ctx.append_basic_block(parent, "return_true");
let return_false = ctx.append_basic_block(parent, "return_false");
let mut current = start;
for (index, field_layout) in field_layouts.iter().enumerate() {
env.builder.position_at_end(current);
let field1 = env
.builder
.build_extract_value(struct1, index as u32, "eq_field")
.unwrap();
let field2 = env
.builder
.build_extract_value(struct2, index as u32, "eq_field")
.unwrap();
let are_equal = if let Layout::RecursivePointer(rec_layout) =
layout_interner.get(*field_layout)
{
debug_assert!(
matches!(layout_interner.get(rec_layout), Layout::Union(union_layout) if !matches!(union_layout, UnionLayout::NonRecursive(..)))
);
let field_layout = rec_layout;
let bt = basic_type_from_layout(env, layout_interner, field_layout);
// cast the i64 pointer to a pointer to block of memory
let field1_cast = env.builder.build_pointer_cast(
field1.into_pointer_value(),
bt.into_pointer_type(),
"i64_to_opaque",
);
let field2_cast = env.builder.build_pointer_cast(
field2.into_pointer_value(),
bt.into_pointer_type(),
"i64_to_opaque",
);
build_eq(
env,
layout_interner,
layout_ids,
field1_cast.into(),
field2_cast.into(),
field_layout,
field_layout,
)
.into_int_value()
} else {
let lhs = use_roc_value(env, layout_interner, *field_layout, field1, "field1");
let rhs = use_roc_value(env, layout_interner, *field_layout, field2, "field2");
build_eq(
env,
layout_interner,
layout_ids,
lhs,
rhs,
*field_layout,
*field_layout,
)
.into_int_value()
};
current = ctx.append_basic_block(parent, &format!("eq_step_{}", index));
env.builder
.build_conditional_branch(are_equal, current, return_false);
}
env.builder.position_at_end(current);
env.builder.build_unconditional_branch(return_true);
{
env.builder.position_at_end(return_true);
env.builder
.build_return(Some(&env.context.bool_type().const_int(1, false)));
}
{
env.builder.position_at_end(return_false);
env.builder
.build_return(Some(&env.context.bool_type().const_int(0, false)));
}
}
fn build_tag_eq<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_interner: &mut STLayoutInterner<'a>,
layout_ids: &mut LayoutIds<'a>,
tag_layout: InLayout<'a>,
union_layout: &UnionLayout<'a>,
tag1: BasicValueEnum<'ctx>,
tag2: BasicValueEnum<'ctx>,
) -> BasicValueEnum<'ctx> {
let block = env.builder.get_insert_block().expect("to be in a function");
let di_location = env.builder.get_current_debug_location().unwrap();
let symbol = Symbol::GENERIC_EQ;
let fn_name = layout_ids
.get(symbol, &tag_layout)
.to_symbol_string(symbol, &env.interns);
let function = match env.module.get_function(fn_name.as_str()) {
Some(function_value) => function_value,
None => {
let arg_type = argument_type_from_union_layout(env, layout_interner, union_layout);
let function_value = crate::llvm::refcounting::build_header_help(
env,
&fn_name,
env.context.bool_type().into(),
&[arg_type, arg_type],
);
build_tag_eq_help(
env,
layout_interner,
layout_ids,
function_value,
union_layout,
);
function_value
}
};
env.builder.position_at_end(block);
env.builder
.set_current_debug_location(env.context, di_location);
let call = env
.builder
.build_call(function, &[tag1.into(), tag2.into()], "tag_eq");
call.set_call_convention(FAST_CALL_CONV);
call.try_as_basic_value().left().unwrap()
}
fn build_tag_eq_help<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_interner: &mut STLayoutInterner<'a>,
layout_ids: &mut LayoutIds<'a>,
parent: FunctionValue<'ctx>,
union_layout: &UnionLayout<'a>,
) {
let ctx = env.context;
let builder = env.builder;
{
use inkwell::debug_info::AsDIScope;
let func_scope = parent.get_subprogram().unwrap();
let lexical_block = env.dibuilder.create_lexical_block(
/* scope */ func_scope.as_debug_info_scope(),
/* file */ env.compile_unit.get_file(),
/* line_no */ 0,
/* column_no */ 0,
);
let loc = env.dibuilder.create_debug_location(
ctx,
/* line */ 0,
/* column */ 0,
/* current_scope */ lexical_block.as_debug_info_scope(),
/* inlined_at */ None,
);
builder.set_current_debug_location(ctx, loc);
}
// Add args to scope
let mut it = parent.get_param_iter();
let tag1 = it.next().unwrap();
let tag2 = it.next().unwrap();
tag1.set_name(Symbol::ARG_1.as_str(&env.interns));
tag2.set_name(Symbol::ARG_2.as_str(&env.interns));
let entry = ctx.append_basic_block(parent, "entry");
let return_true = ctx.append_basic_block(parent, "return_true");
let return_false = ctx.append_basic_block(parent, "return_false");
{
env.builder.position_at_end(return_false);
env.builder
.build_return(Some(&env.context.bool_type().const_int(0, false)));
}
{
env.builder.position_at_end(return_true);
env.builder
.build_return(Some(&env.context.bool_type().const_int(1, false)));
}
env.builder.position_at_end(entry);
use UnionLayout::*;
match union_layout {
NonRecursive(&[]) => {
// we're comparing empty tag unions; this code is effectively unreachable
env.builder.build_unreachable();
}
NonRecursive(tags) => {
let ptr_equal = env.builder.build_int_compare(
IntPredicate::EQ,
env.builder
.build_ptr_to_int(tag1.into_pointer_value(), env.ptr_int(), "pti"),
env.builder
.build_ptr_to_int(tag2.into_pointer_value(), env.ptr_int(), "pti"),
"compare_pointers",
);
let compare_tag_ids = ctx.append_basic_block(parent, "compare_tag_ids");
env.builder
.build_conditional_branch(ptr_equal, return_true, compare_tag_ids);
env.builder.position_at_end(compare_tag_ids);
let id1 = get_tag_id(env, layout_interner, parent, union_layout, tag1);
let id2 = get_tag_id(env, layout_interner, parent, union_layout, tag2);
// clear the tag_id so we get a pointer to the actual data
let tag1 = tag1.into_pointer_value();
let tag2 = tag2.into_pointer_value();
let compare_tag_fields = ctx.append_basic_block(parent, "compare_tag_fields");
let same_tag =
env.builder
.build_int_compare(IntPredicate::EQ, id1, id2, "compare_tag_id");
env.builder
.build_conditional_branch(same_tag, compare_tag_fields, return_false);
env.builder.position_at_end(compare_tag_fields);
// switch on all the tag ids
let mut cases = Vec::with_capacity_in(tags.len(), env.arena);
for (tag_id, field_layouts) in tags.iter().enumerate() {
let block = env.context.append_basic_block(parent, "tag_id_modify");
env.builder.position_at_end(block);
let struct_layout =
layout_interner.insert(Layout::struct_no_name_order(field_layouts));
let answer = eq_ptr_to_struct(
env,
layout_interner,
layout_ids,
struct_layout,
field_layouts,
tag1,
tag2,
);
env.builder.build_return(Some(&answer));
cases.push((id1.get_type().const_int(tag_id as u64, false), block));
}
env.builder.position_at_end(compare_tag_fields);
match cases.pop() {
Some((_, default)) => {
env.builder.build_switch(id1, default, &cases);
}
None => {
// we're comparing empty tag unions; this code is effectively unreachable
env.builder.build_unreachable();
}
}
}
Recursive(tags) => {
let ptr_equal = env.builder.build_int_compare(
IntPredicate::EQ,
env.builder
.build_ptr_to_int(tag1.into_pointer_value(), env.ptr_int(), "pti"),
env.builder
.build_ptr_to_int(tag2.into_pointer_value(), env.ptr_int(), "pti"),
"compare_pointers",
);
let compare_tag_ids = ctx.append_basic_block(parent, "compare_tag_ids");
env.builder
.build_conditional_branch(ptr_equal, return_true, compare_tag_ids);
env.builder.position_at_end(compare_tag_ids);
let id1 = get_tag_id(env, layout_interner, parent, union_layout, tag1);
let id2 = get_tag_id(env, layout_interner, parent, union_layout, tag2);
// clear the tag_id so we get a pointer to the actual data
let tag1 = tag_pointer_clear_tag_id(env, tag1.into_pointer_value());
let tag2 = tag_pointer_clear_tag_id(env, tag2.into_pointer_value());
let compare_tag_fields = ctx.append_basic_block(parent, "compare_tag_fields");
let same_tag =
env.builder
.build_int_compare(IntPredicate::EQ, id1, id2, "compare_tag_id");
env.builder
.build_conditional_branch(same_tag, compare_tag_fields, return_false);
env.builder.position_at_end(compare_tag_fields);
// switch on all the tag ids
let mut cases = Vec::with_capacity_in(tags.len(), env.arena);
for (tag_id, field_layouts) in tags.iter().enumerate() {
let block = env.context.append_basic_block(parent, "tag_id_modify");
env.builder.position_at_end(block);
let struct_layout =
layout_interner.insert(Layout::struct_no_name_order(field_layouts));
let answer = eq_ptr_to_struct(
env,
layout_interner,
layout_ids,
struct_layout,
field_layouts,
tag1,
tag2,
);
env.builder.build_return(Some(&answer));
cases.push((id1.get_type().const_int(tag_id as u64, false), block));
}
env.builder.position_at_end(compare_tag_fields);
let default = cases.pop().unwrap().1;
env.builder.build_switch(id1, default, &cases);
}
NullableUnwrapped { other_fields, .. } => {
let ptr_equal = env.builder.build_int_compare(
IntPredicate::EQ,
env.builder
.build_ptr_to_int(tag1.into_pointer_value(), env.ptr_int(), "pti"),
env.builder
.build_ptr_to_int(tag2.into_pointer_value(), env.ptr_int(), "pti"),
"compare_pointers",
);
let check_for_null = ctx.append_basic_block(parent, "check_for_null");
let compare_other = ctx.append_basic_block(parent, "compare_other");
env.builder
.build_conditional_branch(ptr_equal, return_true, check_for_null);
// check for NULL
env.builder.position_at_end(check_for_null);
let is_null_1 = env
.builder
.build_is_null(tag1.into_pointer_value(), "is_null");
let is_null_2 = env
.builder
.build_is_null(tag2.into_pointer_value(), "is_null");
let either_null = env.builder.build_or(is_null_1, is_null_2, "either_null");
// logic: the pointers are not the same, if one is NULL, the other one is not
// therefore the two tags are not equal
env.builder
.build_conditional_branch(either_null, return_false, compare_other);
// compare the non-null case
env.builder.position_at_end(compare_other);
let struct_layout = layout_interner.insert(Layout::struct_no_name_order(other_fields));
let answer = eq_ptr_to_struct(
env,
layout_interner,
layout_ids,
struct_layout,
other_fields,
tag1.into_pointer_value(),
tag2.into_pointer_value(),
);
env.builder.build_return(Some(&answer));
}
NullableWrapped { other_tags, .. } => {
let ptr_equal = env.builder.build_int_compare(
IntPredicate::EQ,
env.builder
.build_ptr_to_int(tag1.into_pointer_value(), env.ptr_int(), "pti"),
env.builder
.build_ptr_to_int(tag2.into_pointer_value(), env.ptr_int(), "pti"),
"compare_pointers",
);
let check_for_null = ctx.append_basic_block(parent, "check_for_null");
let compare_other = ctx.append_basic_block(parent, "compare_other");
env.builder
.build_conditional_branch(ptr_equal, return_true, check_for_null);
// check for NULL
env.builder.position_at_end(check_for_null);
let is_null_1 = env
.builder
.build_is_null(tag1.into_pointer_value(), "is_null");
let is_null_2 = env
.builder
.build_is_null(tag2.into_pointer_value(), "is_null");
// Logic:
//
// NULL and NULL => equal
// NULL and not => not equal
// not and NULL => not equal
// not and not => more work required
let i8_type = env.context.i8_type();
let sum = env.builder.build_int_add(
env.builder
.build_int_cast_sign_flag(is_null_1, i8_type, false, "to_u8"),
env.builder
.build_int_cast_sign_flag(is_null_2, i8_type, false, "to_u8"),
"sum_is_null",
);
env.builder.build_switch(
sum,
compare_other,
&[
(i8_type.const_int(2, false), return_true),
(i8_type.const_int(1, false), return_false),
],
);
// compare the non-null case
env.builder.position_at_end(compare_other);
let id1 = get_tag_id(env, layout_interner, parent, union_layout, tag1);
let id2 = get_tag_id(env, layout_interner, parent, union_layout, tag2);
// clear the tag_id so we get a pointer to the actual data
let tag1 = tag_pointer_clear_tag_id(env, tag1.into_pointer_value());
let tag2 = tag_pointer_clear_tag_id(env, tag2.into_pointer_value());
let compare_tag_fields = ctx.append_basic_block(parent, "compare_tag_fields");
let same_tag =
env.builder
.build_int_compare(IntPredicate::EQ, id1, id2, "compare_tag_id");
env.builder
.build_conditional_branch(same_tag, compare_tag_fields, return_false);
env.builder.position_at_end(compare_tag_fields);
// switch on all the tag ids
let tags = other_tags;
let mut cases = Vec::with_capacity_in(tags.len(), env.arena);
for (tag_id, field_layouts) in tags.iter().enumerate() {
let block = env.context.append_basic_block(parent, "tag_id_modify");
env.builder.position_at_end(block);
let struct_layout =
layout_interner.insert(Layout::struct_no_name_order(&field_layouts));
let answer = eq_ptr_to_struct(
env,
layout_interner,
layout_ids,
struct_layout,
field_layouts,
tag1,
tag2,
);
env.builder.build_return(Some(&answer));
cases.push((id1.get_type().const_int(tag_id as u64, false), block));
}
env.builder.position_at_end(compare_tag_fields);
let default = cases.pop().unwrap().1;
env.builder.build_switch(id1, default, &cases);
}
NonNullableUnwrapped(field_layouts) => {
let ptr_equal = env.builder.build_int_compare(
IntPredicate::EQ,
env.builder
.build_ptr_to_int(tag1.into_pointer_value(), env.ptr_int(), "pti"),
env.builder
.build_ptr_to_int(tag2.into_pointer_value(), env.ptr_int(), "pti"),
"compare_pointers",
);
let compare_fields = ctx.append_basic_block(parent, "compare_fields");
env.builder
.build_conditional_branch(ptr_equal, return_true, compare_fields);
env.builder.position_at_end(compare_fields);
let struct_layout =
layout_interner.insert(Layout::struct_no_name_order(&field_layouts));
let answer = eq_ptr_to_struct(
env,
layout_interner,
layout_ids,
struct_layout,
field_layouts,
tag1.into_pointer_value(),
tag2.into_pointer_value(),
);
env.builder.build_return(Some(&answer));
}
}
}
fn eq_ptr_to_struct<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_interner: &mut STLayoutInterner<'a>,
layout_ids: &mut LayoutIds<'a>,
struct_layout: InLayout<'a>,
field_layouts: &'a [InLayout<'a>],
tag1: PointerValue<'ctx>,
tag2: PointerValue<'ctx>,
) -> IntValue<'ctx> {
let wrapper_type = basic_type_from_layout(env, layout_interner, struct_layout);
debug_assert!(wrapper_type.is_struct_type());
// cast the opaque pointer to a pointer of the correct shape
let struct1_ptr = env.builder.build_pointer_cast(
tag1,
wrapper_type.ptr_type(AddressSpace::Generic),
"opaque_to_correct",
);
let struct2_ptr = env.builder.build_pointer_cast(
tag2,
wrapper_type.ptr_type(AddressSpace::Generic),
"opaque_to_correct",
);
let struct1 = env
.builder
.new_build_load(wrapper_type, struct1_ptr, "load_struct1")
.into_struct_value();
let struct2 = env
.builder
.new_build_load(wrapper_type, struct2_ptr, "load_struct2")
.into_struct_value();
build_struct_eq(
env,
layout_interner,
layout_ids,
struct_layout,
field_layouts,
struct1,
struct2,
)
.into_int_value()
}
/// ----
fn build_box_eq<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_interner: &mut STLayoutInterner<'a>,
layout_ids: &mut LayoutIds<'a>,
box_layout: InLayout<'a>,
inner_layout: InLayout<'a>,
tag1: BasicValueEnum<'ctx>,
tag2: BasicValueEnum<'ctx>,
) -> BasicValueEnum<'ctx> {
let block = env.builder.get_insert_block().expect("to be in a function");
let di_location = env.builder.get_current_debug_location().unwrap();
let symbol = Symbol::GENERIC_EQ;
let fn_name = layout_ids
.get(symbol, &box_layout)
.to_symbol_string(symbol, &env.interns);
let function = match env.module.get_function(fn_name.as_str()) {
Some(function_value) => function_value,
None => {
let arg_type = basic_type_from_layout(env, layout_interner, box_layout);
let function_value = crate::llvm::refcounting::build_header_help(
env,
&fn_name,
env.context.bool_type().into(),
&[arg_type, arg_type],
);
build_box_eq_help(
env,
layout_interner,
layout_ids,
function_value,
inner_layout,
);
function_value
}
};
env.builder.position_at_end(block);
env.builder
.set_current_debug_location(env.context, di_location);
let call = env
.builder
.build_call(function, &[tag1.into(), tag2.into()], "tag_eq");
call.set_call_convention(FAST_CALL_CONV);
call.try_as_basic_value().left().unwrap()
}
fn build_box_eq_help<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_interner: &mut STLayoutInterner<'a>,
layout_ids: &mut LayoutIds<'a>,
parent: FunctionValue<'ctx>,
inner_layout: InLayout<'a>,
) {
let ctx = env.context;
let builder = env.builder;
{
use inkwell::debug_info::AsDIScope;
let func_scope = parent.get_subprogram().unwrap();
let lexical_block = env.dibuilder.create_lexical_block(
/* scope */ func_scope.as_debug_info_scope(),
/* file */ env.compile_unit.get_file(),
/* line_no */ 0,
/* column_no */ 0,
);
let loc = env.dibuilder.create_debug_location(
ctx,
/* line */ 0,
/* column */ 0,
/* current_scope */ lexical_block.as_debug_info_scope(),
/* inlined_at */ None,
);
builder.set_current_debug_location(ctx, loc);
}
// Add args to scope
let mut it = parent.get_param_iter();
let box1 = it.next().unwrap();
let box2 = it.next().unwrap();
box1.set_name(Symbol::ARG_1.as_str(&env.interns));
box2.set_name(Symbol::ARG_2.as_str(&env.interns));
let entry = ctx.append_basic_block(parent, "entry");
env.builder.position_at_end(entry);
let return_true = ctx.append_basic_block(parent, "return_true");
env.builder.position_at_end(return_true);
env.builder
.build_return(Some(&env.context.bool_type().const_all_ones()));
env.builder.position_at_end(entry);
let ptr_equal = env.builder.build_int_compare(
IntPredicate::EQ,
env.builder
.build_ptr_to_int(box1.into_pointer_value(), env.ptr_int(), "pti"),
env.builder
.build_ptr_to_int(box2.into_pointer_value(), env.ptr_int(), "pti"),
"compare_pointers",
);
let compare_inner_value = ctx.append_basic_block(parent, "compare_inner_value");
env.builder
.build_conditional_branch(ptr_equal, return_true, compare_inner_value);
env.builder.position_at_end(compare_inner_value);
// clear the tag_id so we get a pointer to the actual data
let box1 = box1.into_pointer_value();
let box2 = box2.into_pointer_value();
let value1 = load_roc_value(env, layout_interner, inner_layout, box1, "load_box1");
let value2 = load_roc_value(env, layout_interner, inner_layout, box2, "load_box2");
let is_equal = build_eq(
env,
layout_interner,
layout_ids,
value1,
value2,
inner_layout,
inner_layout,
);
env.builder.build_return(Some(&is_equal));
}