mirror of
https://github.com/roc-lang/roc.git
synced 2025-10-01 15:51:12 +00:00
Generate IR for list equality
This commit is contained in:
parent
a2ada314ce
commit
d58a2814f6
1 changed files with 288 additions and 9 deletions
|
@ -6,8 +6,8 @@ use roc_module::low_level::LowLevel;
|
||||||
use roc_module::symbol::{IdentIds, ModuleId, Symbol};
|
use roc_module::symbol::{IdentIds, ModuleId, Symbol};
|
||||||
|
|
||||||
use crate::ir::{
|
use crate::ir::{
|
||||||
BranchInfo, Call, CallSpecId, CallType, Expr, HostExposedLayouts, Literal, ModifyRc, Proc,
|
BranchInfo, Call, CallSpecId, CallType, Expr, HostExposedLayouts, JoinPointId, Literal,
|
||||||
ProcLayout, SelfRecursive, Stmt, UpdateModeId,
|
ModifyRc, Param, Proc, ProcLayout, SelfRecursive, Stmt, UpdateModeId,
|
||||||
};
|
};
|
||||||
use crate::layout::{Builtin, Layout, UnionLayout};
|
use crate::layout::{Builtin, Layout, UnionLayout};
|
||||||
|
|
||||||
|
@ -269,7 +269,7 @@ impl<'a> CodeGenHelp<'a> {
|
||||||
new_procs_info.push((symbol, proc_layout));
|
new_procs_info.push((symbol, proc_layout));
|
||||||
|
|
||||||
let mut visit_child = |child| {
|
let mut visit_child = |child| {
|
||||||
if layout_needs_helper_proc(child, op) {
|
if layout_needs_helper_proc(child, op) && child.stack_size(self.ptr_size) > 0 {
|
||||||
self.get_or_create_proc_symbols_visit(ident_ids, new_procs_info, op, child);
|
self.get_or_create_proc_symbols_visit(ident_ids, new_procs_info, op, child);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -409,8 +409,12 @@ impl<'a> CodeGenHelp<'a> {
|
||||||
&self,
|
&self,
|
||||||
op: HelperOp,
|
op: HelperOp,
|
||||||
sub_layout: &Layout<'a>,
|
sub_layout: &Layout<'a>,
|
||||||
arguments: &'a [Symbol],
|
arguments: &[Symbol],
|
||||||
) -> Expr<'a> {
|
) -> Expr<'a> {
|
||||||
|
if matches!(op, HelperOp::Eq) && sub_layout.stack_size(self.ptr_size) == 0 {
|
||||||
|
return Expr::Literal(Literal::Int(1));
|
||||||
|
}
|
||||||
|
|
||||||
let found = self
|
let found = self
|
||||||
.specs
|
.specs
|
||||||
.iter()
|
.iter()
|
||||||
|
@ -436,7 +440,7 @@ impl<'a> CodeGenHelp<'a> {
|
||||||
arg_layouts,
|
arg_layouts,
|
||||||
specialization_id: CallSpecId::BACKEND_DUMMY,
|
specialization_id: CallSpecId::BACKEND_DUMMY,
|
||||||
},
|
},
|
||||||
arguments,
|
arguments: self.arena.alloc_slice_copy(arguments),
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
// By the time we get here (generating helper procs), the list of specializations is complete.
|
// By the time we get here (generating helper procs), the list of specializations is complete.
|
||||||
|
@ -455,7 +459,7 @@ impl<'a> CodeGenHelp<'a> {
|
||||||
op: lowlevel,
|
op: lowlevel,
|
||||||
update_mode: UpdateModeId::BACKEND_DUMMY,
|
update_mode: UpdateModeId::BACKEND_DUMMY,
|
||||||
},
|
},
|
||||||
arguments,
|
arguments: self.arena.alloc_slice_copy(arguments),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -691,7 +695,8 @@ impl<'a> CodeGenHelp<'a> {
|
||||||
Layout::Builtin(Builtin::Str) => {
|
Layout::Builtin(Builtin::Str) => {
|
||||||
unreachable!("No generated helper proc for `==` on Str. Use Zig function.")
|
unreachable!("No generated helper proc for `==` on Str. Use Zig function.")
|
||||||
}
|
}
|
||||||
Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_) | Builtin::List(_)) => eq_todo(),
|
Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_)) => eq_todo(),
|
||||||
|
Layout::Builtin(Builtin::List(elem_layout)) => self.eq_list(ident_ids, elem_layout),
|
||||||
Layout::Struct(field_layouts) => self.eq_struct(ident_ids, field_layouts),
|
Layout::Struct(field_layouts) => self.eq_struct(ident_ids, field_layouts),
|
||||||
Layout::Union(union_layout) => self.eq_tag_union(ident_ids, union_layout),
|
Layout::Union(union_layout) => self.eq_tag_union(ident_ids, union_layout),
|
||||||
Layout::LambdaSet(_) => unreachable!("`==` is not defined on functions"),
|
Layout::LambdaSet(_) => unreachable!("`==` is not defined on functions"),
|
||||||
|
@ -808,14 +813,13 @@ impl<'a> CodeGenHelp<'a> {
|
||||||
};
|
};
|
||||||
let field2_stmt = |next| Stmt::Let(field2_sym, field2_expr, *layout, next);
|
let field2_stmt = |next| Stmt::Let(field2_sym, field2_expr, *layout, next);
|
||||||
|
|
||||||
let sub_layout_args = self.arena.alloc([field1_sym, field2_sym]);
|
|
||||||
let sub_layout = match (layout, rec_ptr_layout) {
|
let sub_layout = match (layout, rec_ptr_layout) {
|
||||||
(Layout::RecursivePointer, Some(rec_layout)) => self.arena.alloc(rec_layout),
|
(Layout::RecursivePointer, Some(rec_layout)) => self.arena.alloc(rec_layout),
|
||||||
_ => layout,
|
_ => layout,
|
||||||
};
|
};
|
||||||
|
|
||||||
let eq_call_expr =
|
let eq_call_expr =
|
||||||
self.apply_op_to_sub_layout(HelperOp::Eq, sub_layout, sub_layout_args);
|
self.apply_op_to_sub_layout(HelperOp::Eq, sub_layout, &[field1_sym, field2_sym]);
|
||||||
let eq_call_name = format!("eq_call_{}", i);
|
let eq_call_name = format!("eq_call_{}", i);
|
||||||
let eq_call_sym = self.create_symbol(ident_ids, &eq_call_name);
|
let eq_call_sym = self.create_symbol(ident_ids, &eq_call_name);
|
||||||
let eq_call_stmt = |next| Stmt::Let(eq_call_sym, eq_call_expr, LAYOUT_BOOL, next);
|
let eq_call_stmt = |next| Stmt::Let(eq_call_sym, eq_call_expr, LAYOUT_BOOL, next);
|
||||||
|
@ -992,6 +996,281 @@ impl<'a> CodeGenHelp<'a> {
|
||||||
)),
|
)),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// List equality
|
||||||
|
/// We can't use `ListGetUnsafe` because it increments the refcount, and we don't want that.
|
||||||
|
/// Another way to dereference a heap pointer is to use `Expr::UnionAtIndex`.
|
||||||
|
/// To achieve this we use `PtrCast` to cast the element pointer to a "Box" layout.
|
||||||
|
/// Then we can increment the Box pointer in a loop, dereferencing it each time.
|
||||||
|
fn eq_list(&self, ident_ids: &mut IdentIds, elem_layout: &Layout<'a>) -> Stmt<'a> {
|
||||||
|
use LowLevel::*;
|
||||||
|
let layout_isize = self.layout_isize;
|
||||||
|
|
||||||
|
// A "Box" layout (heap pointer to a single list element)
|
||||||
|
let box_union_layout = UnionLayout::NonNullableUnwrapped(self.arena.alloc([*elem_layout]));
|
||||||
|
let box_layout = Layout::Union(box_union_layout);
|
||||||
|
|
||||||
|
// Compare lengths
|
||||||
|
|
||||||
|
let len_1 = self.create_symbol(ident_ids, "len_1");
|
||||||
|
let len_2 = self.create_symbol(ident_ids, "len_2");
|
||||||
|
let len_1_stmt = |next| self.let_lowlevel(layout_isize, len_1, ListLen, &[ARG_1], next);
|
||||||
|
let len_2_stmt = |next| self.let_lowlevel(layout_isize, len_2, ListLen, &[ARG_2], next);
|
||||||
|
|
||||||
|
let eq_len = self.create_symbol(ident_ids, "eq_len");
|
||||||
|
let eq_len_stmt = |next| self.let_lowlevel(LAYOUT_BOOL, eq_len, Eq, &[len_1, len_2], next);
|
||||||
|
|
||||||
|
// if lengths are equal...
|
||||||
|
|
||||||
|
// get element pointers
|
||||||
|
let elements_1 = self.create_symbol(ident_ids, "elements_1");
|
||||||
|
let elements_2 = self.create_symbol(ident_ids, "elements_2");
|
||||||
|
let elements_1_expr = Expr::StructAtIndex {
|
||||||
|
index: 0,
|
||||||
|
field_layouts: self.arena.alloc([box_layout, layout_isize]),
|
||||||
|
structure: ARG_1,
|
||||||
|
};
|
||||||
|
let elements_2_expr = Expr::StructAtIndex {
|
||||||
|
index: 0,
|
||||||
|
field_layouts: self.arena.alloc([box_layout, layout_isize]),
|
||||||
|
structure: ARG_2,
|
||||||
|
};
|
||||||
|
let elements_1_stmt = |next| Stmt::Let(elements_1, elements_1_expr, box_layout, next);
|
||||||
|
let elements_2_stmt = |next| Stmt::Let(elements_2, elements_2_expr, box_layout, next);
|
||||||
|
|
||||||
|
// Cast to integers
|
||||||
|
let start_addr_1 = self.create_symbol(ident_ids, "start_addr_1");
|
||||||
|
let start_addr_2 = self.create_symbol(ident_ids, "start_addr_2");
|
||||||
|
let start_addr_1_stmt =
|
||||||
|
|next| self.let_lowlevel(layout_isize, start_addr_1, PtrCast, &[elements_1], next);
|
||||||
|
let start_addr_2_stmt =
|
||||||
|
|next| self.let_lowlevel(layout_isize, start_addr_2, PtrCast, &[elements_2], next);
|
||||||
|
|
||||||
|
//
|
||||||
|
// Loop initialisation
|
||||||
|
//
|
||||||
|
|
||||||
|
// let elem_size = literal int
|
||||||
|
let elem_size = self.create_symbol(ident_ids, "elem_size");
|
||||||
|
let elem_size_expr =
|
||||||
|
Expr::Literal(Literal::Int(elem_layout.stack_size(self.ptr_size) as i128));
|
||||||
|
let elem_size_stmt = |next| Stmt::Let(elem_size, elem_size_expr, layout_isize, next);
|
||||||
|
|
||||||
|
// let list_size = len_1 * elem_size
|
||||||
|
let list_size = self.create_symbol(ident_ids, "list_size");
|
||||||
|
let list_size_stmt =
|
||||||
|
|next| self.let_lowlevel(layout_isize, list_size, NumMul, &[len_1, elem_size], next);
|
||||||
|
|
||||||
|
// let end_addr_1 = start_addr_1 + list_size
|
||||||
|
let end_addr_1 = self.create_symbol(ident_ids, "end_addr_1");
|
||||||
|
let end_addr_1_stmt = |next| {
|
||||||
|
self.let_lowlevel(
|
||||||
|
layout_isize,
|
||||||
|
end_addr_1,
|
||||||
|
NumAdd,
|
||||||
|
&[start_addr_1, list_size],
|
||||||
|
next,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
//
|
||||||
|
// Loop name & parameters
|
||||||
|
//
|
||||||
|
|
||||||
|
let elems_loop = JoinPointId(self.create_symbol(ident_ids, "elems_loop"));
|
||||||
|
let addr1 = self.create_symbol(ident_ids, "addr1");
|
||||||
|
let addr2 = self.create_symbol(ident_ids, "addr2");
|
||||||
|
|
||||||
|
let param_addr1 = Param {
|
||||||
|
symbol: addr1,
|
||||||
|
borrow: false,
|
||||||
|
layout: layout_isize,
|
||||||
|
};
|
||||||
|
|
||||||
|
let param_addr2 = Param {
|
||||||
|
symbol: addr2,
|
||||||
|
borrow: false,
|
||||||
|
layout: layout_isize,
|
||||||
|
};
|
||||||
|
|
||||||
|
//
|
||||||
|
// if we haven't reached the end yet...
|
||||||
|
//
|
||||||
|
|
||||||
|
// Cast integers to box pointers
|
||||||
|
let box1 = self.create_symbol(ident_ids, "box1");
|
||||||
|
let box2 = self.create_symbol(ident_ids, "box2");
|
||||||
|
let box1_stmt = |next| self.let_lowlevel(box_layout, box1, PtrCast, &[addr1], next);
|
||||||
|
let box2_stmt = |next| self.let_lowlevel(box_layout, box2, PtrCast, &[addr2], next);
|
||||||
|
|
||||||
|
// Dereference the box pointers to get the current elements
|
||||||
|
let elem1 = self.create_symbol(ident_ids, "elem1");
|
||||||
|
let elem2 = self.create_symbol(ident_ids, "elem2");
|
||||||
|
let elem1_expr = Expr::UnionAtIndex {
|
||||||
|
structure: box1,
|
||||||
|
union_layout: box_union_layout,
|
||||||
|
tag_id: 0,
|
||||||
|
index: 0,
|
||||||
|
};
|
||||||
|
let elem2_expr = Expr::UnionAtIndex {
|
||||||
|
structure: box2,
|
||||||
|
union_layout: box_union_layout,
|
||||||
|
tag_id: 0,
|
||||||
|
index: 0,
|
||||||
|
};
|
||||||
|
let elem1_stmt = |next| Stmt::Let(elem1, elem1_expr, *elem_layout, next);
|
||||||
|
let elem2_stmt = |next| Stmt::Let(elem2, elem2_expr, *elem_layout, next);
|
||||||
|
|
||||||
|
// Compare the two current elements
|
||||||
|
let eq_elems = self.create_symbol(ident_ids, "eq_elems");
|
||||||
|
let eq_elems_expr = self.apply_op_to_sub_layout(HelperOp::Eq, elem_layout, &[elem1, elem2]);
|
||||||
|
let eq_elems_stmt = |next| Stmt::Let(eq_elems, eq_elems_expr, LAYOUT_BOOL, next);
|
||||||
|
|
||||||
|
// If current elements are equal, loop back again
|
||||||
|
let next_addr_1 = self.create_symbol(ident_ids, "next_addr_1");
|
||||||
|
let next_addr_2 = self.create_symbol(ident_ids, "next_addr_2");
|
||||||
|
let next_addr_1_stmt =
|
||||||
|
|next| self.let_lowlevel(layout_isize, next_addr_1, NumAdd, &[addr1, elem_size], next);
|
||||||
|
let next_addr_2_stmt =
|
||||||
|
|next| self.let_lowlevel(layout_isize, next_addr_2, NumAdd, &[addr2, elem_size], next);
|
||||||
|
|
||||||
|
let jump_back = Stmt::Jump(elems_loop, self.arena.alloc([next_addr_1, next_addr_2]));
|
||||||
|
|
||||||
|
//
|
||||||
|
// Control flow
|
||||||
|
//
|
||||||
|
|
||||||
|
let is_end = self.create_symbol(ident_ids, "is_end");
|
||||||
|
let is_end_stmt =
|
||||||
|
|next| self.let_lowlevel(LAYOUT_BOOL, is_end, NumGte, &[addr1, end_addr_1], next);
|
||||||
|
|
||||||
|
let if_elems_not_equal = self.if_false_return_false(
|
||||||
|
eq_elems,
|
||||||
|
// else
|
||||||
|
self.arena.alloc(
|
||||||
|
//
|
||||||
|
next_addr_1_stmt(self.arena.alloc(
|
||||||
|
//
|
||||||
|
next_addr_2_stmt(self.arena.alloc(
|
||||||
|
//
|
||||||
|
jump_back,
|
||||||
|
)),
|
||||||
|
)),
|
||||||
|
),
|
||||||
|
);
|
||||||
|
|
||||||
|
let if_end_of_list = Stmt::Switch {
|
||||||
|
cond_symbol: is_end,
|
||||||
|
cond_layout: LAYOUT_BOOL,
|
||||||
|
ret_layout: LAYOUT_BOOL,
|
||||||
|
branches: self
|
||||||
|
.arena
|
||||||
|
.alloc([(1, BranchInfo::None, Stmt::Ret(Symbol::BOOL_TRUE))]),
|
||||||
|
default_branch: (
|
||||||
|
BranchInfo::None,
|
||||||
|
self.arena.alloc(
|
||||||
|
//
|
||||||
|
box1_stmt(self.arena.alloc(
|
||||||
|
//
|
||||||
|
box2_stmt(self.arena.alloc(
|
||||||
|
//
|
||||||
|
elem1_stmt(self.arena.alloc(
|
||||||
|
//
|
||||||
|
elem2_stmt(self.arena.alloc(
|
||||||
|
//
|
||||||
|
eq_elems_stmt(self.arena.alloc(
|
||||||
|
//
|
||||||
|
if_elems_not_equal,
|
||||||
|
)),
|
||||||
|
)),
|
||||||
|
)),
|
||||||
|
)),
|
||||||
|
)),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
|
let joinpoint_loop = Stmt::Join {
|
||||||
|
id: elems_loop,
|
||||||
|
parameters: self.arena.alloc([param_addr1, param_addr2]),
|
||||||
|
body: self.arena.alloc(
|
||||||
|
//
|
||||||
|
is_end_stmt(
|
||||||
|
//
|
||||||
|
self.arena.alloc(if_end_of_list),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
remainder: self.arena.alloc(Stmt::Jump(
|
||||||
|
elems_loop,
|
||||||
|
self.arena.alloc([start_addr_1, start_addr_2]),
|
||||||
|
)),
|
||||||
|
};
|
||||||
|
|
||||||
|
let if_different_lengths = self.if_false_return_false(
|
||||||
|
eq_len,
|
||||||
|
// else
|
||||||
|
self.arena.alloc(
|
||||||
|
//
|
||||||
|
elements_1_stmt(self.arena.alloc(
|
||||||
|
//
|
||||||
|
elements_2_stmt(self.arena.alloc(
|
||||||
|
//
|
||||||
|
start_addr_1_stmt(self.arena.alloc(
|
||||||
|
//
|
||||||
|
start_addr_2_stmt(self.arena.alloc(
|
||||||
|
//
|
||||||
|
elem_size_stmt(self.arena.alloc(
|
||||||
|
//
|
||||||
|
list_size_stmt(self.arena.alloc(
|
||||||
|
//
|
||||||
|
end_addr_1_stmt(self.arena.alloc(
|
||||||
|
//
|
||||||
|
joinpoint_loop,
|
||||||
|
)),
|
||||||
|
)),
|
||||||
|
)),
|
||||||
|
)),
|
||||||
|
)),
|
||||||
|
)),
|
||||||
|
)),
|
||||||
|
),
|
||||||
|
);
|
||||||
|
|
||||||
|
let pointers_else = len_1_stmt(self.arena.alloc(
|
||||||
|
//
|
||||||
|
len_2_stmt(self.arena.alloc(
|
||||||
|
//
|
||||||
|
eq_len_stmt(self.arena.alloc(
|
||||||
|
//
|
||||||
|
if_different_lengths,
|
||||||
|
)),
|
||||||
|
)),
|
||||||
|
));
|
||||||
|
|
||||||
|
self.if_pointers_equal_return_true(ident_ids, self.arena.alloc(pointers_else))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn let_lowlevel(
|
||||||
|
&self,
|
||||||
|
result_layout: Layout<'a>,
|
||||||
|
result: Symbol,
|
||||||
|
op: LowLevel,
|
||||||
|
args: &[Symbol],
|
||||||
|
next: &'a Stmt<'a>,
|
||||||
|
) -> Stmt<'a> {
|
||||||
|
Stmt::Let(
|
||||||
|
result,
|
||||||
|
Expr::Call(Call {
|
||||||
|
call_type: CallType::LowLevel {
|
||||||
|
op,
|
||||||
|
update_mode: UpdateModeId::BACKEND_DUMMY,
|
||||||
|
},
|
||||||
|
arguments: self.arena.alloc_slice_copy(args),
|
||||||
|
}),
|
||||||
|
result_layout,
|
||||||
|
next,
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Helper to derive a debug function name from a layout
|
/// Helper to derive a debug function name from a layout
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue