add == and != for lists

This commit is contained in:
Folkert 2021-01-03 20:09:28 +01:00
parent 23ed281345
commit a7cf98df9b
6 changed files with 328 additions and 11 deletions

View file

@ -2750,7 +2750,15 @@ fn run_low_level<'a, 'ctx, 'env>(
let (elem, elem_layout) = load_symbol_and_layout(env, scope, &args[1]); let (elem, elem_layout) = load_symbol_and_layout(env, scope, &args[1]);
list_contains(env, parent, elem, elem_layout, list, list_layout) list_contains(
env,
layout_ids,
parent,
elem,
elem_layout,
list,
list_layout,
)
} }
ListWalk => { ListWalk => {
debug_assert_eq!(args.len(), 3); debug_assert_eq!(args.len(), 3);
@ -2975,7 +2983,7 @@ fn run_low_level<'a, 'ctx, 'env>(
let (lhs_arg, lhs_layout) = load_symbol_and_layout(env, scope, &args[0]); let (lhs_arg, lhs_layout) = load_symbol_and_layout(env, scope, &args[0]);
let (rhs_arg, rhs_layout) = load_symbol_and_layout(env, scope, &args[1]); let (rhs_arg, rhs_layout) = load_symbol_and_layout(env, scope, &args[1]);
build_eq(env, lhs_arg, rhs_arg, lhs_layout, rhs_layout) build_eq(env, layout_ids, lhs_arg, rhs_arg, lhs_layout, rhs_layout)
} }
NotEq => { NotEq => {
debug_assert_eq!(args.len(), 2); debug_assert_eq!(args.len(), 2);
@ -2983,7 +2991,7 @@ fn run_low_level<'a, 'ctx, 'env>(
let (lhs_arg, lhs_layout) = load_symbol_and_layout(env, scope, &args[0]); let (lhs_arg, lhs_layout) = load_symbol_and_layout(env, scope, &args[0]);
let (rhs_arg, rhs_layout) = load_symbol_and_layout(env, scope, &args[1]); let (rhs_arg, rhs_layout) = load_symbol_and_layout(env, scope, &args[1]);
build_neq(env, lhs_arg, rhs_arg, lhs_layout, rhs_layout) build_neq(env, layout_ids, lhs_arg, rhs_arg, lhs_layout, rhs_layout)
} }
And => { And => {
// The (&&) operator // The (&&) operator

View file

@ -1003,6 +1003,7 @@ pub fn list_walk_backwards<'a, 'ctx, 'env>(
/// List.contains : List elem, elem -> Bool /// List.contains : List elem, elem -> Bool
pub fn list_contains<'a, 'ctx, 'env>( pub fn list_contains<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>, env: &Env<'a, 'ctx, 'env>,
layout_ids: &mut LayoutIds<'a>,
parent: FunctionValue<'ctx>, parent: FunctionValue<'ctx>,
elem: BasicValueEnum<'ctx>, elem: BasicValueEnum<'ctx>,
elem_layout: &Layout<'a>, elem_layout: &Layout<'a>,
@ -1034,6 +1035,7 @@ pub fn list_contains<'a, 'ctx, 'env>(
list_contains_help( list_contains_help(
env, env,
layout_ids,
parent, parent,
length, length,
list_ptr, list_ptr,
@ -1045,6 +1047,7 @@ pub fn list_contains<'a, 'ctx, 'env>(
pub fn list_contains_help<'a, 'ctx, 'env>( pub fn list_contains_help<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>, env: &Env<'a, 'ctx, 'env>,
layout_ids: &mut LayoutIds<'a>,
parent: FunctionValue<'ctx>, parent: FunctionValue<'ctx>,
length: IntValue<'ctx>, length: IntValue<'ctx>,
source_ptr: PointerValue<'ctx>, source_ptr: PointerValue<'ctx>,
@ -1082,7 +1085,14 @@ pub fn list_contains_help<'a, 'ctx, 'env>(
let current_elem = builder.build_load(current_elem_ptr, "load_elem"); let current_elem = builder.build_load(current_elem_ptr, "load_elem");
let has_found = build_eq(env, current_elem, elem, list_elem_layout, elem_layout); let has_found = build_eq(
env,
layout_ids,
current_elem,
elem,
list_elem_layout,
elem_layout,
);
builder.build_store(bool_alloca, has_found.into_int_value()); builder.build_store(bool_alloca, has_found.into_int_value());

View file

@ -1,11 +1,16 @@
use crate::llvm::build::Env; use crate::llvm::build::Env;
use crate::llvm::build::{set_name, FAST_CALL_CONV};
use crate::llvm::build_list::{list_len, load_list_ptr};
use crate::llvm::build_str::str_equal; use crate::llvm::build_str::str_equal;
use inkwell::values::{BasicValueEnum, IntValue}; use crate::llvm::convert::{basic_type_from_layout, get_ptr_type};
use inkwell::{FloatPredicate, IntPredicate}; use inkwell::values::{BasicValueEnum, FunctionValue, IntValue, StructValue};
use roc_mono::layout::{Builtin, Layout}; use inkwell::{AddressSpace, FloatPredicate, IntPredicate};
use roc_module::symbol::Symbol;
use roc_mono::layout::{Builtin, Layout, LayoutIds};
pub fn build_eq<'a, 'ctx, 'env>( pub fn build_eq<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>, env: &Env<'a, 'ctx, 'env>,
layout_ids: &mut LayoutIds<'a>,
lhs_val: BasicValueEnum<'ctx>, lhs_val: BasicValueEnum<'ctx>,
rhs_val: BasicValueEnum<'ctx>, rhs_val: BasicValueEnum<'ctx>,
lhs_layout: &Layout<'a>, lhs_layout: &Layout<'a>,
@ -45,6 +50,25 @@ pub fn build_eq<'a, 'ctx, 'env>(
(Builtin::Float64, Builtin::Float64) => float_cmp(FloatPredicate::OEQ, "eq_f64"), (Builtin::Float64, Builtin::Float64) => float_cmp(FloatPredicate::OEQ, "eq_f64"),
(Builtin::Float32, Builtin::Float32) => float_cmp(FloatPredicate::OEQ, "eq_f32"), (Builtin::Float32, Builtin::Float32) => float_cmp(FloatPredicate::OEQ, "eq_f32"),
(Builtin::Str, Builtin::Str) => str_equal(env, lhs_val, rhs_val), (Builtin::Str, Builtin::Str) => str_equal(env, lhs_val, rhs_val),
(Builtin::EmptyList, Builtin::EmptyList) => {
env.context.bool_type().const_int(1, false).into()
}
(Builtin::List(_, _), Builtin::EmptyList)
| (Builtin::EmptyList, Builtin::List(_, _)) => {
unreachable!("the `==` operator makes sure its two arguments have the same type and thus layout")
}
(Builtin::List(_, elem1), Builtin::List(_, elem2)) => {
debug_assert_eq!(elem1, elem2);
build_list_eq(
env,
layout_ids,
lhs_layout,
elem1,
lhs_val.into_struct_value(),
rhs_val.into_struct_value(),
)
}
(b1, b2) => { (b1, b2) => {
todo!("Handle equals for builtin layouts {:?} == {:?}", b1, b2); todo!("Handle equals for builtin layouts {:?} == {:?}", b1, b2);
} }
@ -57,8 +81,206 @@ pub fn build_eq<'a, 'ctx, 'env>(
} }
} }
fn build_list_eq<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_ids: &mut LayoutIds<'a>,
list_layout: &Layout<'a>,
element_layout: &Layout<'a>,
list1: StructValue<'ctx>,
list2: StructValue<'ctx>,
) -> BasicValueEnum<'ctx> {
let block = env.builder.get_insert_block().expect("to be in a function");
let di_location = env.builder.get_current_debug_location().unwrap();
let symbol = Symbol::LIST_EQ;
let fn_name = layout_ids
.get(symbol, &element_layout)
.to_symbol_string(symbol, &env.interns);
let function = match env.module.get_function(fn_name.as_str()) {
Some(function_value) => function_value,
None => {
let arena = env.arena;
let arg_type = basic_type_from_layout(arena, env.context, &list_layout, env.ptr_bytes);
let function_value = crate::llvm::refcounting::build_header_help(
env,
&fn_name,
env.context.bool_type().into(),
&[arg_type, arg_type],
);
build_list_eq_help(env, layout_ids, function_value, element_layout);
function_value
}
};
env.builder.position_at_end(block);
env.builder
.set_current_debug_location(env.context, di_location);
let call = env
.builder
.build_call(function, &[list1.into(), list2.into()], "list_eq");
call.set_call_convention(FAST_CALL_CONV);
call.try_as_basic_value().left().unwrap()
}
fn build_list_eq_help<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_ids: &mut LayoutIds<'a>,
parent: FunctionValue<'ctx>,
element_layout: &Layout<'a>,
) {
let ctx = env.context;
let builder = env.builder;
{
use inkwell::debug_info::AsDIScope;
let func_scope = parent.get_subprogram().unwrap();
let lexical_block = env.dibuilder.create_lexical_block(
/* scope */ func_scope.as_debug_info_scope(),
/* file */ env.compile_unit.get_file(),
/* line_no */ 0,
/* column_no */ 0,
);
let loc = env.dibuilder.create_debug_location(
ctx,
/* line */ 0,
/* column */ 0,
/* current_scope */ lexical_block.as_debug_info_scope(),
/* inlined_at */ None,
);
builder.set_current_debug_location(&ctx, loc);
}
// Add args to scope
let mut it = parent.get_param_iter();
let list1 = it.next().unwrap().into_struct_value();
let list2 = it.next().unwrap().into_struct_value();
set_name(list1.into(), Symbol::ARG_1.ident_string(&env.interns));
set_name(list1.into(), Symbol::ARG_2.ident_string(&env.interns));
let entry = ctx.append_basic_block(parent, "entry");
env.builder.position_at_end(entry);
let return_true = ctx.append_basic_block(parent, "return_true");
let return_false = ctx.append_basic_block(parent, "return_false");
// first, check whether the length is equal
let len1 = list_len(env.builder, list1);
let len2 = list_len(env.builder, list2);
let length_equal: IntValue =
env.builder
.build_int_compare(IntPredicate::EQ, len1, len2, "bounds_check");
let then_block = ctx.append_basic_block(parent, "then");
env.builder
.build_conditional_branch(length_equal, then_block, return_false);
{
// the length is equal; check elements pointwise
env.builder.position_at_end(then_block);
{
let builder = env.builder;
let element_type =
basic_type_from_layout(env.arena, env.context, element_layout, env.ptr_bytes);
let ptr_type = get_ptr_type(&element_type, AddressSpace::Generic);
let ptr1 = load_list_ptr(env.builder, list1, ptr_type);
let ptr2 = load_list_ptr(env.builder, list2, ptr_type);
// we know that len1 == len2
let end = len1;
// constant 1i64
let one = ctx.i64_type().const_int(1, false);
// allocate a stack slot for the current index
let index_alloca = builder.build_alloca(ctx.i64_type(), "index");
builder.build_store(index_alloca, ctx.i64_type().const_zero());
let loop_bb = ctx.append_basic_block(parent, "loop");
let body_bb = ctx.append_basic_block(parent, "body");
let increment_bb = ctx.append_basic_block(parent, "increment");
builder.build_unconditional_branch(loop_bb);
builder.position_at_end(loop_bb);
let curr_index = builder.build_load(index_alloca, "index").into_int_value();
// #index < end
let loop_end_cond =
builder.build_int_compare(IntPredicate::ULT, curr_index, end, "bounds_check");
builder.build_conditional_branch(loop_end_cond, body_bb, return_true);
builder.position_at_end(body_bb);
{
// loop body
let elem1 = {
let elem_ptr =
unsafe { builder.build_in_bounds_gep(ptr1, &[curr_index], "load_index") };
builder.build_load(elem_ptr, "get_elem")
};
let elem2 = {
let elem_ptr =
unsafe { builder.build_in_bounds_gep(ptr2, &[curr_index], "load_index") };
builder.build_load(elem_ptr, "get_elem")
};
let are_equal = build_eq(
env,
layout_ids,
elem1,
elem2,
element_layout,
element_layout,
)
.into_int_value();
builder.build_conditional_branch(are_equal, increment_bb, return_false);
}
{
env.builder.position_at_end(increment_bb);
let next_index = builder.build_int_add(curr_index, one, "nextindex");
builder.build_store(index_alloca, next_index);
// jump back to the top of the loop
builder.build_unconditional_branch(loop_bb);
}
}
}
{
env.builder.position_at_end(return_true);
env.builder
.build_return(Some(&env.context.bool_type().const_int(1, false)));
}
{
env.builder.position_at_end(return_false);
env.builder
.build_return(Some(&env.context.bool_type().const_int(0, false)));
}
}
pub fn build_neq<'a, 'ctx, 'env>( pub fn build_neq<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>, env: &Env<'a, 'ctx, 'env>,
layout_ids: &mut LayoutIds<'a>,
lhs_val: BasicValueEnum<'ctx>, lhs_val: BasicValueEnum<'ctx>,
rhs_val: BasicValueEnum<'ctx>, rhs_val: BasicValueEnum<'ctx>,
lhs_layout: &Layout<'a>, lhs_layout: &Layout<'a>,
@ -103,6 +325,29 @@ pub fn build_neq<'a, 'ctx, 'env>(
result.into() result.into()
} }
(Builtin::EmptyList, Builtin::EmptyList) => {
env.context.bool_type().const_int(0, false).into()
}
(Builtin::List(_, _), Builtin::EmptyList)
| (Builtin::EmptyList, Builtin::List(_, _)) => {
unreachable!("the `==` operator makes sure its two arguments have the same type and thus layout")
}
(Builtin::List(_, elem1), Builtin::List(_, elem2)) => {
debug_assert_eq!(elem1, elem2);
let equal = build_list_eq(
env,
layout_ids,
lhs_layout,
elem1,
lhs_val.into_struct_value(),
rhs_val.into_struct_value(),
);
let not_equal: IntValue = env.builder.build_not(equal.into_int_value(), "not");
not_equal.into()
}
(b1, b2) => { (b1, b2) => {
todo!("Handle not equals for builtin layouts {:?} == {:?}", b1, b2); todo!("Handle not equals for builtin layouts {:?} == {:?}", b1, b2);
} }

