This commit is contained in:
J.Teeuwissen 2023-04-24 13:10:14 +02:00
parent a0f8b50efe
commit 91533e1071
No known key found for this signature in database
GPG key ID: DB5F7A1ED8D478AD
2 changed files with 33 additions and 42 deletions

View file

@ -1,4 +1,5 @@
use bumpalo::collections::vec::Vec;
use bumpalo::collections::CollectIn;
use bumpalo::Bump;
use roc_module::low_level::LowLevel;
use roc_module::symbol::{IdentIds, ModuleId, Symbol};
@ -539,10 +540,10 @@ impl<'a> CodeGenHelp<'a> {
&self,
layout_interner: &STLayoutInterner<'a>,
union: UnionLayout<'a>,
) -> (bool, Vec<'a, Option<usize>>) {
) -> Option<Vec<'a, Option<usize>>> {
use UnionLayout::*;
match union {
NonRecursive(_) => (false, bumpalo::vec![in self.arena]),
NonRecursive(_) => None,
Recursive(tags) => self.union_tail_recursion_fields_help(layout_interner, tags),
@ -564,19 +565,22 @@ impl<'a> CodeGenHelp<'a> {
&self,
layout_interner: &STLayoutInterner<'a>,
tags: &[&'a [InLayout<'a>]],
) -> (bool, Vec<'a, Option<usize>>) {
let mut can_use_tailrec = false;
let mut tailrec_indices = Vec::with_capacity_in(tags.len(), self.arena);
) -> Option<Vec<'a, Option<usize>>> {
let tailrec_indices = tags
.iter()
.map(|fields| {
let found_index = fields
.iter()
.position(|f| matches!(layout_interner.get(*f), Layout::RecursivePointer(_)));
found_index
})
.collect_in::<Vec<_>>(self.arena);
for fields in tags.iter() {
let found_index = fields
.iter()
.position(|f| matches!(layout_interner.get(*f), Layout::RecursivePointer(_)));
tailrec_indices.push(found_index);
can_use_tailrec |= found_index.is_some();
if tailrec_indices.iter().any(|i| i.is_some()) {
None
} else {
Some(tailrec_indices)
}
(can_use_tailrec, tailrec_indices)
}
}

View file

@ -1202,8 +1202,8 @@ fn refcount_union<'a>(
),
Recursive(tags) => {
let (is_tailrec, tail_idx) = root.union_tail_recursion_fields(layout_interner, union);
if is_tailrec && !ctx.op.is_decref() {
let tailrec_idx = root.union_tail_recursion_fields(layout_interner, union);
if let (Some(tail_idx), false) = (tailrec_idx, ctx.op.is_decref()) {
refcount_union_tailrec(
root,
ident_ids,
@ -1256,8 +1256,8 @@ fn refcount_union<'a>(
nullable_id,
} => {
let null_id = Some(nullable_id);
let (is_tailrec, tail_idx) = root.union_tail_recursion_fields(layout_interner, union);
if is_tailrec && !ctx.op.is_decref() {
let tailrec_idx = root.union_tail_recursion_fields(layout_interner, union);
if let (Some(tail_idx), false) = (tailrec_idx, ctx.op.is_decref()) {
refcount_union_tailrec(
root,
ident_ids,
@ -1291,8 +1291,8 @@ fn refcount_union<'a>(
} => {
let null_id = Some(nullable_id as TagIdIntType);
let tags = root.arena.alloc([other_fields]);
let (is_tailrec, tail_idx) = root.union_tail_recursion_fields(layout_interner, union);
if is_tailrec && !ctx.op.is_decref() {
let tailrec_idx = root.union_tail_recursion_fields(layout_interner, union);
if let (Some(tail_idx), false) = (tailrec_idx, ctx.op.is_decref()) {
refcount_union_tailrec(
root,
ident_ids,
@ -1391,17 +1391,12 @@ fn refcount_union_contents<'a>(
if let Some(id) = null_id {
let ret = rc_return_stmt(root, ident_ids, ctx);
tag_branches.push((id as u64, BranchInfo::None, ret));
}
let mut tag_id: TagIdIntType = 0;
for field_layouts in tag_layouts.iter() {
match null_id {
Some(id) if id == tag_id => {
tag_id += 1;
}
_ => {}
}
};
for (field_layouts, tag_id) in tag_layouts
.iter()
.zip((0..).filter(|tag_id| !matches!(null_id, Some(id) if tag_id == &id)))
{
// After refcounting the fields, jump to modify the union itself
// (Order is important, to avoid use-after-free for Dec)
let following = Stmt::Jump(jp_contents_modified, &[]);
@ -1426,8 +1421,6 @@ fn refcount_union_contents<'a>(
);
tag_branches.push((tag_id as u64, BranchInfo::None, fields_stmt));
tag_id += 1;
}
let default_stmt: Stmt<'a> = tag_branches.pop().unwrap().2;
@ -1620,15 +1613,11 @@ fn refcount_union_tailrec<'a>(
tag_branches.push((id as u64, BranchInfo::None, ret));
}
let mut tag_id: TagIdIntType = 0;
for (field_layouts, opt_tailrec_index) in tag_layouts.iter().zip(tailrec_indices) {
match null_id {
Some(id) if id == tag_id => {
tag_id += 1;
}
_ => {}
}
for ((field_layouts, opt_tailrec_index), tag_id) in tag_layouts
.iter()
.zip(tailrec_indices)
.zip((0..).filter(|tag_id| !matches!(null_id, Some(id) if tag_id == &id)))
{
// After refcounting the fields, jump to modify the union itself.
// The loop param is a pointer to the next union. It gets passed through two jumps.
let (non_tailrec_fields, jump_to_modify_union) =
@ -1694,8 +1683,6 @@ fn refcount_union_tailrec<'a>(
);
tag_branches.push((tag_id as u64, BranchInfo::None, fields_stmt));
tag_id += 1;
}
let default_stmt: Stmt<'a> = tag_branches.pop().unwrap().2;