mirror of
https://github.com/roc-lang/roc.git
synced 2025-09-29 06:44:46 +00:00
add ==
and !=
for lists
This commit is contained in:
parent
23ed281345
commit
a7cf98df9b
6 changed files with 328 additions and 11 deletions
|
@ -2750,7 +2750,15 @@ fn run_low_level<'a, 'ctx, 'env>(
|
|||
|
||||
let (elem, elem_layout) = load_symbol_and_layout(env, scope, &args[1]);
|
||||
|
||||
list_contains(env, parent, elem, elem_layout, list, list_layout)
|
||||
list_contains(
|
||||
env,
|
||||
layout_ids,
|
||||
parent,
|
||||
elem,
|
||||
elem_layout,
|
||||
list,
|
||||
list_layout,
|
||||
)
|
||||
}
|
||||
ListWalk => {
|
||||
debug_assert_eq!(args.len(), 3);
|
||||
|
@ -2975,7 +2983,7 @@ fn run_low_level<'a, 'ctx, 'env>(
|
|||
let (lhs_arg, lhs_layout) = load_symbol_and_layout(env, scope, &args[0]);
|
||||
let (rhs_arg, rhs_layout) = load_symbol_and_layout(env, scope, &args[1]);
|
||||
|
||||
build_eq(env, lhs_arg, rhs_arg, lhs_layout, rhs_layout)
|
||||
build_eq(env, layout_ids, lhs_arg, rhs_arg, lhs_layout, rhs_layout)
|
||||
}
|
||||
NotEq => {
|
||||
debug_assert_eq!(args.len(), 2);
|
||||
|
@ -2983,7 +2991,7 @@ fn run_low_level<'a, 'ctx, 'env>(
|
|||
let (lhs_arg, lhs_layout) = load_symbol_and_layout(env, scope, &args[0]);
|
||||
let (rhs_arg, rhs_layout) = load_symbol_and_layout(env, scope, &args[1]);
|
||||
|
||||
build_neq(env, lhs_arg, rhs_arg, lhs_layout, rhs_layout)
|
||||
build_neq(env, layout_ids, lhs_arg, rhs_arg, lhs_layout, rhs_layout)
|
||||
}
|
||||
And => {
|
||||
// The (&&) operator
|
||||
|
|
|
@ -1003,6 +1003,7 @@ pub fn list_walk_backwards<'a, 'ctx, 'env>(
|
|||
/// List.contains : List elem, elem -> Bool
|
||||
pub fn list_contains<'a, 'ctx, 'env>(
|
||||
env: &Env<'a, 'ctx, 'env>,
|
||||
layout_ids: &mut LayoutIds<'a>,
|
||||
parent: FunctionValue<'ctx>,
|
||||
elem: BasicValueEnum<'ctx>,
|
||||
elem_layout: &Layout<'a>,
|
||||
|
@ -1034,6 +1035,7 @@ pub fn list_contains<'a, 'ctx, 'env>(
|
|||
|
||||
list_contains_help(
|
||||
env,
|
||||
layout_ids,
|
||||
parent,
|
||||
length,
|
||||
list_ptr,
|
||||
|
@ -1045,6 +1047,7 @@ pub fn list_contains<'a, 'ctx, 'env>(
|
|||
|
||||
pub fn list_contains_help<'a, 'ctx, 'env>(
|
||||
env: &Env<'a, 'ctx, 'env>,
|
||||
layout_ids: &mut LayoutIds<'a>,
|
||||
parent: FunctionValue<'ctx>,
|
||||
length: IntValue<'ctx>,
|
||||
source_ptr: PointerValue<'ctx>,
|
||||
|
@ -1082,7 +1085,14 @@ pub fn list_contains_help<'a, 'ctx, 'env>(
|
|||
|
||||
let current_elem = builder.build_load(current_elem_ptr, "load_elem");
|
||||
|
||||
let has_found = build_eq(env, current_elem, elem, list_elem_layout, elem_layout);
|
||||
let has_found = build_eq(
|
||||
env,
|
||||
layout_ids,
|
||||
current_elem,
|
||||
elem,
|
||||
list_elem_layout,
|
||||
elem_layout,
|
||||
);
|
||||
|
||||
builder.build_store(bool_alloca, has_found.into_int_value());
|
||||
|
||||
|
|
|
@ -1,11 +1,16 @@
|
|||
use crate::llvm::build::Env;
|
||||
use crate::llvm::build::{set_name, FAST_CALL_CONV};
|
||||
use crate::llvm::build_list::{list_len, load_list_ptr};
|
||||
use crate::llvm::build_str::str_equal;
|
||||
use inkwell::values::{BasicValueEnum, IntValue};
|
||||
use inkwell::{FloatPredicate, IntPredicate};
|
||||
use roc_mono::layout::{Builtin, Layout};
|
||||
use crate::llvm::convert::{basic_type_from_layout, get_ptr_type};
|
||||
use inkwell::values::{BasicValueEnum, FunctionValue, IntValue, StructValue};
|
||||
use inkwell::{AddressSpace, FloatPredicate, IntPredicate};
|
||||
use roc_module::symbol::Symbol;
|
||||
use roc_mono::layout::{Builtin, Layout, LayoutIds};
|
||||
|
||||
pub fn build_eq<'a, 'ctx, 'env>(
|
||||
env: &Env<'a, 'ctx, 'env>,
|
||||
layout_ids: &mut LayoutIds<'a>,
|
||||
lhs_val: BasicValueEnum<'ctx>,
|
||||
rhs_val: BasicValueEnum<'ctx>,
|
||||
lhs_layout: &Layout<'a>,
|
||||
|
@ -45,6 +50,25 @@ pub fn build_eq<'a, 'ctx, 'env>(
|
|||
(Builtin::Float64, Builtin::Float64) => float_cmp(FloatPredicate::OEQ, "eq_f64"),
|
||||
(Builtin::Float32, Builtin::Float32) => float_cmp(FloatPredicate::OEQ, "eq_f32"),
|
||||
(Builtin::Str, Builtin::Str) => str_equal(env, lhs_val, rhs_val),
|
||||
(Builtin::EmptyList, Builtin::EmptyList) => {
|
||||
env.context.bool_type().const_int(1, false).into()
|
||||
}
|
||||
(Builtin::List(_, _), Builtin::EmptyList)
|
||||
| (Builtin::EmptyList, Builtin::List(_, _)) => {
|
||||
unreachable!("the `==` operator makes sure its two arguments have the same type and thus layout")
|
||||
}
|
||||
(Builtin::List(_, elem1), Builtin::List(_, elem2)) => {
|
||||
debug_assert_eq!(elem1, elem2);
|
||||
|
||||
build_list_eq(
|
||||
env,
|
||||
layout_ids,
|
||||
lhs_layout,
|
||||
elem1,
|
||||
lhs_val.into_struct_value(),
|
||||
rhs_val.into_struct_value(),
|
||||
)
|
||||
}
|
||||
(b1, b2) => {
|
||||
todo!("Handle equals for builtin layouts {:?} == {:?}", b1, b2);
|
||||
}
|
||||
|
@ -57,8 +81,206 @@ pub fn build_eq<'a, 'ctx, 'env>(
|
|||
}
|
||||
}
|
||||
|
||||
fn build_list_eq<'a, 'ctx, 'env>(
|
||||
env: &Env<'a, 'ctx, 'env>,
|
||||
layout_ids: &mut LayoutIds<'a>,
|
||||
list_layout: &Layout<'a>,
|
||||
element_layout: &Layout<'a>,
|
||||
list1: StructValue<'ctx>,
|
||||
list2: StructValue<'ctx>,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
let block = env.builder.get_insert_block().expect("to be in a function");
|
||||
let di_location = env.builder.get_current_debug_location().unwrap();
|
||||
|
||||
let symbol = Symbol::LIST_EQ;
|
||||
let fn_name = layout_ids
|
||||
.get(symbol, &element_layout)
|
||||
.to_symbol_string(symbol, &env.interns);
|
||||
|
||||
let function = match env.module.get_function(fn_name.as_str()) {
|
||||
Some(function_value) => function_value,
|
||||
None => {
|
||||
let arena = env.arena;
|
||||
let arg_type = basic_type_from_layout(arena, env.context, &list_layout, env.ptr_bytes);
|
||||
|
||||
let function_value = crate::llvm::refcounting::build_header_help(
|
||||
env,
|
||||
&fn_name,
|
||||
env.context.bool_type().into(),
|
||||
&[arg_type, arg_type],
|
||||
);
|
||||
|
||||
build_list_eq_help(env, layout_ids, function_value, element_layout);
|
||||
|
||||
function_value
|
||||
}
|
||||
};
|
||||
|
||||
env.builder.position_at_end(block);
|
||||
env.builder
|
||||
.set_current_debug_location(env.context, di_location);
|
||||
let call = env
|
||||
.builder
|
||||
.build_call(function, &[list1.into(), list2.into()], "list_eq");
|
||||
|
||||
call.set_call_convention(FAST_CALL_CONV);
|
||||
|
||||
call.try_as_basic_value().left().unwrap()
|
||||
}
|
||||
|
||||
fn build_list_eq_help<'a, 'ctx, 'env>(
|
||||
env: &Env<'a, 'ctx, 'env>,
|
||||
layout_ids: &mut LayoutIds<'a>,
|
||||
parent: FunctionValue<'ctx>,
|
||||
element_layout: &Layout<'a>,
|
||||
) {
|
||||
let ctx = env.context;
|
||||
let builder = env.builder;
|
||||
|
||||
{
|
||||
use inkwell::debug_info::AsDIScope;
|
||||
|
||||
let func_scope = parent.get_subprogram().unwrap();
|
||||
let lexical_block = env.dibuilder.create_lexical_block(
|
||||
/* scope */ func_scope.as_debug_info_scope(),
|
||||
/* file */ env.compile_unit.get_file(),
|
||||
/* line_no */ 0,
|
||||
/* column_no */ 0,
|
||||
);
|
||||
|
||||
let loc = env.dibuilder.create_debug_location(
|
||||
ctx,
|
||||
/* line */ 0,
|
||||
/* column */ 0,
|
||||
/* current_scope */ lexical_block.as_debug_info_scope(),
|
||||
/* inlined_at */ None,
|
||||
);
|
||||
builder.set_current_debug_location(&ctx, loc);
|
||||
}
|
||||
|
||||
// Add args to scope
|
||||
let mut it = parent.get_param_iter();
|
||||
let list1 = it.next().unwrap().into_struct_value();
|
||||
let list2 = it.next().unwrap().into_struct_value();
|
||||
|
||||
set_name(list1.into(), Symbol::ARG_1.ident_string(&env.interns));
|
||||
set_name(list1.into(), Symbol::ARG_2.ident_string(&env.interns));
|
||||
|
||||
let entry = ctx.append_basic_block(parent, "entry");
|
||||
env.builder.position_at_end(entry);
|
||||
|
||||
let return_true = ctx.append_basic_block(parent, "return_true");
|
||||
let return_false = ctx.append_basic_block(parent, "return_false");
|
||||
|
||||
// first, check whether the length is equal
|
||||
|
||||
let len1 = list_len(env.builder, list1);
|
||||
let len2 = list_len(env.builder, list2);
|
||||
|
||||
let length_equal: IntValue =
|
||||
env.builder
|
||||
.build_int_compare(IntPredicate::EQ, len1, len2, "bounds_check");
|
||||
|
||||
let then_block = ctx.append_basic_block(parent, "then");
|
||||
|
||||
env.builder
|
||||
.build_conditional_branch(length_equal, then_block, return_false);
|
||||
|
||||
{
|
||||
// the length is equal; check elements pointwise
|
||||
env.builder.position_at_end(then_block);
|
||||
|
||||
{
|
||||
let builder = env.builder;
|
||||
let element_type =
|
||||
basic_type_from_layout(env.arena, env.context, element_layout, env.ptr_bytes);
|
||||
let ptr_type = get_ptr_type(&element_type, AddressSpace::Generic);
|
||||
let ptr1 = load_list_ptr(env.builder, list1, ptr_type);
|
||||
let ptr2 = load_list_ptr(env.builder, list2, ptr_type);
|
||||
|
||||
// we know that len1 == len2
|
||||
let end = len1;
|
||||
|
||||
// constant 1i64
|
||||
let one = ctx.i64_type().const_int(1, false);
|
||||
|
||||
// allocate a stack slot for the current index
|
||||
let index_alloca = builder.build_alloca(ctx.i64_type(), "index");
|
||||
builder.build_store(index_alloca, ctx.i64_type().const_zero());
|
||||
|
||||
let loop_bb = ctx.append_basic_block(parent, "loop");
|
||||
let body_bb = ctx.append_basic_block(parent, "body");
|
||||
let increment_bb = ctx.append_basic_block(parent, "increment");
|
||||
|
||||
builder.build_unconditional_branch(loop_bb);
|
||||
builder.position_at_end(loop_bb);
|
||||
|
||||
let curr_index = builder.build_load(index_alloca, "index").into_int_value();
|
||||
|
||||
// #index < end
|
||||
let loop_end_cond =
|
||||
builder.build_int_compare(IntPredicate::ULT, curr_index, end, "bounds_check");
|
||||
|
||||
builder.build_conditional_branch(loop_end_cond, body_bb, return_true);
|
||||
|
||||
builder.position_at_end(body_bb);
|
||||
|
||||
{
|
||||
// loop body
|
||||
let elem1 = {
|
||||
let elem_ptr =
|
||||
unsafe { builder.build_in_bounds_gep(ptr1, &[curr_index], "load_index") };
|
||||
builder.build_load(elem_ptr, "get_elem")
|
||||
};
|
||||
|
||||
let elem2 = {
|
||||
let elem_ptr =
|
||||
unsafe { builder.build_in_bounds_gep(ptr2, &[curr_index], "load_index") };
|
||||
builder.build_load(elem_ptr, "get_elem")
|
||||
};
|
||||
|
||||
let are_equal = build_eq(
|
||||
env,
|
||||
layout_ids,
|
||||
elem1,
|
||||
elem2,
|
||||
element_layout,
|
||||
element_layout,
|
||||
)
|
||||
.into_int_value();
|
||||
|
||||
builder.build_conditional_branch(are_equal, increment_bb, return_false);
|
||||
}
|
||||
|
||||
{
|
||||
env.builder.position_at_end(increment_bb);
|
||||
|
||||
let next_index = builder.build_int_add(curr_index, one, "nextindex");
|
||||
|
||||
builder.build_store(index_alloca, next_index);
|
||||
|
||||
// jump back to the top of the loop
|
||||
builder.build_unconditional_branch(loop_bb);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
env.builder.position_at_end(return_true);
|
||||
env.builder
|
||||
.build_return(Some(&env.context.bool_type().const_int(1, false)));
|
||||
}
|
||||
|
||||
{
|
||||
env.builder.position_at_end(return_false);
|
||||
env.builder
|
||||
.build_return(Some(&env.context.bool_type().const_int(0, false)));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn build_neq<'a, 'ctx, 'env>(
|
||||
env: &Env<'a, 'ctx, 'env>,
|
||||
layout_ids: &mut LayoutIds<'a>,
|
||||
lhs_val: BasicValueEnum<'ctx>,
|
||||
rhs_val: BasicValueEnum<'ctx>,
|
||||
lhs_layout: &Layout<'a>,
|
||||
|
@ -103,6 +325,29 @@ pub fn build_neq<'a, 'ctx, 'env>(
|
|||
|
||||
result.into()
|
||||
}
|
||||
(Builtin::EmptyList, Builtin::EmptyList) => {
|
||||
env.context.bool_type().const_int(0, false).into()
|
||||
}
|
||||
(Builtin::List(_, _), Builtin::EmptyList)
|
||||
| (Builtin::EmptyList, Builtin::List(_, _)) => {
|
||||
unreachable!("the `==` operator makes sure its two arguments have the same type and thus layout")
|
||||
}
|
||||
(Builtin::List(_, elem1), Builtin::List(_, elem2)) => {
|
||||
debug_assert_eq!(elem1, elem2);
|
||||
|
||||
let equal = build_list_eq(
|
||||
env,
|
||||
layout_ids,
|
||||
lhs_layout,
|
||||
elem1,
|
||||
lhs_val.into_struct_value(),
|
||||
rhs_val.into_struct_value(),
|
||||
);
|
||||
|
||||
let not_equal: IntValue = env.builder.build_not(equal.into_int_value(), "not");
|
||||
|
||||
not_equal.into()
|
||||
}
|
||||
(b1, b2) => {
|
||||
todo!("Handle not equals for builtin layouts {:?} == {:?}", b1, b2);
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ use bumpalo::collections::Vec;
|
|||
use inkwell::context::Context;
|
||||
use inkwell::debug_info::AsDIScope;
|
||||
use inkwell::module::Linkage;
|
||||
use inkwell::types::{AnyTypeEnum, BasicTypeEnum};
|
||||
use inkwell::values::{BasicValueEnum, FunctionValue, IntValue, PointerValue, StructValue};
|
||||
use inkwell::{AddressSpace, IntPredicate};
|
||||
use roc_module::symbol::Symbol;
|
||||
|
@ -993,12 +994,28 @@ pub fn build_header<'a, 'ctx, 'env>(
|
|||
fn_name: &str,
|
||||
) -> FunctionValue<'ctx> {
|
||||
let arena = env.arena;
|
||||
let context = &env.context;
|
||||
|
||||
let arg_type = basic_type_from_layout(arena, env.context, &layout, env.ptr_bytes);
|
||||
build_header_help(env, fn_name, env.context.void_type().into(), &[arg_type])
|
||||
}
|
||||
|
||||
// inc and dec return void
|
||||
let fn_type = context.void_type().fn_type(&[arg_type], false);
|
||||
/// Build an increment or decrement function for a specific layout
|
||||
pub fn build_header_help<'a, 'ctx, 'env>(
|
||||
env: &Env<'a, 'ctx, 'env>,
|
||||
fn_name: &str,
|
||||
return_type: AnyTypeEnum<'ctx>,
|
||||
arguments: &[BasicTypeEnum<'ctx>],
|
||||
) -> FunctionValue<'ctx> {
|
||||
use inkwell::types::AnyTypeEnum::*;
|
||||
let fn_type = match return_type {
|
||||
ArrayType(t) => t.fn_type(arguments, false),
|
||||
FloatType(t) => t.fn_type(arguments, false),
|
||||
FunctionType(_) => unreachable!("functions cannot return functions"),
|
||||
IntType(t) => t.fn_type(arguments, false),
|
||||
PointerType(t) => t.fn_type(arguments, false),
|
||||
StructType(t) => t.fn_type(arguments, false),
|
||||
VectorType(t) => t.fn_type(arguments, false),
|
||||
VoidType(t) => t.fn_type(arguments, false),
|
||||
};
|
||||
|
||||
let fn_val = env
|
||||
.module
|
||||
|
|
|
@ -1703,4 +1703,40 @@ mod gen_list {
|
|||
assert_evals_to!("List.sum [ 1, 2, 3 ]", 6, i64);
|
||||
assert_evals_to!("List.sum [ 1.1, 2.2, 3.3 ]", 6.6, f64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_eq_empty() {
|
||||
assert_evals_to!("[] == []", true, bool);
|
||||
assert_evals_to!("[] != []", false, bool);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_eq_by_length() {
|
||||
assert_evals_to!("[1] == []", false, bool);
|
||||
assert_evals_to!("[] == [1]", false, bool);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_eq_compare_pointwise() {
|
||||
assert_evals_to!("[1] == [1]", true, bool);
|
||||
assert_evals_to!("[2] == [1]", false, bool);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_eq_nested() {
|
||||
assert_evals_to!("[[1]] == [[1]]", true, bool);
|
||||
assert_evals_to!("[[2]] == [[1]]", false, bool);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_neq_compare_pointwise() {
|
||||
assert_evals_to!("[1] != [1]", false, bool);
|
||||
assert_evals_to!("[2] != [1]", true, bool);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_neq_nested() {
|
||||
assert_evals_to!("[[1]] != [[1]]", false, bool);
|
||||
assert_evals_to!("[[2]] != [[1]]", true, bool);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -738,6 +738,7 @@ define_builtins! {
|
|||
10 INC: "#inc" // internal function that increments the refcount
|
||||
11 DEC: "#dec" // internal function that increments the refcount
|
||||
12 ARG_CLOSURE: "#arg_closure" // symbol used to store the closure record
|
||||
13 LIST_EQ: "#list_eq" // internal function that checks list equality
|
||||
}
|
||||
1 NUM: "Num" => {
|
||||
0 NUM_NUM: "Num" imported // the Num.Num type alias
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue