mirror of
https://github.com/roc-lang/roc.git
synced 2025-10-01 07:41:12 +00:00
Generate IR for list equality
This commit is contained in:
parent
a2ada314ce
commit
d58a2814f6
1 changed files with 288 additions and 9 deletions
|
@ -6,8 +6,8 @@ use roc_module::low_level::LowLevel;
|
|||
use roc_module::symbol::{IdentIds, ModuleId, Symbol};
|
||||
|
||||
use crate::ir::{
|
||||
BranchInfo, Call, CallSpecId, CallType, Expr, HostExposedLayouts, Literal, ModifyRc, Proc,
|
||||
ProcLayout, SelfRecursive, Stmt, UpdateModeId,
|
||||
BranchInfo, Call, CallSpecId, CallType, Expr, HostExposedLayouts, JoinPointId, Literal,
|
||||
ModifyRc, Param, Proc, ProcLayout, SelfRecursive, Stmt, UpdateModeId,
|
||||
};
|
||||
use crate::layout::{Builtin, Layout, UnionLayout};
|
||||
|
||||
|
@ -269,7 +269,7 @@ impl<'a> CodeGenHelp<'a> {
|
|||
new_procs_info.push((symbol, proc_layout));
|
||||
|
||||
let mut visit_child = |child| {
|
||||
if layout_needs_helper_proc(child, op) {
|
||||
if layout_needs_helper_proc(child, op) && child.stack_size(self.ptr_size) > 0 {
|
||||
self.get_or_create_proc_symbols_visit(ident_ids, new_procs_info, op, child);
|
||||
}
|
||||
};
|
||||
|
@ -409,8 +409,12 @@ impl<'a> CodeGenHelp<'a> {
|
|||
&self,
|
||||
op: HelperOp,
|
||||
sub_layout: &Layout<'a>,
|
||||
arguments: &'a [Symbol],
|
||||
arguments: &[Symbol],
|
||||
) -> Expr<'a> {
|
||||
if matches!(op, HelperOp::Eq) && sub_layout.stack_size(self.ptr_size) == 0 {
|
||||
return Expr::Literal(Literal::Int(1));
|
||||
}
|
||||
|
||||
let found = self
|
||||
.specs
|
||||
.iter()
|
||||
|
@ -436,7 +440,7 @@ impl<'a> CodeGenHelp<'a> {
|
|||
arg_layouts,
|
||||
specialization_id: CallSpecId::BACKEND_DUMMY,
|
||||
},
|
||||
arguments,
|
||||
arguments: self.arena.alloc_slice_copy(arguments),
|
||||
})
|
||||
} else {
|
||||
// By the time we get here (generating helper procs), the list of specializations is complete.
|
||||
|
@ -455,7 +459,7 @@ impl<'a> CodeGenHelp<'a> {
|
|||
op: lowlevel,
|
||||
update_mode: UpdateModeId::BACKEND_DUMMY,
|
||||
},
|
||||
arguments,
|
||||
arguments: self.arena.alloc_slice_copy(arguments),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -691,7 +695,8 @@ impl<'a> CodeGenHelp<'a> {
|
|||
Layout::Builtin(Builtin::Str) => {
|
||||
unreachable!("No generated helper proc for `==` on Str. Use Zig function.")
|
||||
}
|
||||
Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_) | Builtin::List(_)) => eq_todo(),
|
||||
Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_)) => eq_todo(),
|
||||
Layout::Builtin(Builtin::List(elem_layout)) => self.eq_list(ident_ids, elem_layout),
|
||||
Layout::Struct(field_layouts) => self.eq_struct(ident_ids, field_layouts),
|
||||
Layout::Union(union_layout) => self.eq_tag_union(ident_ids, union_layout),
|
||||
Layout::LambdaSet(_) => unreachable!("`==` is not defined on functions"),
|
||||
|
@ -808,14 +813,13 @@ impl<'a> CodeGenHelp<'a> {
|
|||
};
|
||||
let field2_stmt = |next| Stmt::Let(field2_sym, field2_expr, *layout, next);
|
||||
|
||||
let sub_layout_args = self.arena.alloc([field1_sym, field2_sym]);
|
||||
let sub_layout = match (layout, rec_ptr_layout) {
|
||||
(Layout::RecursivePointer, Some(rec_layout)) => self.arena.alloc(rec_layout),
|
||||
_ => layout,
|
||||
};
|
||||
|
||||
let eq_call_expr =
|
||||
self.apply_op_to_sub_layout(HelperOp::Eq, sub_layout, sub_layout_args);
|
||||
self.apply_op_to_sub_layout(HelperOp::Eq, sub_layout, &[field1_sym, field2_sym]);
|
||||
let eq_call_name = format!("eq_call_{}", i);
|
||||
let eq_call_sym = self.create_symbol(ident_ids, &eq_call_name);
|
||||
let eq_call_stmt = |next| Stmt::Let(eq_call_sym, eq_call_expr, LAYOUT_BOOL, next);
|
||||
|
@ -992,6 +996,281 @@ impl<'a> CodeGenHelp<'a> {
|
|||
)),
|
||||
))
|
||||
}
|
||||
|
||||
/// List equality
|
||||
/// We can't use `ListGetUnsafe` because it increments the refcount, and we don't want that.
|
||||
/// Another way to dereference a heap pointer is to use `Expr::UnionAtIndex`.
|
||||
/// To achieve this we use `PtrCast` to cast the element pointer to a "Box" layout.
|
||||
/// Then we can increment the Box pointer in a loop, dereferencing it each time.
|
||||
fn eq_list(&self, ident_ids: &mut IdentIds, elem_layout: &Layout<'a>) -> Stmt<'a> {
|
||||
use LowLevel::*;
|
||||
let layout_isize = self.layout_isize;
|
||||
|
||||
// A "Box" layout (heap pointer to a single list element)
|
||||
let box_union_layout = UnionLayout::NonNullableUnwrapped(self.arena.alloc([*elem_layout]));
|
||||
let box_layout = Layout::Union(box_union_layout);
|
||||
|
||||
// Compare lengths
|
||||
|
||||
let len_1 = self.create_symbol(ident_ids, "len_1");
|
||||
let len_2 = self.create_symbol(ident_ids, "len_2");
|
||||
let len_1_stmt = |next| self.let_lowlevel(layout_isize, len_1, ListLen, &[ARG_1], next);
|
||||
let len_2_stmt = |next| self.let_lowlevel(layout_isize, len_2, ListLen, &[ARG_2], next);
|
||||
|
||||
let eq_len = self.create_symbol(ident_ids, "eq_len");
|
||||
let eq_len_stmt = |next| self.let_lowlevel(LAYOUT_BOOL, eq_len, Eq, &[len_1, len_2], next);
|
||||
|
||||
// if lengths are equal...
|
||||
|
||||
// get element pointers
|
||||
let elements_1 = self.create_symbol(ident_ids, "elements_1");
|
||||
let elements_2 = self.create_symbol(ident_ids, "elements_2");
|
||||
let elements_1_expr = Expr::StructAtIndex {
|
||||
index: 0,
|
||||
field_layouts: self.arena.alloc([box_layout, layout_isize]),
|
||||
structure: ARG_1,
|
||||
};
|
||||
let elements_2_expr = Expr::StructAtIndex {
|
||||
index: 0,
|
||||
field_layouts: self.arena.alloc([box_layout, layout_isize]),
|
||||
structure: ARG_2,
|
||||
};
|
||||
let elements_1_stmt = |next| Stmt::Let(elements_1, elements_1_expr, box_layout, next);
|
||||
let elements_2_stmt = |next| Stmt::Let(elements_2, elements_2_expr, box_layout, next);
|
||||
|
||||
// Cast to integers
|
||||
let start_addr_1 = self.create_symbol(ident_ids, "start_addr_1");
|
||||
let start_addr_2 = self.create_symbol(ident_ids, "start_addr_2");
|
||||
let start_addr_1_stmt =
|
||||
|next| self.let_lowlevel(layout_isize, start_addr_1, PtrCast, &[elements_1], next);
|
||||
let start_addr_2_stmt =
|
||||
|next| self.let_lowlevel(layout_isize, start_addr_2, PtrCast, &[elements_2], next);
|
||||
|
||||
//
|
||||
// Loop initialisation
|
||||
//
|
||||
|
||||
// let elem_size = literal int
|
||||
let elem_size = self.create_symbol(ident_ids, "elem_size");
|
||||
let elem_size_expr =
|
||||
Expr::Literal(Literal::Int(elem_layout.stack_size(self.ptr_size) as i128));
|
||||
let elem_size_stmt = |next| Stmt::Let(elem_size, elem_size_expr, layout_isize, next);
|
||||
|
||||
// let list_size = len_1 * elem_size
|
||||
let list_size = self.create_symbol(ident_ids, "list_size");
|
||||
let list_size_stmt =
|
||||
|next| self.let_lowlevel(layout_isize, list_size, NumMul, &[len_1, elem_size], next);
|
||||
|
||||
// let end_addr_1 = start_addr_1 + list_size
|
||||
let end_addr_1 = self.create_symbol(ident_ids, "end_addr_1");
|
||||
let end_addr_1_stmt = |next| {
|
||||
self.let_lowlevel(
|
||||
layout_isize,
|
||||
end_addr_1,
|
||||
NumAdd,
|
||||
&[start_addr_1, list_size],
|
||||
next,
|
||||
)
|
||||
};
|
||||
|
||||
//
|
||||
// Loop name & parameters
|
||||
//
|
||||
|
||||
let elems_loop = JoinPointId(self.create_symbol(ident_ids, "elems_loop"));
|
||||
let addr1 = self.create_symbol(ident_ids, "addr1");
|
||||
let addr2 = self.create_symbol(ident_ids, "addr2");
|
||||
|
||||
let param_addr1 = Param {
|
||||
symbol: addr1,
|
||||
borrow: false,
|
||||
layout: layout_isize,
|
||||
};
|
||||
|
||||
let param_addr2 = Param {
|
||||
symbol: addr2,
|
||||
borrow: false,
|
||||
layout: layout_isize,
|
||||
};
|
||||
|
||||
//
|
||||
// if we haven't reached the end yet...
|
||||
//
|
||||
|
||||
// Cast integers to box pointers
|
||||
let box1 = self.create_symbol(ident_ids, "box1");
|
||||
let box2 = self.create_symbol(ident_ids, "box2");
|
||||
let box1_stmt = |next| self.let_lowlevel(box_layout, box1, PtrCast, &[addr1], next);
|
||||
let box2_stmt = |next| self.let_lowlevel(box_layout, box2, PtrCast, &[addr2], next);
|
||||
|
||||
// Dereference the box pointers to get the current elements
|
||||
let elem1 = self.create_symbol(ident_ids, "elem1");
|
||||
let elem2 = self.create_symbol(ident_ids, "elem2");
|
||||
let elem1_expr = Expr::UnionAtIndex {
|
||||
structure: box1,
|
||||
union_layout: box_union_layout,
|
||||
tag_id: 0,
|
||||
index: 0,
|
||||
};
|
||||
let elem2_expr = Expr::UnionAtIndex {
|
||||
structure: box2,
|
||||
union_layout: box_union_layout,
|
||||
tag_id: 0,
|
||||
index: 0,
|
||||
};
|
||||
let elem1_stmt = |next| Stmt::Let(elem1, elem1_expr, *elem_layout, next);
|
||||
let elem2_stmt = |next| Stmt::Let(elem2, elem2_expr, *elem_layout, next);
|
||||
|
||||
// Compare the two current elements
|
||||
let eq_elems = self.create_symbol(ident_ids, "eq_elems");
|
||||
let eq_elems_expr = self.apply_op_to_sub_layout(HelperOp::Eq, elem_layout, &[elem1, elem2]);
|
||||
let eq_elems_stmt = |next| Stmt::Let(eq_elems, eq_elems_expr, LAYOUT_BOOL, next);
|
||||
|
||||
// If current elements are equal, loop back again
|
||||
let next_addr_1 = self.create_symbol(ident_ids, "next_addr_1");
|
||||
let next_addr_2 = self.create_symbol(ident_ids, "next_addr_2");
|
||||
let next_addr_1_stmt =
|
||||
|next| self.let_lowlevel(layout_isize, next_addr_1, NumAdd, &[addr1, elem_size], next);
|
||||
let next_addr_2_stmt =
|
||||
|next| self.let_lowlevel(layout_isize, next_addr_2, NumAdd, &[addr2, elem_size], next);
|
||||
|
||||
let jump_back = Stmt::Jump(elems_loop, self.arena.alloc([next_addr_1, next_addr_2]));
|
||||
|
||||
//
|
||||
// Control flow
|
||||
//
|
||||
|
||||
let is_end = self.create_symbol(ident_ids, "is_end");
|
||||
let is_end_stmt =
|
||||
|next| self.let_lowlevel(LAYOUT_BOOL, is_end, NumGte, &[addr1, end_addr_1], next);
|
||||
|
||||
let if_elems_not_equal = self.if_false_return_false(
|
||||
eq_elems,
|
||||
// else
|
||||
self.arena.alloc(
|
||||
//
|
||||
next_addr_1_stmt(self.arena.alloc(
|
||||
//
|
||||
next_addr_2_stmt(self.arena.alloc(
|
||||
//
|
||||
jump_back,
|
||||
)),
|
||||
)),
|
||||
),
|
||||
);
|
||||
|
||||
let if_end_of_list = Stmt::Switch {
|
||||
cond_symbol: is_end,
|
||||
cond_layout: LAYOUT_BOOL,
|
||||
ret_layout: LAYOUT_BOOL,
|
||||
branches: self
|
||||
.arena
|
||||
.alloc([(1, BranchInfo::None, Stmt::Ret(Symbol::BOOL_TRUE))]),
|
||||
default_branch: (
|
||||
BranchInfo::None,
|
||||
self.arena.alloc(
|
||||
//
|
||||
box1_stmt(self.arena.alloc(
|
||||
//
|
||||
box2_stmt(self.arena.alloc(
|
||||
//
|
||||
elem1_stmt(self.arena.alloc(
|
||||
//
|
||||
elem2_stmt(self.arena.alloc(
|
||||
//
|
||||
eq_elems_stmt(self.arena.alloc(
|
||||
//
|
||||
if_elems_not_equal,
|
||||
)),
|
||||
)),
|
||||
)),
|
||||
)),
|
||||
)),
|
||||
),
|
||||
),
|
||||
};
|
||||
|
||||
let joinpoint_loop = Stmt::Join {
|
||||
id: elems_loop,
|
||||
parameters: self.arena.alloc([param_addr1, param_addr2]),
|
||||
body: self.arena.alloc(
|
||||
//
|
||||
is_end_stmt(
|
||||
//
|
||||
self.arena.alloc(if_end_of_list),
|
||||
),
|
||||
),
|
||||
remainder: self.arena.alloc(Stmt::Jump(
|
||||
elems_loop,
|
||||
self.arena.alloc([start_addr_1, start_addr_2]),
|
||||
)),
|
||||
};
|
||||
|
||||
let if_different_lengths = self.if_false_return_false(
|
||||
eq_len,
|
||||
// else
|
||||
self.arena.alloc(
|
||||
//
|
||||
elements_1_stmt(self.arena.alloc(
|
||||
//
|
||||
elements_2_stmt(self.arena.alloc(
|
||||
//
|
||||
start_addr_1_stmt(self.arena.alloc(
|
||||
//
|
||||
start_addr_2_stmt(self.arena.alloc(
|
||||
//
|
||||
elem_size_stmt(self.arena.alloc(
|
||||
//
|
||||
list_size_stmt(self.arena.alloc(
|
||||
//
|
||||
end_addr_1_stmt(self.arena.alloc(
|
||||
//
|
||||
joinpoint_loop,
|
||||
)),
|
||||
)),
|
||||
)),
|
||||
)),
|
||||
)),
|
||||
)),
|
||||
)),
|
||||
),
|
||||
);
|
||||
|
||||
let pointers_else = len_1_stmt(self.arena.alloc(
|
||||
//
|
||||
len_2_stmt(self.arena.alloc(
|
||||
//
|
||||
eq_len_stmt(self.arena.alloc(
|
||||
//
|
||||
if_different_lengths,
|
||||
)),
|
||||
)),
|
||||
));
|
||||
|
||||
self.if_pointers_equal_return_true(ident_ids, self.arena.alloc(pointers_else))
|
||||
}
|
||||
|
||||
fn let_lowlevel(
|
||||
&self,
|
||||
result_layout: Layout<'a>,
|
||||
result: Symbol,
|
||||
op: LowLevel,
|
||||
args: &[Symbol],
|
||||
next: &'a Stmt<'a>,
|
||||
) -> Stmt<'a> {
|
||||
Stmt::Let(
|
||||
result,
|
||||
Expr::Call(Call {
|
||||
call_type: CallType::LowLevel {
|
||||
op,
|
||||
update_mode: UpdateModeId::BACKEND_DUMMY,
|
||||
},
|
||||
arguments: self.arena.alloc_slice_copy(args),
|
||||
}),
|
||||
result_layout,
|
||||
next,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to derive a debug function name from a layout
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue