fix and optimize tail-recursive decrement

This commit is contained in:
Folkert 2023-04-28 17:50:30 +02:00
parent 755c294d90
commit a7c7ad2d17
No known key found for this signature in database
GPG key ID: 1F17F6FFD112B97C
2 changed files with 53 additions and 44 deletions

View file

@ -538,42 +538,37 @@ impl<'a> CodeGenHelp<'a> {
fn union_tail_recursion_fields(
&self,
layout_interner: &STLayoutInterner<'a>,
union_in_layout: InLayout<'a>,
union: UnionLayout<'a>,
) -> Option<Vec<'a, Option<usize>>> {
use UnionLayout::*;
match union {
NonRecursive(_) => None,
Recursive(tags) => self.union_tail_recursion_fields_help(layout_interner, tags),
Recursive(tags) => self.union_tail_recursion_fields_help(union_in_layout, tags),
NonNullableUnwrapped(field_layouts) => {
self.union_tail_recursion_fields_help(layout_interner, &[field_layouts])
self.union_tail_recursion_fields_help(union_in_layout, &[field_layouts])
}
NullableWrapped {
other_tags: tags, ..
} => self.union_tail_recursion_fields_help(layout_interner, tags),
} => self.union_tail_recursion_fields_help(union_in_layout, tags),
NullableUnwrapped { other_fields, .. } => {
self.union_tail_recursion_fields_help(layout_interner, &[other_fields])
self.union_tail_recursion_fields_help(union_in_layout, &[other_fields])
}
}
}
fn union_tail_recursion_fields_help(
&self,
layout_interner: &STLayoutInterner<'a>,
in_layout: InLayout<'a>,
tags: &[&'a [InLayout<'a>]],
) -> Option<Vec<'a, Option<usize>>> {
let tailrec_indices = tags
.iter()
.map(|fields| {
let found_index = fields
.iter()
.position(|f| matches!(layout_interner.get(*f), Layout::RecursivePointer(_)));
found_index
})
.map(|fields| fields.iter().position(|f| *f == in_layout))
.collect_in::<Vec<_>>(self.arena);
if tailrec_indices.iter().any(|i| i.is_some()) {

View file

@ -156,6 +156,7 @@ pub fn refcount_generic<'a>(
ident_ids,
ctx,
layout_interner,
layout,
union_layout,
structure,
),
@ -335,24 +336,15 @@ pub fn refcount_reset_proc_body<'a>(
.unwrap();
let decrement_stmt = |next| Stmt::Let(decrement_unit, decrement_expr, LAYOUT_UNIT, next);
// Zero
let zero = root.create_symbol(ident_ids, "zero");
let zero_expr = Expr::Literal(Literal::Int(0i128.to_ne_bytes()));
let zero_stmt = |next| Stmt::Let(zero, zero_expr, root.layout_isize, next);
// Null pointer with union layout
let null = root.create_symbol(ident_ids, "null");
let null_stmt =
|next| let_lowlevel(root.arena, root.layout_isize, null, PtrCast, &[zero], next);
let null_stmt = |next| Stmt::Let(null, Expr::NullPointer, layout, next);
decrement_stmt(root.arena.alloc(
//
zero_stmt(root.arena.alloc(
null_stmt(root.arena.alloc(
//
null_stmt(root.arena.alloc(
//
Stmt::Ret(null),
)),
Stmt::Ret(null),
)),
))
};
@ -501,8 +493,7 @@ pub fn refcount_resetref_proc_body<'a>(
// Null pointer with union layout
let null = root.create_symbol(ident_ids, "null");
let null_stmt =
|next| let_lowlevel(root.arena, root.layout_isize, null, PtrCast, &[zero], next);
let null_stmt = |next| Stmt::Let(null, Expr::NullPointer, layout, next);
decrement_stmt(root.arena.alloc(
//
@ -1169,6 +1160,7 @@ fn refcount_union<'a>(
ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
layout_interner: &mut STLayoutInterner<'a>,
union_in_layout: InLayout<'a>,
union: UnionLayout<'a>,
structure: Symbol,
) -> Stmt<'a> {
@ -1191,8 +1183,8 @@ fn refcount_union<'a>(
),
Recursive(tags) => {
let tailrec_idx = root.union_tail_recursion_fields(layout_interner, union);
if let (Some(tail_idx), false) = (tailrec_idx, ctx.op.is_decref()) {
let tailrec_idx = root.union_tail_recursion_fields(union_in_layout, union);
if let (Some(tail_idx), true) = (tailrec_idx, ctx.op.is_dec()) {
refcount_union_tailrec(
root,
ident_ids,
@ -1241,8 +1233,8 @@ fn refcount_union<'a>(
nullable_id,
} => {
let null_id = Some(nullable_id);
let tailrec_idx = root.union_tail_recursion_fields(layout_interner, union);
if let (Some(tail_idx), false) = (tailrec_idx, ctx.op.is_decref()) {
let tailrec_idx = root.union_tail_recursion_fields(union_in_layout, union);
if let (Some(tail_idx), true) = (tailrec_idx, ctx.op.is_dec()) {
refcount_union_tailrec(
root,
ident_ids,
@ -1274,8 +1266,8 @@ fn refcount_union<'a>(
} => {
let null_id = Some(nullable_id as TagIdIntType);
let tags = root.arena.alloc([other_fields]);
let tailrec_idx = root.union_tail_recursion_fields(layout_interner, union);
if let (Some(tail_idx), false) = (tailrec_idx, ctx.op.is_decref()) {
let tailrec_idx = root.union_tail_recursion_fields(union_in_layout, union);
if let (Some(tail_idx), true) = (tailrec_idx, ctx.op.is_dec()) {
refcount_union_tailrec(
root,
ident_ids,
@ -1620,20 +1612,12 @@ fn refcount_union_tailrec<'a>(
(filtered.into_bump_slice(), tail_stmt.unwrap())
} else {
let zero = root.create_symbol(ident_ids, "zero");
let zero_expr = Expr::Literal(Literal::Int(0i128.to_ne_bytes()));
let zero_stmt = |next| Stmt::Let(zero, zero_expr, root.layout_isize, next);
let null = root.create_symbol(ident_ids, "null");
let null_stmt =
|next| let_lowlevel(root.arena, layout, null, PtrCast, &[zero], next);
let null_stmt = |next| Stmt::Let(null, Expr::NullPointer, layout, next);
let tail_stmt = zero_stmt(root.arena.alloc(
let tail_stmt = null_stmt(root.arena.alloc(
//
null_stmt(root.arena.alloc(
//
Stmt::Jump(jp_modify_union, root.arena.alloc([null])),
)),
Stmt::Jump(jp_modify_union, root.arena.alloc([null])),
));
let field_layouts = field_layouts
@ -1671,6 +1655,36 @@ fn refcount_union_tailrec<'a>(
ret_layout: LAYOUT_UNIT,
};
let is_unique = root.create_symbol(ident_ids, "is_unique");
let null_pointer = root.create_symbol(ident_ids, "null_pointer");
let jump_with_null_ptr = Stmt::Let(
null_pointer,
Expr::NullPointer,
layout_interner.insert(Layout::Union(union_layout)),
root.arena.alloc(Stmt::Jump(
jp_modify_union,
root.arena.alloc([null_pointer]),
)),
);
let switch_with_unique_check = Stmt::if_then_else(
root.arena,
is_unique,
Layout::UNIT,
tag_id_switch,
root.arena.alloc(jump_with_null_ptr),
);
let switch_with_unique_check_and_let = let_lowlevel(
root.arena,
Layout::BOOL,
is_unique,
LowLevel::RefCountIsUnique,
&[current],
root.arena.alloc(switch_with_unique_check),
);
let jp_param = Param {
symbol: next_ptr,
ownership: Ownership::Borrowed,
@ -1681,7 +1695,7 @@ fn refcount_union_tailrec<'a>(
id: jp_modify_union,
parameters: root.arena.alloc([jp_param]),
body: root.arena.alloc(rc_structure_stmt),
remainder: root.arena.alloc(tag_id_switch),
remainder: root.arena.alloc(switch_with_unique_check_and_let),
}
};