mirror of
https://github.com/roc-lang/roc.git
synced 2025-09-29 14:54:47 +00:00
Implement Eq for tags in CodeGenHelp
This commit is contained in:
parent
6b28635978
commit
635c9757dd
1 changed files with 192 additions and 9 deletions
|
@ -690,7 +690,7 @@ impl<'a> CodeGenHelp<'a> {
|
|||
}
|
||||
Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_) | Builtin::List(_)) => eq_todo(),
|
||||
Layout::Struct(field_layouts) => self.eq_struct(ident_ids, field_layouts),
|
||||
Layout::Union(_) => eq_todo(),
|
||||
Layout::Union(union_layout) => self.eq_tag_union(ident_ids, union_layout),
|
||||
Layout::LambdaSet(_) => unreachable!("`==` is not defined on functions"),
|
||||
Layout::RecursivePointer => eq_todo(),
|
||||
};
|
||||
|
@ -715,9 +715,9 @@ impl<'a> CodeGenHelp<'a> {
|
|||
ptr2: Symbol,
|
||||
following: &'a Stmt<'a>,
|
||||
) -> Stmt<'a> {
|
||||
let ptr1_addr = self.create_symbol(ident_ids, &format!("{:?}_addr", ptr1));
|
||||
let ptr2_addr = self.create_symbol(ident_ids, &format!("{:?}_addr", ptr2));
|
||||
let ptr_eq = self.create_symbol(ident_ids, &format!("eq_{:?}_{:?}", ptr1_addr, ptr2_addr));
|
||||
let ptr1_addr = self.create_symbol(ident_ids, "addr1");
|
||||
let ptr2_addr = self.create_symbol(ident_ids, "addr2");
|
||||
let ptr_eq = self.create_symbol(ident_ids, "eq_addr");
|
||||
|
||||
Stmt::Let(
|
||||
ptr1_addr,
|
||||
|
@ -780,7 +780,9 @@ impl<'a> CodeGenHelp<'a> {
|
|||
fn eq_struct(&self, ident_ids: &mut IdentIds, field_layouts: &'a [Layout<'a>]) -> Stmt<'a> {
|
||||
let else_clause = self.eq_fields(
|
||||
ident_ids,
|
||||
0,
|
||||
field_layouts,
|
||||
None,
|
||||
&[Symbol::ARG_1, Symbol::ARG_2],
|
||||
Stmt::Ret(Symbol::BOOL_TRUE),
|
||||
);
|
||||
|
@ -795,14 +797,15 @@ impl<'a> CodeGenHelp<'a> {
|
|||
fn eq_fields(
|
||||
&self,
|
||||
ident_ids: &mut IdentIds,
|
||||
tag_id: u64,
|
||||
field_layouts: &'a [Layout<'a>],
|
||||
rec_ptr_layout: Option<Layout<'a>>,
|
||||
arguments: &'a [Symbol],
|
||||
following: Stmt<'a>,
|
||||
) -> Stmt<'a> {
|
||||
let mut stmt = following;
|
||||
for (i, layout) in field_layouts.iter().enumerate().rev() {
|
||||
let field1_name = format!("{:?}_field_{}", arguments[0], i);
|
||||
let field1_sym = self.create_symbol(ident_ids, &field1_name);
|
||||
let field1_sym = self.create_symbol(ident_ids, &format!("field_1_{}_{}", tag_id, i));
|
||||
let field1_expr = Expr::StructAtIndex {
|
||||
index: i as u64,
|
||||
field_layouts,
|
||||
|
@ -810,8 +813,7 @@ impl<'a> CodeGenHelp<'a> {
|
|||
};
|
||||
let field1_stmt = |next| Stmt::Let(field1_sym, field1_expr, *layout, next);
|
||||
|
||||
let field2_name = format!("{:?}_field_{}", arguments[1], i);
|
||||
let field2_sym = self.create_symbol(ident_ids, &field2_name);
|
||||
let field2_sym = self.create_symbol(ident_ids, &format!("field_2_{}_{}", tag_id, i));
|
||||
let field2_expr = Expr::StructAtIndex {
|
||||
index: i as u64,
|
||||
field_layouts,
|
||||
|
@ -820,7 +822,13 @@ impl<'a> CodeGenHelp<'a> {
|
|||
let field2_stmt = |next| Stmt::Let(field2_sym, field2_expr, *layout, next);
|
||||
|
||||
let sub_layout_args = self.arena.alloc([field1_sym, field2_sym]);
|
||||
let eq_call_expr = self.apply_op_to_sub_layout(HelperOp::Eq, layout, sub_layout_args);
|
||||
let sub_layout = match (layout, rec_ptr_layout) {
|
||||
(Layout::RecursivePointer, Some(rec_layout)) => self.arena.alloc(rec_layout),
|
||||
_ => layout,
|
||||
};
|
||||
|
||||
let eq_call_expr =
|
||||
self.apply_op_to_sub_layout(HelperOp::Eq, sub_layout, sub_layout_args);
|
||||
let eq_call_name = format!("eq_call_{}", i);
|
||||
let eq_call_sym = self.create_symbol(ident_ids, &eq_call_name);
|
||||
let eq_call_stmt = |next| Stmt::Let(eq_call_sym, eq_call_expr, LAYOUT_BOOL, next);
|
||||
|
@ -838,6 +846,181 @@ impl<'a> CodeGenHelp<'a> {
|
|||
}
|
||||
stmt
|
||||
}
|
||||
|
||||
fn eq_tag_union(&self, ident_ids: &mut IdentIds, union_layout: UnionLayout<'a>) -> Stmt<'a> {
|
||||
use UnionLayout::*;
|
||||
|
||||
let main_stmt = match union_layout {
|
||||
NonRecursive(tags) => self.eq_tag_union_help(ident_ids, union_layout, tags, None),
|
||||
|
||||
Recursive(tags) => self.eq_tag_union_help(ident_ids, union_layout, tags, None),
|
||||
|
||||
NonNullableUnwrapped(field_layouts) => self.eq_fields(
|
||||
ident_ids,
|
||||
0,
|
||||
field_layouts,
|
||||
Some(Layout::Union(union_layout)),
|
||||
&[Symbol::ARG_1, Symbol::ARG_2],
|
||||
Stmt::Ret(Symbol::BOOL_TRUE),
|
||||
),
|
||||
|
||||
NullableWrapped {
|
||||
other_tags,
|
||||
nullable_id,
|
||||
} => self.eq_tag_union_help(ident_ids, union_layout, other_tags, Some(nullable_id)),
|
||||
|
||||
NullableUnwrapped {
|
||||
other_fields,
|
||||
nullable_id: n,
|
||||
} => self.eq_tag_union_help(
|
||||
ident_ids,
|
||||
union_layout,
|
||||
self.arena.alloc([other_fields]),
|
||||
Some(n as u16),
|
||||
),
|
||||
};
|
||||
|
||||
self.if_pointers_equal_return_true(
|
||||
ident_ids,
|
||||
Symbol::ARG_1,
|
||||
Symbol::ARG_2,
|
||||
self.arena.alloc(main_stmt),
|
||||
)
|
||||
}
|
||||
|
||||
fn eq_tag_union_help(
|
||||
&self,
|
||||
ident_ids: &mut IdentIds,
|
||||
union_layout: UnionLayout<'a>,
|
||||
tag_layouts: &'a [&'a [Layout<'a>]],
|
||||
nullable_id: Option<u16>,
|
||||
) -> Stmt<'a> {
|
||||
let tag_id_layout = union_layout.tag_id_layout();
|
||||
|
||||
let tag_id_a = self.create_symbol(ident_ids, "tag_id_a");
|
||||
let tag_id_a_stmt = |next| {
|
||||
Stmt::Let(
|
||||
tag_id_a,
|
||||
Expr::GetTagId {
|
||||
structure: Symbol::ARG_1,
|
||||
union_layout,
|
||||
},
|
||||
tag_id_layout,
|
||||
next,
|
||||
)
|
||||
};
|
||||
|
||||
let tag_id_b = self.create_symbol(ident_ids, "tag_id_b");
|
||||
let tag_id_b_stmt = |next| {
|
||||
Stmt::Let(
|
||||
tag_id_b,
|
||||
Expr::GetTagId {
|
||||
structure: Symbol::ARG_2,
|
||||
union_layout,
|
||||
},
|
||||
tag_id_layout,
|
||||
next,
|
||||
)
|
||||
};
|
||||
|
||||
let tag_ids_eq = self.create_symbol(ident_ids, "tag_ids_eq");
|
||||
let tag_ids_eq_stmt = |next| {
|
||||
Stmt::Let(
|
||||
tag_ids_eq,
|
||||
Expr::Call(Call {
|
||||
call_type: CallType::LowLevel {
|
||||
op: LowLevel::Eq,
|
||||
update_mode: UpdateModeId::BACKEND_DUMMY,
|
||||
},
|
||||
arguments: self.arena.alloc([tag_id_a, tag_id_b]),
|
||||
}),
|
||||
LAYOUT_BOOL,
|
||||
next,
|
||||
)
|
||||
};
|
||||
|
||||
let if_equal_ids_stmt = |next| Stmt::Switch {
|
||||
cond_symbol: tag_ids_eq,
|
||||
cond_layout: LAYOUT_BOOL,
|
||||
branches: self
|
||||
.arena
|
||||
.alloc([(0, BranchInfo::None, Stmt::Ret(Symbol::BOOL_FALSE))]),
|
||||
default_branch: (BranchInfo::None, next),
|
||||
ret_layout: LAYOUT_BOOL,
|
||||
};
|
||||
|
||||
//
|
||||
// Switch statement by tag ID
|
||||
//
|
||||
|
||||
let mut tag_branches = Vec::with_capacity_in(tag_layouts.len(), self.arena);
|
||||
|
||||
// If there's a null tag, check it first. We might not need to load any data from memory.
|
||||
if let Some(id) = nullable_id {
|
||||
tag_branches.push((id as u64, BranchInfo::None, Stmt::Ret(Symbol::BOOL_TRUE)))
|
||||
}
|
||||
|
||||
let recursive_ptr_layout = Some(Layout::Union(union_layout));
|
||||
|
||||
let mut tag_id: u64 = 0;
|
||||
for field_layouts in tag_layouts.iter().take(tag_layouts.len() - 1) {
|
||||
if let Some(null_id) = nullable_id {
|
||||
if tag_id == null_id as u64 {
|
||||
tag_id += 1;
|
||||
}
|
||||
}
|
||||
|
||||
tag_branches.push((
|
||||
tag_id,
|
||||
BranchInfo::None,
|
||||
self.eq_fields(
|
||||
ident_ids,
|
||||
tag_id,
|
||||
field_layouts,
|
||||
recursive_ptr_layout,
|
||||
&[Symbol::ARG_1, Symbol::ARG_2],
|
||||
Stmt::Ret(Symbol::BOOL_TRUE),
|
||||
),
|
||||
));
|
||||
|
||||
tag_id += 1;
|
||||
}
|
||||
|
||||
let tag_switch_stmt = Stmt::Switch {
|
||||
cond_symbol: tag_id_a,
|
||||
cond_layout: tag_id_layout,
|
||||
branches: tag_branches.into_bump_slice(),
|
||||
default_branch: (
|
||||
BranchInfo::None,
|
||||
self.arena.alloc(self.eq_fields(
|
||||
ident_ids,
|
||||
tag_id,
|
||||
tag_layouts.last().unwrap(),
|
||||
recursive_ptr_layout,
|
||||
&[Symbol::ARG_1, Symbol::ARG_2],
|
||||
Stmt::Ret(Symbol::BOOL_TRUE),
|
||||
)),
|
||||
),
|
||||
ret_layout: LAYOUT_BOOL,
|
||||
};
|
||||
|
||||
//
|
||||
// combine all the statments
|
||||
//
|
||||
tag_id_a_stmt(self.arena.alloc(
|
||||
//
|
||||
tag_id_b_stmt(self.arena.alloc(
|
||||
//
|
||||
tag_ids_eq_stmt(self.arena.alloc(
|
||||
//
|
||||
if_equal_ids_stmt(self.arena.alloc(
|
||||
//
|
||||
tag_switch_stmt,
|
||||
)),
|
||||
)),
|
||||
)),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to derive a debug function name from a layout
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue