Generate IR for list equality

This commit is contained in:
Brian Carroll 2021-12-21 09:40:59 +00:00
parent a2ada314ce
commit d58a2814f6

View file

@ -6,8 +6,8 @@ use roc_module::low_level::LowLevel;
use roc_module::symbol::{IdentIds, ModuleId, Symbol};
use crate::ir::{
BranchInfo, Call, CallSpecId, CallType, Expr, HostExposedLayouts, Literal, ModifyRc, Proc,
ProcLayout, SelfRecursive, Stmt, UpdateModeId,
BranchInfo, Call, CallSpecId, CallType, Expr, HostExposedLayouts, JoinPointId, Literal,
ModifyRc, Param, Proc, ProcLayout, SelfRecursive, Stmt, UpdateModeId,
};
use crate::layout::{Builtin, Layout, UnionLayout};
@ -269,7 +269,7 @@ impl<'a> CodeGenHelp<'a> {
new_procs_info.push((symbol, proc_layout));
let mut visit_child = |child| {
if layout_needs_helper_proc(child, op) {
if layout_needs_helper_proc(child, op) && child.stack_size(self.ptr_size) > 0 {
self.get_or_create_proc_symbols_visit(ident_ids, new_procs_info, op, child);
}
};
@ -409,8 +409,12 @@ impl<'a> CodeGenHelp<'a> {
&self,
op: HelperOp,
sub_layout: &Layout<'a>,
arguments: &'a [Symbol],
arguments: &[Symbol],
) -> Expr<'a> {
if matches!(op, HelperOp::Eq) && sub_layout.stack_size(self.ptr_size) == 0 {
return Expr::Literal(Literal::Int(1));
}
let found = self
.specs
.iter()
@ -436,7 +440,7 @@ impl<'a> CodeGenHelp<'a> {
arg_layouts,
specialization_id: CallSpecId::BACKEND_DUMMY,
},
arguments,
arguments: self.arena.alloc_slice_copy(arguments),
})
} else {
// By the time we get here (generating helper procs), the list of specializations is complete.
@ -455,7 +459,7 @@ impl<'a> CodeGenHelp<'a> {
op: lowlevel,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments,
arguments: self.arena.alloc_slice_copy(arguments),
})
}
}
@ -691,7 +695,8 @@ impl<'a> CodeGenHelp<'a> {
Layout::Builtin(Builtin::Str) => {
unreachable!("No generated helper proc for `==` on Str. Use Zig function.")
}
Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_) | Builtin::List(_)) => eq_todo(),
Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_)) => eq_todo(),
Layout::Builtin(Builtin::List(elem_layout)) => self.eq_list(ident_ids, elem_layout),
Layout::Struct(field_layouts) => self.eq_struct(ident_ids, field_layouts),
Layout::Union(union_layout) => self.eq_tag_union(ident_ids, union_layout),
Layout::LambdaSet(_) => unreachable!("`==` is not defined on functions"),
@ -808,14 +813,13 @@ impl<'a> CodeGenHelp<'a> {
};
let field2_stmt = |next| Stmt::Let(field2_sym, field2_expr, *layout, next);
let sub_layout_args = self.arena.alloc([field1_sym, field2_sym]);
let sub_layout = match (layout, rec_ptr_layout) {
(Layout::RecursivePointer, Some(rec_layout)) => self.arena.alloc(rec_layout),
_ => layout,
};
let eq_call_expr =
self.apply_op_to_sub_layout(HelperOp::Eq, sub_layout, sub_layout_args);
self.apply_op_to_sub_layout(HelperOp::Eq, sub_layout, &[field1_sym, field2_sym]);
let eq_call_name = format!("eq_call_{}", i);
let eq_call_sym = self.create_symbol(ident_ids, &eq_call_name);
let eq_call_stmt = |next| Stmt::Let(eq_call_sym, eq_call_expr, LAYOUT_BOOL, next);
@ -992,6 +996,281 @@ impl<'a> CodeGenHelp<'a> {
)),
))
}
/// List equality
/// We can't use `ListGetUnsafe` because it increments the refcount, and we don't want that.
/// Another way to dereference a heap pointer is to use `Expr::UnionAtIndex`.
/// To achieve this we use `PtrCast` to cast the element pointer to a "Box" layout.
/// Then we can increment the Box pointer in a loop, dereferencing it each time.
fn eq_list(&self, ident_ids: &mut IdentIds, elem_layout: &Layout<'a>) -> Stmt<'a> {
use LowLevel::*;
let layout_isize = self.layout_isize;
// A "Box" layout (heap pointer to a single list element)
let box_union_layout = UnionLayout::NonNullableUnwrapped(self.arena.alloc([*elem_layout]));
let box_layout = Layout::Union(box_union_layout);
// Compare lengths
let len_1 = self.create_symbol(ident_ids, "len_1");
let len_2 = self.create_symbol(ident_ids, "len_2");
let len_1_stmt = |next| self.let_lowlevel(layout_isize, len_1, ListLen, &[ARG_1], next);
let len_2_stmt = |next| self.let_lowlevel(layout_isize, len_2, ListLen, &[ARG_2], next);
let eq_len = self.create_symbol(ident_ids, "eq_len");
let eq_len_stmt = |next| self.let_lowlevel(LAYOUT_BOOL, eq_len, Eq, &[len_1, len_2], next);
// if lengths are equal...
// get element pointers
let elements_1 = self.create_symbol(ident_ids, "elements_1");
let elements_2 = self.create_symbol(ident_ids, "elements_2");
let elements_1_expr = Expr::StructAtIndex {
index: 0,
field_layouts: self.arena.alloc([box_layout, layout_isize]),
structure: ARG_1,
};
let elements_2_expr = Expr::StructAtIndex {
index: 0,
field_layouts: self.arena.alloc([box_layout, layout_isize]),
structure: ARG_2,
};
let elements_1_stmt = |next| Stmt::Let(elements_1, elements_1_expr, box_layout, next);
let elements_2_stmt = |next| Stmt::Let(elements_2, elements_2_expr, box_layout, next);
// Cast to integers
let start_addr_1 = self.create_symbol(ident_ids, "start_addr_1");
let start_addr_2 = self.create_symbol(ident_ids, "start_addr_2");
let start_addr_1_stmt =
|next| self.let_lowlevel(layout_isize, start_addr_1, PtrCast, &[elements_1], next);
let start_addr_2_stmt =
|next| self.let_lowlevel(layout_isize, start_addr_2, PtrCast, &[elements_2], next);
//
// Loop initialisation
//
// let elem_size = literal int
let elem_size = self.create_symbol(ident_ids, "elem_size");
let elem_size_expr =
Expr::Literal(Literal::Int(elem_layout.stack_size(self.ptr_size) as i128));
let elem_size_stmt = |next| Stmt::Let(elem_size, elem_size_expr, layout_isize, next);
// let list_size = len_1 * elem_size
let list_size = self.create_symbol(ident_ids, "list_size");
let list_size_stmt =
|next| self.let_lowlevel(layout_isize, list_size, NumMul, &[len_1, elem_size], next);
// let end_addr_1 = start_addr_1 + list_size
let end_addr_1 = self.create_symbol(ident_ids, "end_addr_1");
let end_addr_1_stmt = |next| {
self.let_lowlevel(
layout_isize,
end_addr_1,
NumAdd,
&[start_addr_1, list_size],
next,
)
};
//
// Loop name & parameters
//
let elems_loop = JoinPointId(self.create_symbol(ident_ids, "elems_loop"));
let addr1 = self.create_symbol(ident_ids, "addr1");
let addr2 = self.create_symbol(ident_ids, "addr2");
let param_addr1 = Param {
symbol: addr1,
borrow: false,
layout: layout_isize,
};
let param_addr2 = Param {
symbol: addr2,
borrow: false,
layout: layout_isize,
};
//
// if we haven't reached the end yet...
//
// Cast integers to box pointers
let box1 = self.create_symbol(ident_ids, "box1");
let box2 = self.create_symbol(ident_ids, "box2");
let box1_stmt = |next| self.let_lowlevel(box_layout, box1, PtrCast, &[addr1], next);
let box2_stmt = |next| self.let_lowlevel(box_layout, box2, PtrCast, &[addr2], next);
// Dereference the box pointers to get the current elements
let elem1 = self.create_symbol(ident_ids, "elem1");
let elem2 = self.create_symbol(ident_ids, "elem2");
let elem1_expr = Expr::UnionAtIndex {
structure: box1,
union_layout: box_union_layout,
tag_id: 0,
index: 0,
};
let elem2_expr = Expr::UnionAtIndex {
structure: box2,
union_layout: box_union_layout,
tag_id: 0,
index: 0,
};
let elem1_stmt = |next| Stmt::Let(elem1, elem1_expr, *elem_layout, next);
let elem2_stmt = |next| Stmt::Let(elem2, elem2_expr, *elem_layout, next);
// Compare the two current elements
let eq_elems = self.create_symbol(ident_ids, "eq_elems");
let eq_elems_expr = self.apply_op_to_sub_layout(HelperOp::Eq, elem_layout, &[elem1, elem2]);
let eq_elems_stmt = |next| Stmt::Let(eq_elems, eq_elems_expr, LAYOUT_BOOL, next);
// If current elements are equal, loop back again
let next_addr_1 = self.create_symbol(ident_ids, "next_addr_1");
let next_addr_2 = self.create_symbol(ident_ids, "next_addr_2");
let next_addr_1_stmt =
|next| self.let_lowlevel(layout_isize, next_addr_1, NumAdd, &[addr1, elem_size], next);
let next_addr_2_stmt =
|next| self.let_lowlevel(layout_isize, next_addr_2, NumAdd, &[addr2, elem_size], next);
let jump_back = Stmt::Jump(elems_loop, self.arena.alloc([next_addr_1, next_addr_2]));
//
// Control flow
//
let is_end = self.create_symbol(ident_ids, "is_end");
let is_end_stmt =
|next| self.let_lowlevel(LAYOUT_BOOL, is_end, NumGte, &[addr1, end_addr_1], next);
let if_elems_not_equal = self.if_false_return_false(
eq_elems,
// else
self.arena.alloc(
//
next_addr_1_stmt(self.arena.alloc(
//
next_addr_2_stmt(self.arena.alloc(
//
jump_back,
)),
)),
),
);
let if_end_of_list = Stmt::Switch {
cond_symbol: is_end,
cond_layout: LAYOUT_BOOL,
ret_layout: LAYOUT_BOOL,
branches: self
.arena
.alloc([(1, BranchInfo::None, Stmt::Ret(Symbol::BOOL_TRUE))]),
default_branch: (
BranchInfo::None,
self.arena.alloc(
//
box1_stmt(self.arena.alloc(
//
box2_stmt(self.arena.alloc(
//
elem1_stmt(self.arena.alloc(
//
elem2_stmt(self.arena.alloc(
//
eq_elems_stmt(self.arena.alloc(
//
if_elems_not_equal,
)),
)),
)),
)),
)),
),
),
};
let joinpoint_loop = Stmt::Join {
id: elems_loop,
parameters: self.arena.alloc([param_addr1, param_addr2]),
body: self.arena.alloc(
//
is_end_stmt(
//
self.arena.alloc(if_end_of_list),
),
),
remainder: self.arena.alloc(Stmt::Jump(
elems_loop,
self.arena.alloc([start_addr_1, start_addr_2]),
)),
};
let if_different_lengths = self.if_false_return_false(
eq_len,
// else
self.arena.alloc(
//
elements_1_stmt(self.arena.alloc(
//
elements_2_stmt(self.arena.alloc(
//
start_addr_1_stmt(self.arena.alloc(
//
start_addr_2_stmt(self.arena.alloc(
//
elem_size_stmt(self.arena.alloc(
//
list_size_stmt(self.arena.alloc(
//
end_addr_1_stmt(self.arena.alloc(
//
joinpoint_loop,
)),
)),
)),
)),
)),
)),
)),
),
);
let pointers_else = len_1_stmt(self.arena.alloc(
//
len_2_stmt(self.arena.alloc(
//
eq_len_stmt(self.arena.alloc(
//
if_different_lengths,
)),
)),
));
self.if_pointers_equal_return_true(ident_ids, self.arena.alloc(pointers_else))
}
fn let_lowlevel(
&self,
result_layout: Layout<'a>,
result: Symbol,
op: LowLevel,
args: &[Symbol],
next: &'a Stmt<'a>,
) -> Stmt<'a> {
Stmt::Let(
result,
Expr::Call(Call {
call_type: CallType::LowLevel {
op,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: self.arena.alloc_slice_copy(args),
}),
result_layout,
next,
)
}
}
/// Helper to derive a debug function name from a layout