Merge branch 'trunk' of github.com:rtfeldman/roc into wasm-module-refactor

This commit is contained in:
Brian Carroll 2022-01-11 21:51:03 +00:00
commit 54f2855349
9 changed files with 1153 additions and 81 deletions

View file

@ -329,6 +329,21 @@ values - including other records, or even functions!
{ birds: 4, nestedRecord: { someFunction: (\arg -> arg + 1), name: "Sam" } }
```
### Record shorthands
Roc has a couple of shorthands you can use to express some record-related operations more concisely.
Instead of writing `\record -> record.x` we can write `.x` and it will evaluate to the same thing:
a function that takes a record and returns its `x` field. You can do this with any field you want.
For example:
```elm
returnFoo = .foo
returnFoo { foo: "hi!", bar: "blah" }
# returns "hi!"
```
Whenever we're setting a field to be a def that has the same name as the field -
for example, `{ x: x }` - we can shorten it to just writing the name of the def alone -
for example, `{ x }`. We can do this with as many fields as we like, e.g.

View file

@ -1036,6 +1036,12 @@ fn byte_to_ast<'a>(env: &Env<'a, '_>, value: u8, content: &Content) -> Expr<'a>
FlatType::TagUnion(tags, _) if tags.len() == 1 => {
let (tag_name, payload_vars) = unpack_single_element_tag_union(env.subs, *tags);
// If this tag union represents a number, skip right to
// returning it as an Expr::Num
if let TagName::Private(Symbol::NUM_AT_NUM) = &tag_name {
return Expr::Num(env.arena.alloc_str(&value.to_string()));
}
let loc_tag_expr = {
let tag_name = &tag_name.as_ident_str(env.interns, env.home);
let tag_expr = if tag_name.starts_with('@') {
@ -1122,7 +1128,7 @@ fn num_to_ast<'a>(env: &Env<'a, '_>, num_expr: Expr<'a>, content: &Content) -> E
let (tag_name, payload_vars) = unpack_single_element_tag_union(env.subs, *tags);
// If this tag union represents a number, skip right to
// returning tis as an Expr::Num
// returning it as an Expr::Num
if let TagName::Private(Symbol::NUM_AT_NUM) = &tag_name {
return num_expr;
}

View file

@ -851,6 +851,20 @@ mod repl_eval {
)
}
#[test]
fn print_u8s() {
expect_success(
indoc!(
r#"
x : U8
x = 129
x
"#
),
"129 : U8",
)
}
#[test]
fn parse_problem() {
expect_failure(

View file

@ -32,6 +32,12 @@ enum HelperOp {
Eq,
}
impl HelperOp {
fn is_decref(&self) -> bool {
matches!(self, Self::DecRef(_))
}
}
#[derive(Debug)]
struct Specialization<'a> {
op: HelperOp,
@ -174,7 +180,7 @@ impl<'a> CodeGenHelp<'a> {
) -> Option<Expr<'a>> {
use HelperOp::*;
debug_assert!(self.debug_recursion_depth < 10);
// debug_assert!(self.debug_recursion_depth < 100);
self.debug_recursion_depth += 1;
let layout = if matches!(called_layout, Layout::RecursivePointer) {
@ -225,6 +231,8 @@ impl<'a> CodeGenHelp<'a> {
) -> Symbol {
use HelperOp::*;
let layout = self.replace_rec_ptr(ctx, layout);
let found = self
.specializations
.iter()
@ -323,6 +331,98 @@ impl<'a> CodeGenHelp<'a> {
let ident_id = ident_ids.add(Ident::from(debug_name));
Symbol::new(self.home, ident_id)
}
// When creating or looking up Specializations, we need to replace RecursivePointer
// with the particular Union layout it represents at this point in the tree.
// For example if a program uses `RoseTree a : [ Tree a (List (RoseTree a)) ]`
// then it could have both `RoseTree I64` and `RoseTree Str`. In this case it
// needs *two* specializations for `List(RecursivePointer)`, not just one.
fn replace_rec_ptr(&self, ctx: &Context<'a>, layout: Layout<'a>) -> Layout<'a> {
match layout {
Layout::Builtin(Builtin::Dict(k, v)) => Layout::Builtin(Builtin::Dict(
self.arena.alloc(self.replace_rec_ptr(ctx, *k)),
self.arena.alloc(self.replace_rec_ptr(ctx, *v)),
)),
Layout::Builtin(Builtin::Set(k)) => Layout::Builtin(Builtin::Set(
self.arena.alloc(self.replace_rec_ptr(ctx, *k)),
)),
Layout::Builtin(Builtin::List(v)) => Layout::Builtin(Builtin::List(
self.arena.alloc(self.replace_rec_ptr(ctx, *v)),
)),
Layout::Builtin(_) => layout,
Layout::Struct(fields) => {
let new_fields_iter = fields.iter().map(|f| self.replace_rec_ptr(ctx, *f));
Layout::Struct(self.arena.alloc_slice_fill_iter(new_fields_iter))
}
Layout::Union(UnionLayout::NonRecursive(tags)) => {
let mut new_tags = Vec::with_capacity_in(tags.len(), self.arena);
for fields in tags {
let mut new_fields = Vec::with_capacity_in(fields.len(), self.arena);
for field in fields.iter() {
new_fields.push(self.replace_rec_ptr(ctx, *field))
}
new_tags.push(new_fields.into_bump_slice());
}
Layout::Union(UnionLayout::NonRecursive(new_tags.into_bump_slice()))
}
Layout::Union(_) => layout,
Layout::LambdaSet(lambda_set) => {
self.replace_rec_ptr(ctx, lambda_set.runtime_representation())
}
// This line is the whole point of the function
Layout::RecursivePointer => Layout::Union(ctx.recursive_union.unwrap()),
}
}
fn union_tail_recursion_fields(
&self,
union: UnionLayout<'a>,
) -> (bool, Vec<'a, Option<usize>>) {
use UnionLayout::*;
match union {
NonRecursive(_) => return (false, bumpalo::vec![in self.arena]),
Recursive(tags) => self.union_tail_recursion_fields_help(tags),
NonNullableUnwrapped(field_layouts) => {
self.union_tail_recursion_fields_help(&[field_layouts])
}
NullableWrapped {
other_tags: tags, ..
} => self.union_tail_recursion_fields_help(tags),
NullableUnwrapped { other_fields, .. } => {
self.union_tail_recursion_fields_help(&[other_fields])
}
}
}
fn union_tail_recursion_fields_help(
&self,
tags: &[&'a [Layout<'a>]],
) -> (bool, Vec<'a, Option<usize>>) {
let mut can_use_tailrec = false;
let mut tailrec_indices = Vec::with_capacity_in(tags.len(), self.arena);
for fields in tags.iter() {
let found_index = fields
.iter()
.position(|f| matches!(f, Layout::RecursivePointer));
tailrec_indices.push(found_index);
can_use_tailrec |= found_index.is_some();
}
(can_use_tailrec, tailrec_indices)
}
}
fn let_lowlevel<'a>(

View file

@ -1,3 +1,4 @@
use bumpalo::collections::vec::Vec;
use roc_builtins::bitcode::IntWidth;
use roc_module::low_level::{LowLevel, LowLevel::*};
use roc_module::symbol::{IdentIds, Symbol};
@ -6,7 +7,7 @@ use crate::code_gen_help::let_lowlevel;
use crate::ir::{
BranchInfo, Call, CallType, Expr, JoinPointId, Literal, ModifyRc, Param, Stmt, UpdateModeId,
};
use crate::layout::{Builtin, Layout, UnionLayout};
use crate::layout::{Builtin, Layout, TagIdIntType, UnionLayout};
use super::{CodeGenHelp, Context, HelperOp};
@ -67,8 +68,9 @@ pub fn refcount_stmt<'a>(
refcount_stmt(root, ident_ids, ctx, layout, modify, following)
}
// Struct is stack-only, so DecRef is a no-op
// Struct and non-recursive Unions are stack-only, so DecRef is a no-op
Layout::Struct(_) => following,
Layout::Union(UnionLayout::NonRecursive(_)) => following,
// Inline the refcounting code instead of making a function. Don't iterate fields,
// and replace any return statements with jumps to the `following` statement.
@ -112,9 +114,12 @@ pub fn refcount_generic<'a>(
Layout::Struct(field_layouts) => {
refcount_struct(root, ident_ids, ctx, field_layouts, structure)
}
Layout::Union(_) => rc_todo(),
Layout::LambdaSet(_) => {
unreachable!("Refcounting on LambdaSet is invalid. Should be a Union at runtime.")
Layout::Union(union_layout) => {
refcount_union(root, ident_ids, ctx, union_layout, structure)
}
Layout::LambdaSet(lambda_set) => {
let runtime_layout = lambda_set.runtime_representation();
refcount_generic(root, ident_ids, ctx, runtime_layout, structure)
}
Layout::RecursivePointer => rc_todo(),
}
@ -124,12 +129,32 @@ pub fn refcount_generic<'a>(
// In the short term, it helps us to skip refcounting and let it leak, so we can make
// progress incrementally. Kept in sync with generate_procs using assertions.
pub fn is_rc_implemented_yet(layout: &Layout) -> bool {
use UnionLayout::*;
match layout {
Layout::Builtin(Builtin::Dict(..) | Builtin::Set(_)) => false,
Layout::Builtin(Builtin::List(elem_layout)) => is_rc_implemented_yet(elem_layout),
Layout::Builtin(_) => true,
Layout::Struct(fields) => fields.iter().all(is_rc_implemented_yet),
_ => false,
Layout::Union(union_layout) => match union_layout {
NonRecursive(tags) => tags
.iter()
.all(|fields| fields.iter().all(is_rc_implemented_yet)),
Recursive(tags) => tags
.iter()
.all(|fields| fields.iter().all(is_rc_implemented_yet)),
NonNullableUnwrapped(fields) => fields.iter().all(is_rc_implemented_yet),
NullableWrapped { other_tags, .. } => other_tags
.iter()
.all(|fields| fields.iter().all(is_rc_implemented_yet)),
NullableUnwrapped { other_fields, .. } => {
other_fields.iter().all(is_rc_implemented_yet)
}
},
Layout::LambdaSet(lambda_set) => {
is_rc_implemented_yet(&lambda_set.runtime_representation())
}
Layout::RecursivePointer => true,
}
}
@ -165,6 +190,7 @@ pub fn rc_ptr_from_data_ptr<'a>(
ident_ids: &mut IdentIds,
structure: Symbol,
rc_ptr_sym: Symbol,
mask_lower_bits: bool,
following: &'a Stmt<'a>,
) -> Stmt<'a> {
// Typecast the structure pointer to an integer
@ -179,6 +205,21 @@ pub fn rc_ptr_from_data_ptr<'a>(
});
let addr_stmt = |next| Stmt::Let(addr_sym, addr_expr, root.layout_isize, next);
// Mask for lower bits (for tag union id)
let mask_sym = root.create_symbol(ident_ids, "mask");
let mask_expr = Expr::Literal(Literal::Int(-(root.ptr_size as i128)));
let mask_stmt = |next| Stmt::Let(mask_sym, mask_expr, root.layout_isize, next);
let masked_sym = root.create_symbol(ident_ids, "masked");
let and_expr = Expr::Call(Call {
call_type: CallType::LowLevel {
op: LowLevel::And,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: root.arena.alloc([addr_sym, mask_sym]),
});
let and_stmt = |next| Stmt::Let(masked_sym, and_expr, root.layout_isize, next);
// Pointer size constant
let ptr_size_sym = root.create_symbol(ident_ids, "ptr_size");
let ptr_size_expr = Expr::Literal(Literal::Int(root.ptr_size as i128));
@ -186,38 +227,67 @@ pub fn rc_ptr_from_data_ptr<'a>(
// Refcount address
let rc_addr_sym = root.create_symbol(ident_ids, "rc_addr");
let rc_addr_expr = Expr::Call(Call {
let sub_expr = Expr::Call(Call {
call_type: CallType::LowLevel {
op: LowLevel::NumSub,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: root.arena.alloc([addr_sym, ptr_size_sym]),
arguments: root.arena.alloc([
if mask_lower_bits {
masked_sym
} else {
addr_sym
},
ptr_size_sym,
]),
});
let rc_addr_stmt = |next| Stmt::Let(rc_addr_sym, rc_addr_expr, root.layout_isize, next);
let sub_stmt = |next| Stmt::Let(rc_addr_sym, sub_expr, root.layout_isize, next);
// Typecast the refcount address from integer to pointer
let rc_ptr_expr = Expr::Call(Call {
let cast_expr = Expr::Call(Call {
call_type: CallType::LowLevel {
op: LowLevel::PtrCast,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: root.arena.alloc([rc_addr_sym]),
});
let rc_ptr_stmt = |next| Stmt::Let(rc_ptr_sym, rc_ptr_expr, LAYOUT_PTR, next);
let cast_stmt = |next| Stmt::Let(rc_ptr_sym, cast_expr, LAYOUT_PTR, next);
if mask_lower_bits {
addr_stmt(root.arena.alloc(
//
mask_stmt(root.arena.alloc(
//
and_stmt(root.arena.alloc(
//
ptr_size_stmt(root.arena.alloc(
//
sub_stmt(root.arena.alloc(
//
cast_stmt(root.arena.alloc(
//
following,
)),
)),
)),
)),
)),
))
} else {
addr_stmt(root.arena.alloc(
//
ptr_size_stmt(root.arena.alloc(
//
rc_addr_stmt(root.arena.alloc(
sub_stmt(root.arena.alloc(
//
rc_ptr_stmt(root.arena.alloc(
cast_stmt(root.arena.alloc(
//
following,
)),
)),
)),
))
}
}
fn modify_refcount<'a>(
@ -332,6 +402,7 @@ fn refcount_str<'a>(
ident_ids,
elements,
rc_ptr,
false,
root.arena.alloc(
//
mod_rc_stmt,
@ -412,12 +483,32 @@ fn refcount_list<'a>(
//
// modify refcount of the list and its elements
// (elements first, to avoid use-after-free for Dec)
//
let rc_ptr = root.create_symbol(ident_ids, "rc_ptr");
let alignment = layout.alignment_bytes(root.ptr_size);
let modify_elems = if elem_layout.is_refcounted() && !matches!(ctx.op, HelperOp::DecRef(_)) {
let ret_stmt = rc_return_stmt(root, ident_ids, ctx);
let modify_list = modify_refcount(
root,
ident_ids,
ctx,
rc_ptr,
alignment,
arena.alloc(ret_stmt),
);
let get_rc_and_modify_list = rc_ptr_from_data_ptr(
root,
ident_ids,
elements,
rc_ptr,
false,
arena.alloc(modify_list),
);
let modify_elems_and_list = if elem_layout.is_refcounted() && !ctx.op.is_decref() {
refcount_list_elems(
root,
ident_ids,
@ -427,36 +518,31 @@ fn refcount_list<'a>(
box_union_layout,
len,
elements,
get_rc_and_modify_list,
)
} else {
rc_return_stmt(root, ident_ids, ctx)
get_rc_and_modify_list
};
let modify_list = modify_refcount(
root,
ident_ids,
ctx,
rc_ptr,
alignment,
arena.alloc(modify_elems),
);
let modify_list_and_elems = elements_stmt(arena.alloc(
//
rc_ptr_from_data_ptr(root, ident_ids, elements, rc_ptr, arena.alloc(modify_list)),
));
//
// Do nothing if the list is empty
//
let non_empty_branch = root.arena.alloc(
//
elements_stmt(root.arena.alloc(
//
modify_elems_and_list,
)),
);
let if_stmt = Stmt::Switch {
cond_symbol: is_empty,
cond_layout: LAYOUT_BOOL,
branches: root
.arena
.alloc([(1, BranchInfo::None, rc_return_stmt(root, ident_ids, ctx))]),
default_branch: (BranchInfo::None, root.arena.alloc(modify_list_and_elems)),
default_branch: (BranchInfo::None, non_empty_branch),
ret_layout: LAYOUT_UNIT,
};
@ -482,6 +568,7 @@ fn refcount_list_elems<'a>(
box_union_layout: UnionLayout<'a>,
length: Symbol,
elements: Symbol,
following: Stmt<'a>,
) -> Stmt<'a> {
use LowLevel::*;
let layout_isize = root.layout_isize;
@ -496,9 +583,9 @@ fn refcount_list_elems<'a>(
//
// let size = literal int
let size = root.create_symbol(ident_ids, "size");
let size_expr = Expr::Literal(Literal::Int(elem_layout.stack_size(root.ptr_size) as i128));
let size_stmt = |next| Stmt::Let(size, size_expr, layout_isize, next);
let elem_size = root.create_symbol(ident_ids, "elem_size");
let elem_size_expr = Expr::Literal(Literal::Int(elem_layout.stack_size(root.ptr_size) as i128));
let elem_size_stmt = |next| Stmt::Let(elem_size, elem_size_expr, layout_isize, next);
// let list_size = len * size
let list_size = root.create_symbol(ident_ids, "list_size");
@ -508,7 +595,7 @@ fn refcount_list_elems<'a>(
layout_isize,
list_size,
NumMul,
&[length, size],
&[length, elem_size],
next,
)
};
@ -564,8 +651,16 @@ fn refcount_list_elems<'a>(
// Next loop iteration
//
let next_addr = root.create_symbol(ident_ids, "next_addr");
let next_addr_stmt =
|next| let_lowlevel(arena, layout_isize, next_addr, NumAdd, &[addr, size], next);
let next_addr_stmt = |next| {
let_lowlevel(
arena,
layout_isize,
next_addr,
NumAdd,
&[addr, elem_size],
next,
)
};
//
// Control flow
@ -578,9 +673,7 @@ fn refcount_list_elems<'a>(
cond_symbol: is_end,
cond_layout: LAYOUT_BOOL,
ret_layout,
branches: root
.arena
.alloc([(1, BranchInfo::None, rc_return_stmt(root, ident_ids, ctx))]),
branches: root.arena.alloc([(1, BranchInfo::None, following)]),
default_branch: (
BranchInfo::None,
arena.alloc(box_stmt(arena.alloc(
@ -616,7 +709,7 @@ fn refcount_list_elems<'a>(
start_stmt(arena.alloc(
//
size_stmt(arena.alloc(
elem_size_stmt(arena.alloc(
//
list_size_stmt(arena.alloc(
//
@ -667,3 +760,510 @@ fn refcount_struct<'a>(
stmt
}
fn refcount_union<'a>(
root: &mut CodeGenHelp<'a>,
ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
union: UnionLayout<'a>,
structure: Symbol,
) -> Stmt<'a> {
use UnionLayout::*;
let parent_rec_ptr_layout = ctx.recursive_union;
if !matches!(union, NonRecursive(_)) {
ctx.recursive_union = Some(union);
}
let body = match union {
NonRecursive(tags) => refcount_union_nonrec(root, ident_ids, ctx, union, tags, structure),
Recursive(tags) => {
let (is_tailrec, tail_idx) = root.union_tail_recursion_fields(union);
if is_tailrec && !ctx.op.is_decref() {
refcount_union_tailrec(root, ident_ids, ctx, union, tags, None, tail_idx, structure)
} else {
refcount_union_rec(root, ident_ids, ctx, union, tags, None, structure)
}
}
NonNullableUnwrapped(field_layouts) => {
// We don't do tail recursion on NonNullableUnwrapped.
// Its RecursionPointer is always nested inside a List, Option, or other sub-layout, since
// a direct RecursionPointer is only possible if there's at least one non-recursive variant.
// This nesting makes it harder to do tail recursion, so we just don't.
let tags = root.arena.alloc([field_layouts]);
refcount_union_rec(root, ident_ids, ctx, union, tags, None, structure)
}
NullableWrapped {
other_tags: tags,
nullable_id,
} => {
let null_id = Some(nullable_id);
let (is_tailrec, tail_idx) = root.union_tail_recursion_fields(union);
if is_tailrec && !ctx.op.is_decref() {
refcount_union_tailrec(
root, ident_ids, ctx, union, tags, null_id, tail_idx, structure,
)
} else {
refcount_union_rec(root, ident_ids, ctx, union, tags, null_id, structure)
}
}
NullableUnwrapped {
other_fields,
nullable_id,
} => {
let null_id = Some(nullable_id as TagIdIntType);
let tags = root.arena.alloc([other_fields]);
let (is_tailrec, tail_idx) = root.union_tail_recursion_fields(union);
if is_tailrec && !ctx.op.is_decref() {
refcount_union_tailrec(
root, ident_ids, ctx, union, tags, null_id, tail_idx, structure,
)
} else {
refcount_union_rec(root, ident_ids, ctx, union, tags, null_id, structure)
}
}
};
ctx.recursive_union = parent_rec_ptr_layout;
body
}
fn refcount_union_nonrec<'a>(
root: &mut CodeGenHelp<'a>,
ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
union_layout: UnionLayout<'a>,
tag_layouts: &'a [&'a [Layout<'a>]],
structure: Symbol,
) -> Stmt<'a> {
let tag_id_layout = union_layout.tag_id_layout();
let tag_id_sym = root.create_symbol(ident_ids, "tag_id");
let tag_id_stmt = |next| {
Stmt::Let(
tag_id_sym,
Expr::GetTagId {
structure,
union_layout,
},
tag_id_layout,
next,
)
};
let continuation = rc_return_stmt(root, ident_ids, ctx);
let switch_stmt = refcount_union_contents(
root,
ident_ids,
ctx,
union_layout,
tag_layouts,
None,
structure,
tag_id_sym,
tag_id_layout,
continuation,
);
tag_id_stmt(root.arena.alloc(
//
switch_stmt,
))
}
#[allow(clippy::too_many_arguments)]
fn refcount_union_contents<'a>(
root: &mut CodeGenHelp<'a>,
ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
union_layout: UnionLayout<'a>,
tag_layouts: &'a [&'a [Layout<'a>]],
null_id: Option<TagIdIntType>,
structure: Symbol,
tag_id_sym: Symbol,
tag_id_layout: Layout<'a>,
modify_union_stmt: Stmt<'a>,
) -> Stmt<'a> {
let jp_modify_union = JoinPointId(root.create_symbol(ident_ids, "jp_modify_union"));
let mut tag_branches = Vec::with_capacity_in(tag_layouts.len() + 1, root.arena);
if let Some(id) = null_id {
let ret = rc_return_stmt(root, ident_ids, ctx);
tag_branches.push((id as u64, BranchInfo::None, ret));
}
let mut tag_id: TagIdIntType = 0;
for field_layouts in tag_layouts.iter() {
match null_id {
Some(id) if id == tag_id => {
tag_id += 1;
}
_ => {}
}
// After refcounting the fields, jump to modify the union itself
// (Order is important, to avoid use-after-free for Dec)
let following = Stmt::Jump(jp_modify_union, &[]);
let fields_stmt = refcount_tag_fields(
root,
ident_ids,
ctx,
union_layout,
field_layouts,
structure,
tag_id,
following,
);
tag_branches.push((tag_id as u64, BranchInfo::None, fields_stmt));
tag_id += 1;
}
let default_stmt: Stmt<'a> = tag_branches.pop().unwrap().2;
let tag_id_switch = Stmt::Switch {
cond_symbol: tag_id_sym,
cond_layout: tag_id_layout,
branches: tag_branches.into_bump_slice(),
default_branch: (BranchInfo::None, root.arena.alloc(default_stmt)),
ret_layout: LAYOUT_UNIT,
};
Stmt::Join {
id: jp_modify_union,
parameters: &[],
body: root.arena.alloc(modify_union_stmt),
remainder: root.arena.alloc(tag_id_switch),
}
}
fn refcount_union_rec<'a>(
root: &mut CodeGenHelp<'a>,
ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
union_layout: UnionLayout<'a>,
tag_layouts: &'a [&'a [Layout<'a>]],
null_id: Option<TagIdIntType>,
structure: Symbol,
) -> Stmt<'a> {
let tag_id_layout = union_layout.tag_id_layout();
let tag_id_sym = root.create_symbol(ident_ids, "tag_id");
let tag_id_stmt = |next| {
Stmt::Let(
tag_id_sym,
Expr::GetTagId {
structure,
union_layout,
},
tag_id_layout,
next,
)
};
let rc_structure_stmt = {
let rc_ptr = root.create_symbol(ident_ids, "rc_ptr");
let alignment = Layout::Union(union_layout).alignment_bytes(root.ptr_size);
let ret_stmt = rc_return_stmt(root, ident_ids, ctx);
let modify_structure_stmt = modify_refcount(
root,
ident_ids,
ctx,
rc_ptr,
alignment,
root.arena.alloc(ret_stmt),
);
rc_ptr_from_data_ptr(
root,
ident_ids,
structure,
rc_ptr,
union_layout.stores_tag_id_in_pointer(root.ptr_size),
root.arena.alloc(modify_structure_stmt),
)
};
let rc_contents_then_structure = if ctx.op.is_decref() {
rc_structure_stmt
} else {
refcount_union_contents(
root,
ident_ids,
ctx,
union_layout,
tag_layouts,
null_id,
structure,
tag_id_sym,
tag_id_layout,
rc_structure_stmt,
)
};
if ctx.op.is_decref() && null_id.is_none() {
rc_contents_then_structure
} else {
tag_id_stmt(root.arena.alloc(
//
rc_contents_then_structure,
))
}
}
// Refcount a recursive union using tail-call elimination to limit stack growth
#[allow(clippy::too_many_arguments)]
fn refcount_union_tailrec<'a>(
root: &mut CodeGenHelp<'a>,
ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
union_layout: UnionLayout<'a>,
tag_layouts: &'a [&'a [Layout<'a>]],
null_id: Option<TagIdIntType>,
tailrec_indices: Vec<'a, Option<usize>>,
initial_structure: Symbol,
) -> Stmt<'a> {
let tailrec_loop = JoinPointId(root.create_symbol(ident_ids, "tailrec_loop"));
let current = root.create_symbol(ident_ids, "current");
let next_ptr = root.create_symbol(ident_ids, "next_ptr");
let layout = Layout::Union(union_layout);
let tag_id_layout = union_layout.tag_id_layout();
let tag_id_sym = root.create_symbol(ident_ids, "tag_id");
let tag_id_stmt = |next| {
Stmt::Let(
tag_id_sym,
Expr::GetTagId {
structure: current,
union_layout,
},
tag_id_layout,
next,
)
};
// Do refcounting on the structure itself
// In the control flow, this comes *after* refcounting the fields
// It receives a `next` parameter to pass through to the outer joinpoint
let rc_structure_stmt = {
let rc_ptr = root.create_symbol(ident_ids, "rc_ptr");
let next_addr = root.create_symbol(ident_ids, "next_addr");
let exit_stmt = rc_return_stmt(root, ident_ids, ctx);
let jump_to_loop = Stmt::Jump(tailrec_loop, root.arena.alloc([next_ptr]));
let loop_or_exit = Stmt::Switch {
cond_symbol: next_addr,
cond_layout: root.layout_isize,
branches: root.arena.alloc([(0, BranchInfo::None, exit_stmt)]),
default_branch: (BranchInfo::None, root.arena.alloc(jump_to_loop)),
ret_layout: LAYOUT_UNIT,
};
let loop_or_exit_based_on_next_addr = {
let_lowlevel(
root.arena,
root.layout_isize,
next_addr,
PtrCast,
&[next_ptr],
root.arena.alloc(loop_or_exit),
)
};
let alignment = layout.alignment_bytes(root.ptr_size);
let modify_structure_stmt = modify_refcount(
root,
ident_ids,
ctx,
rc_ptr,
alignment,
root.arena.alloc(loop_or_exit_based_on_next_addr),
);
rc_ptr_from_data_ptr(
root,
ident_ids,
current,
rc_ptr,
union_layout.stores_tag_id_in_pointer(root.ptr_size),
root.arena.alloc(modify_structure_stmt),
)
};
let rc_contents_then_structure = {
let jp_modify_union = JoinPointId(root.create_symbol(ident_ids, "jp_modify_union"));
let mut tag_branches = Vec::with_capacity_in(tag_layouts.len() + 1, root.arena);
// If this is null, there is no refcount, no `next`, no fields. Just return.
if let Some(id) = null_id {
let ret = rc_return_stmt(root, ident_ids, ctx);
tag_branches.push((id as u64, BranchInfo::None, ret));
}
let mut tag_id: TagIdIntType = 0;
for (field_layouts, opt_tailrec_index) in tag_layouts.iter().zip(tailrec_indices) {
match null_id {
Some(id) if id == tag_id => {
tag_id += 1;
}
_ => {}
}
// After refcounting the fields, jump to modify the union itself.
// The loop param is a pointer to the next union. It gets passed through two jumps.
let (non_tailrec_fields, jump_to_modify_union) =
if let Some(tailrec_index) = opt_tailrec_index {
let mut filtered = Vec::with_capacity_in(field_layouts.len() - 1, root.arena);
let mut tail_stmt = None;
for (i, field) in field_layouts.iter().enumerate() {
if i != tailrec_index {
filtered.push(*field);
} else {
let field_val =
root.create_symbol(ident_ids, &format!("field_{}_{}", tag_id, i));
let field_val_expr = Expr::UnionAtIndex {
union_layout,
tag_id,
index: i as u64,
structure: current,
};
let jump_params = root.arena.alloc([field_val]);
let jump = root.arena.alloc(Stmt::Jump(jp_modify_union, jump_params));
tail_stmt = Some(Stmt::Let(field_val, field_val_expr, *field, jump));
}
}
(filtered.into_bump_slice(), tail_stmt.unwrap())
} else {
let zero = root.create_symbol(ident_ids, "zero");
let zero_expr = Expr::Literal(Literal::Int(0));
let zero_stmt = |next| Stmt::Let(zero, zero_expr, root.layout_isize, next);
let null = root.create_symbol(ident_ids, "null");
let null_stmt =
|next| let_lowlevel(root.arena, layout, null, PtrCast, &[zero], next);
let tail_stmt = zero_stmt(root.arena.alloc(
//
null_stmt(root.arena.alloc(
//
Stmt::Jump(jp_modify_union, root.arena.alloc([null])),
)),
));
(*field_layouts, tail_stmt)
};
let fields_stmt = refcount_tag_fields(
root,
ident_ids,
ctx,
union_layout,
non_tailrec_fields,
current,
tag_id,
jump_to_modify_union,
);
tag_branches.push((tag_id as u64, BranchInfo::None, fields_stmt));
tag_id += 1;
}
let default_stmt: Stmt<'a> = tag_branches.pop().unwrap().2;
let tag_id_switch = Stmt::Switch {
cond_symbol: tag_id_sym,
cond_layout: tag_id_layout,
branches: tag_branches.into_bump_slice(),
default_branch: (BranchInfo::None, root.arena.alloc(default_stmt)),
ret_layout: LAYOUT_UNIT,
};
let jp_param = Param {
symbol: next_ptr,
borrow: true,
layout,
};
Stmt::Join {
id: jp_modify_union,
parameters: root.arena.alloc([jp_param]),
body: root.arena.alloc(rc_structure_stmt),
remainder: root.arena.alloc(tag_id_switch),
}
};
let loop_body = tag_id_stmt(root.arena.alloc(
//
rc_contents_then_structure,
));
let loop_init = Stmt::Jump(tailrec_loop, root.arena.alloc([initial_structure]));
let loop_param = Param {
symbol: current,
borrow: true,
layout: Layout::Union(union_layout),
};
Stmt::Join {
id: tailrec_loop,
parameters: root.arena.alloc([loop_param]),
body: root.arena.alloc(loop_body),
remainder: root.arena.alloc(loop_init),
}
}
#[allow(clippy::too_many_arguments)]
fn refcount_tag_fields<'a>(
root: &mut CodeGenHelp<'a>,
ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
union_layout: UnionLayout<'a>,
field_layouts: &'a [Layout<'a>],
structure: Symbol,
tag_id: TagIdIntType,
following: Stmt<'a>,
) -> Stmt<'a> {
let mut stmt = following;
for (i, field_layout) in field_layouts.iter().enumerate().rev() {
if field_layout.contains_refcounted() {
let field_val = root.create_symbol(ident_ids, &format!("field_{}_{}", tag_id, i));
let field_val_expr = Expr::UnionAtIndex {
union_layout,
tag_id,
index: i as u64,
structure,
};
let field_val_stmt = |next| Stmt::Let(field_val, field_val_expr, *field_layout, next);
let mod_unit = root.create_symbol(ident_ids, &format!("mod_field_{}_{}", tag_id, i));
let mod_args = refcount_args(root, ctx, field_val);
let mod_expr = root
.call_specialized_op(ident_ids, ctx, *field_layout, mod_args)
.unwrap();
let mod_stmt = |next| Stmt::Let(mod_unit, mod_expr, LAYOUT_UNIT, next);
stmt = field_val_stmt(root.arena.alloc(
//
mod_stmt(root.arena.alloc(
//
stmt,
)),
))
}
}
stmt
}

View file

@ -498,6 +498,48 @@ fn eq_rosetree() {
);
}
#[test]
#[cfg(any(feature = "gen-wasm"))]
fn eq_different_rosetrees() {
// Requires two different equality procedures for `List (Rose I64)` and `List (Rose Str)`
// even though both appear in the mono Layout as `List(RecursivePointer)`
assert_evals_to!(
indoc!(
r#"
Rose a : [ Rose a (List (Rose a)) ]
a1 : Rose I64
a1 = Rose 999 []
a2 : Rose I64
a2 = Rose 0 [a1]
b1 : Rose I64
b1 = Rose 999 []
b2 : Rose I64
b2 = Rose 0 [b1]
ab = a2 == b2
c1 : Rose Str
c1 = Rose "hello" []
c2 : Rose Str
c2 = Rose "" [c1]
d1 : Rose Str
d1 = Rose "hello" []
d2 : Rose Str
d2 = Rose "" [d1]
cd = c2 == d2
ab && cd
"#
),
true,
bool
);
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
#[ignore]

View file

@ -7,6 +7,10 @@ use indoc::indoc;
#[allow(unused_imports)]
use roc_std::{RocList, RocStr};
// A "good enough" representation of a pointer for these tests, because
// we ignore the return value. As long as it's the right stack size, it's fine.
type Pointer = usize;
#[test]
#[cfg(any(feature = "gen-wasm"))]
fn str_inc() {
@ -150,3 +154,277 @@ fn struct_dealloc() {
&[0] // s
);
}
#[test]
#[cfg(any(feature = "gen-wasm"))]
fn union_nonrecursive_inc() {
type TwoStr = (RocStr, RocStr, i64);
assert_refcounts!(
indoc!(
r#"
TwoOrNone a: [ Two a a, None ]
s = Str.concat "A long enough string " "to be heap-allocated"
two : TwoOrNone Str
two = Two s s
four : TwoOrNone (TwoOrNone Str)
four = Two two two
four
"#
),
(TwoStr, TwoStr, i64),
&[4]
);
}
#[test]
#[cfg(any(feature = "gen-wasm"))]
fn union_nonrecursive_dec() {
assert_refcounts!(
indoc!(
r#"
TwoOrNone a: [ Two a a, None ]
s = Str.concat "A long enough string " "to be heap-allocated"
two : TwoOrNone Str
two = Two s s
when two is
Two x _ -> x
None -> ""
"#
),
RocStr,
&[1] // s
);
}
#[test]
#[cfg(any(feature = "gen-wasm"))]
fn union_recursive_inc() {
assert_refcounts!(
indoc!(
r#"
Expr : [ Sym Str, Add Expr Expr ]
s = Str.concat "heap_allocated" "_symbol_name"
x : Expr
x = Sym s
e : Expr
e = Add x x
Pair e e
"#
),
(Pointer, Pointer),
&[
4, // s
4, // sym
2, // e
]
);
}
#[test]
#[cfg(any(feature = "gen-wasm"))]
fn union_recursive_dec() {
assert_refcounts!(
indoc!(
r#"
Expr : [ Sym Str, Add Expr Expr ]
s = Str.concat "heap_allocated" "_symbol_name"
x : Expr
x = Sym s
e : Expr
e = Add x x
when e is
Add y _ -> y
Sym _ -> e
"#
),
Pointer,
&[
1, // s
1, // sym
0 // e
]
);
}
#[test]
#[cfg(any(feature = "gen-wasm"))]
fn refcount_different_rosetrees_inc() {
// Requires two different Inc procedures for `List (Rose I64)` and `List (Rose Str)`
// even though both appear in the mono Layout as `List(RecursivePointer)`
assert_refcounts!(
indoc!(
r#"
Rose a : [ Rose a (List (Rose a)) ]
s = Str.concat "A long enough string " "to be heap-allocated"
i1 : Rose I64
i1 = Rose 999 []
s1 : Rose Str
s1 = Rose s []
i2 : Rose I64
i2 = Rose 0 [i1, i1, i1]
s2 : Rose Str
s2 = Rose "" [s1, s1]
Tuple i2 s2
"#
),
(Pointer, Pointer),
&[
2, // s
3, // i1
2, // s1
1, // [i1, i1]
1, // i2
1, // [s1, s1]
1 // s2
]
);
}
#[test]
#[cfg(any(feature = "gen-wasm"))]
fn refcount_different_rosetrees_dec() {
// Requires two different Dec procedures for `List (Rose I64)` and `List (Rose Str)`
// even though both appear in the mono Layout as `List(RecursivePointer)`
assert_refcounts!(
indoc!(
r#"
Rose a : [ Rose a (List (Rose a)) ]
s = Str.concat "A long enough string " "to be heap-allocated"
i1 : Rose I64
i1 = Rose 999 []
s1 : Rose Str
s1 = Rose s []
i2 : Rose I64
i2 = Rose 0 [i1, i1]
s2 : Rose Str
s2 = Rose "" [s1, s1]
when (Tuple i2 s2) is
Tuple (Rose x _) _ -> x
"#
),
i64,
&[
0, // s
0, // i1
0, // s1
0, // [i1, i1]
0, // i2
0, // [s1, s1]
0, // s2
]
);
}
#[test]
#[cfg(any(feature = "gen-wasm"))]
fn union_linked_list_inc() {
assert_refcounts!(
indoc!(
r#"
LinkedList a : [ Nil, Cons a (LinkedList a) ]
s = Str.concat "A long enough string " "to be heap-allocated"
linked : LinkedList Str
linked = Cons s (Cons s (Cons s Nil))
Tuple linked linked
"#
),
(Pointer, Pointer),
&[
6, // s
2, // Cons
2, // Cons
2, // Cons
]
);
}
#[test]
#[cfg(any(feature = "gen-wasm"))]
fn union_linked_list_dec() {
assert_refcounts!(
indoc!(
r#"
LinkedList a : [ Nil, Cons a (LinkedList a) ]
s = Str.concat "A long enough string " "to be heap-allocated"
linked : LinkedList Str
linked = Cons s (Cons s (Cons s Nil))
when linked is
Cons x _ -> x
Nil -> ""
"#
),
RocStr,
&[
1, // s
0, // Cons
0, // Cons
0, // Cons
]
);
}
#[test]
#[cfg(any(feature = "gen-wasm"))]
fn union_linked_list_long_dec() {
assert_refcounts!(
indoc!(
r#"
app "test" provides [ main ] to "./platform"
LinkedList a : [ Nil, Cons a (LinkedList a) ]
prependOnes = \n, tail ->
if n == 0 then
tail
else
prependOnes (n-1) (Cons 1 tail)
main =
n = 1_000
linked : LinkedList I64
linked = prependOnes n Nil
when linked is
Cons x _ -> x
Nil -> -1
"#
),
i64,
&[0; 1_000]
);
}

View file

@ -307,28 +307,38 @@ where
let memory = instance.exports.get_memory(MEMORY_NAME).unwrap();
let expected_len = num_refcounts as i32;
let init_refcount_test = instance.exports.get_function("init_refcount_test").unwrap();
let init_result = init_refcount_test.call(&[wasmer::Value::I32(num_refcounts as i32)]);
let refcount_array_addr = match init_result {
let init_result = init_refcount_test.call(&[wasmer::Value::I32(expected_len)]);
let refcount_vector_addr = match init_result {
Err(e) => return Err(format!("{:?}", e)),
Ok(result) => match result[0] {
wasmer::Value::I32(a) => a,
_ => panic!(),
},
};
// An array of refcount pointers
let refcount_ptr_array: WasmPtr<WasmPtr<i32>, wasmer::Array> =
WasmPtr::new(refcount_array_addr as u32);
let refcount_ptrs: &[Cell<WasmPtr<i32>>] = refcount_ptr_array
.deref(memory, 0, num_refcounts as u32)
.unwrap();
// Run the test
let test_wrapper = instance.exports.get_function(TEST_WRAPPER_NAME).unwrap();
match test_wrapper.call(&[]) {
Err(e) => return Err(format!("{:?}", e)),
Ok(_) => {}
}
// Check we got the right number of refcounts
let refcount_vector_len: WasmPtr<i32> = WasmPtr::new(refcount_vector_addr as u32);
let actual_len = refcount_vector_len.deref(memory).unwrap().get();
if actual_len != expected_len {
panic!("Expected {} refcounts but got {}", expected_len, actual_len);
}
// Read the actual refcount values
let refcount_ptr_array: WasmPtr<WasmPtr<i32>, wasmer::Array> =
WasmPtr::new(4 + refcount_vector_addr as u32);
let refcount_ptrs: &[Cell<WasmPtr<i32>>] = refcount_ptr_array
.deref(memory, 0, num_refcounts as u32)
.unwrap();
let mut refcounts = Vec::with_capacity(num_refcounts);
for i in 0..num_refcounts {
let rc_ptr = refcount_ptrs[i].get();

View file

@ -3,35 +3,41 @@
// Makes test runs take 50% longer, due to linking
#define ENABLE_PRINTF 0
typedef struct
{
size_t length;
size_t *elements[]; // flexible array member
} Vector;
// Globals for refcount testing
size_t **rc_pointers; // array of pointers to refcount values
size_t rc_pointers_len;
size_t rc_pointers_index;
Vector *rc_pointers;
size_t rc_pointers_capacity;
// The rust test passes us the max number of allocations it expects to make,
// and we tell it where we're going to write the refcount pointers.
// It won't actually read that memory until later, when the test is done.
size_t **init_refcount_test(size_t max_allocs)
Vector *init_refcount_test(size_t capacity)
{
rc_pointers = malloc(max_allocs * sizeof(size_t *));
rc_pointers_len = max_allocs;
rc_pointers_index = 0;
for (size_t i = 0; i < max_allocs; ++i)
rc_pointers[i] = NULL;
rc_pointers_capacity = capacity;
rc_pointers = malloc((1 + capacity) * sizeof(size_t *));
rc_pointers->length = 0;
for (size_t i = 0; i < capacity; ++i)
rc_pointers->elements[i] = NULL;
return rc_pointers;
}
#if ENABLE_PRINTF
#define ASSERT(x) \
if (!(x)) \
#define ASSERT(condition, format, ...) \
if (!(condition)) \
{ \
printf("FAILED: " #x "\n"); \
printf("ASSERT FAILED: " #format "\n", __VA_ARGS__); \
abort(); \
}
#else
#define ASSERT(x) \
if (!(x)) \
#define ASSERT(condition, format, ...) \
if (!(condition)) \
abort();
#endif
@ -50,12 +56,13 @@ void *roc_alloc(size_t size, unsigned int alignment)
if (rc_pointers)
{
ASSERT(alignment >= sizeof(size_t));
ASSERT(rc_pointers_index < rc_pointers_len);
ASSERT(alignment >= sizeof(size_t), "alignment %zd != %zd", alignment, sizeof(size_t));
size_t num_alloc = rc_pointers->length + 1;
ASSERT(num_alloc <= rc_pointers_capacity, "Too many allocations %zd > %zd", num_alloc, rc_pointers_capacity);
size_t *rc_ptr = alloc_ptr_to_rc_ptr(allocated, alignment);
rc_pointers[rc_pointers_index] = rc_ptr;
rc_pointers_index++;
rc_pointers->elements[rc_pointers->length] = rc_ptr;
rc_pointers->length++;
}
#if ENABLE_PRINTF
@ -94,16 +101,16 @@ void roc_dealloc(void *ptr, unsigned int alignment)
// Then even if malloc reuses the space, everything still works
size_t *rc_ptr = alloc_ptr_to_rc_ptr(ptr, alignment);
int i = 0;
for (; i < rc_pointers_index; ++i)
for (; i < rc_pointers->length; ++i)
{
if (rc_pointers[i] == rc_ptr)
if (rc_pointers->elements[i] == rc_ptr)
{
rc_pointers[i] = NULL;
rc_pointers->elements[i] = NULL;
break;
}
}
int was_found = i < rc_pointers_index;
ASSERT(was_found);
int was_found = i < rc_pointers->length;
ASSERT(was_found, "RC pointer not found %p", rc_ptr);
}
#if ENABLE_PRINTF