View file

@ -8,6 +8,7 @@ use bumpalo::collections::Vec;
use inkwell::context::Context; use inkwell::context::Context;
use inkwell::debug_info::AsDIScope; use inkwell::debug_info::AsDIScope;
use inkwell::module::Linkage; use inkwell::module::Linkage;
use inkwell::types::{AnyTypeEnum, BasicTypeEnum};
use inkwell::values::{BasicValueEnum, FunctionValue, IntValue, PointerValue, StructValue}; use inkwell::values::{BasicValueEnum, FunctionValue, IntValue, PointerValue, StructValue};
use inkwell::{AddressSpace, IntPredicate}; use inkwell::{AddressSpace, IntPredicate};
use roc_module::symbol::Symbol; use roc_module::symbol::Symbol;
@ -993,12 +994,28 @@ pub fn build_header<'a, 'ctx, 'env>(
fn_name: &str, fn_name: &str,
) -> FunctionValue<'ctx> { ) -> FunctionValue<'ctx> {
let arena = env.arena; let arena = env.arena;
let context = &env.context;
let arg_type = basic_type_from_layout(arena, env.context, &layout, env.ptr_bytes); let arg_type = basic_type_from_layout(arena, env.context, &layout, env.ptr_bytes);
build_header_help(env, fn_name, env.context.void_type().into(), &[arg_type])
}
// inc and dec return void /// Build an increment or decrement function for a specific layout
let fn_type = context.void_type().fn_type(&[arg_type], false); pub fn build_header_help<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
fn_name: &str,
return_type: AnyTypeEnum<'ctx>,
arguments: &[BasicTypeEnum<'ctx>],
) -> FunctionValue<'ctx> {
use inkwell::types::AnyTypeEnum::*;
let fn_type = match return_type {
ArrayType(t) => t.fn_type(arguments, false),
FloatType(t) => t.fn_type(arguments, false),
FunctionType(_) => unreachable!("functions cannot return functions"),
IntType(t) => t.fn_type(arguments, false),
PointerType(t) => t.fn_type(arguments, false),
StructType(t) => t.fn_type(arguments, false),
VectorType(t) => t.fn_type(arguments, false),
VoidType(t) => t.fn_type(arguments, false),
};
let fn_val = env let fn_val = env
.module .module

View file

@ -1703,4 +1703,40 @@ mod gen_list {
assert_evals_to!("List.sum [ 1, 2, 3 ]", 6, i64); assert_evals_to!("List.sum [ 1, 2, 3 ]", 6, i64);
assert_evals_to!("List.sum [ 1.1, 2.2, 3.3 ]", 6.6, f64); assert_evals_to!("List.sum [ 1.1, 2.2, 3.3 ]", 6.6, f64);
} }
#[test]
fn list_eq_empty() {
assert_evals_to!("[] == []", true, bool);
assert_evals_to!("[] != []", false, bool);
}
#[test]
fn list_eq_by_length() {
assert_evals_to!("[1] == []", false, bool);
assert_evals_to!("[] == [1]", false, bool);
}
#[test]
fn list_eq_compare_pointwise() {
assert_evals_to!("[1] == [1]", true, bool);
assert_evals_to!("[2] == [1]", false, bool);
}
#[test]
fn list_eq_nested() {
assert_evals_to!("[[1]] == [[1]]", true, bool);
assert_evals_to!("[[2]] == [[1]]", false, bool);
}
#[test]
fn list_neq_compare_pointwise() {
assert_evals_to!("[1] != [1]", false, bool);
assert_evals_to!("[2] != [1]", true, bool);
}
#[test]
fn list_neq_nested() {
assert_evals_to!("[[1]] != [[1]]", false, bool);
assert_evals_to!("[[2]] != [[1]]", true, bool);
}
} }

View file

@ -738,6 +738,7 @@ define_builtins! {
10 INC: "#inc" // internal function that increments the refcount 10 INC: "#inc" // internal function that increments the refcount
11 DEC: "#dec" // internal function that increments the refcount 11 DEC: "#dec" // internal function that increments the refcount
12 ARG_CLOSURE: "#arg_closure" // symbol used to store the closure record 12 ARG_CLOSURE: "#arg_closure" // symbol used to store the closure record
13 LIST_EQ: "#list_eq" // internal function that checks list equality
} }
1 NUM: "Num" => { 1 NUM: "Num" => {
0 NUM_NUM: "Num" imported // the Num.Num type alias 0 NUM_NUM: "Num" imported // the Num.Num type alias