moved all crates into seperate folder + related path fixes

This commit is contained in:
Anton-4 2022-07-01 17:37:43 +02:00
parent 12ef03bb86
commit eee85fa45d
No known key found for this signature in database
GPG key ID: C954D6E0F9C0ABFD
1063 changed files with 92 additions and 93 deletions

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,811 @@
use bumpalo::collections::vec::Vec;
use roc_module::low_level::LowLevel;
use roc_module::symbol::{IdentIds, Symbol};
use crate::ir::{
BranchInfo, Call, CallType, Expr, JoinPointId, Literal, Param, Stmt, UpdateModeId,
};
use crate::layout::{Builtin, Layout, TagIdIntType, UnionLayout};
use super::{let_lowlevel, CodeGenHelp, Context, LAYOUT_BOOL};
const ARG_1: Symbol = Symbol::ARG_1;
const ARG_2: Symbol = Symbol::ARG_2;
pub fn eq_generic<'a>(
root: &mut CodeGenHelp<'a>,
ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
layout: Layout<'a>,
) -> Stmt<'a> {
let eq_todo = || todo!("Specialized `==` operator for `{:?}`", layout);
let main_body = match layout {
Layout::Builtin(Builtin::Int(_) | Builtin::Float(_) | Builtin::Bool | Builtin::Decimal) => {
unreachable!(
"No generated proc for `==`. Use direct code gen for {:?}",
layout
)
}
Layout::Builtin(Builtin::Str) => {
unreachable!("No generated helper proc for `==` on Str. Use Zig function.")
}
Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_)) => eq_todo(),
Layout::Builtin(Builtin::List(elem_layout)) => eq_list(root, ident_ids, ctx, elem_layout),
Layout::Struct { field_layouts, .. } => eq_struct(root, ident_ids, ctx, field_layouts),
Layout::Union(union_layout) => eq_tag_union(root, ident_ids, ctx, union_layout),
Layout::Boxed(inner_layout) => eq_boxed(root, ident_ids, ctx, inner_layout),
Layout::LambdaSet(_) => unreachable!("`==` is not defined on functions"),
Layout::RecursivePointer => {
unreachable!(
"Can't perform `==` on RecursivePointer. Should have been replaced by a tag union."
)
}
};
Stmt::Let(
Symbol::BOOL_TRUE,
Expr::Literal(Literal::Int(1i128.to_ne_bytes())),
LAYOUT_BOOL,
root.arena.alloc(Stmt::Let(
Symbol::BOOL_FALSE,
Expr::Literal(Literal::Int(0i128.to_ne_bytes())),
LAYOUT_BOOL,
root.arena.alloc(main_body),
)),
)
}
fn if_pointers_equal_return_true<'a>(
root: &CodeGenHelp<'a>,
ident_ids: &mut IdentIds,
operands: [Symbol; 2],
following: &'a Stmt<'a>,
) -> Stmt<'a> {
let ptr1_addr = root.create_symbol(ident_ids, "addr1");
let ptr2_addr = root.create_symbol(ident_ids, "addr2");
let ptr_eq = root.create_symbol(ident_ids, "eq_addr");
Stmt::Let(
ptr1_addr,
Expr::Call(Call {
call_type: CallType::LowLevel {
op: LowLevel::PtrCast,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: root.arena.alloc([operands[0]]),
}),
root.layout_isize,
root.arena.alloc(Stmt::Let(
ptr2_addr,
Expr::Call(Call {
call_type: CallType::LowLevel {
op: LowLevel::PtrCast,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: root.arena.alloc([operands[1]]),
}),
root.layout_isize,
root.arena.alloc(Stmt::Let(
ptr_eq,
Expr::Call(Call {
call_type: CallType::LowLevel {
op: LowLevel::Eq,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: root.arena.alloc([ptr1_addr, ptr2_addr]),
}),
LAYOUT_BOOL,
root.arena.alloc(Stmt::Switch {
cond_symbol: ptr_eq,
cond_layout: LAYOUT_BOOL,
branches: root.arena.alloc([(
1,
BranchInfo::None,
Stmt::Ret(Symbol::BOOL_TRUE),
)]),
default_branch: (BranchInfo::None, following),
ret_layout: LAYOUT_BOOL,
}),
)),
)),
)
}
fn if_false_return_false<'a>(
root: &CodeGenHelp<'a>,
symbol: Symbol,
following: &'a Stmt<'a>,
) -> Stmt<'a> {
Stmt::Switch {
cond_symbol: symbol,
cond_layout: LAYOUT_BOOL,
branches: root
.arena
.alloc([(0, BranchInfo::None, Stmt::Ret(Symbol::BOOL_FALSE))]),
default_branch: (BranchInfo::None, following),
ret_layout: LAYOUT_BOOL,
}
}
fn eq_struct<'a>(
root: &mut CodeGenHelp<'a>,
ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
field_layouts: &'a [Layout<'a>],
) -> Stmt<'a> {
let mut else_stmt = Stmt::Ret(Symbol::BOOL_TRUE);
for (i, layout) in field_layouts.iter().enumerate().rev() {
let field1_sym = root.create_symbol(ident_ids, &format!("field_1_{}", i));
let field1_expr = Expr::StructAtIndex {
index: i as u64,
field_layouts,
structure: ARG_1,
};
let field1_stmt = |next| Stmt::Let(field1_sym, field1_expr, *layout, next);
let field2_sym = root.create_symbol(ident_ids, &format!("field_2_{}", i));
let field2_expr = Expr::StructAtIndex {
index: i as u64,
field_layouts,
structure: ARG_2,
};
let field2_stmt = |next| Stmt::Let(field2_sym, field2_expr, *layout, next);
let eq_call_expr = root
.call_specialized_op(
ident_ids,
ctx,
*layout,
root.arena.alloc([field1_sym, field2_sym]),
)
.unwrap();
let eq_call_name = format!("eq_call_{}", i);
let eq_call_sym = root.create_symbol(ident_ids, &eq_call_name);
let eq_call_stmt = |next| Stmt::Let(eq_call_sym, eq_call_expr, LAYOUT_BOOL, next);
else_stmt = field1_stmt(root.arena.alloc(
//
field2_stmt(root.arena.alloc(
//
eq_call_stmt(root.arena.alloc(
//
if_false_return_false(root, eq_call_sym, root.arena.alloc(else_stmt)),
)),
)),
))
}
if_pointers_equal_return_true(root, ident_ids, [ARG_1, ARG_2], root.arena.alloc(else_stmt))
}
fn eq_tag_union<'a>(
root: &mut CodeGenHelp<'a>,
ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
union_layout: UnionLayout<'a>,
) -> Stmt<'a> {
use UnionLayout::*;
let parent_rec_ptr_layout = ctx.recursive_union;
if !matches!(union_layout, NonRecursive(_)) {
ctx.recursive_union = Some(union_layout);
}
let body = match union_layout {
NonRecursive(tags) => eq_tag_union_help(root, ident_ids, ctx, union_layout, tags, None),
Recursive(tags) => eq_tag_union_help(root, ident_ids, ctx, union_layout, tags, None),
NonNullableUnwrapped(field_layouts) => {
let tags = root.arena.alloc([field_layouts]);
eq_tag_union_help(root, ident_ids, ctx, union_layout, tags, None)
}
NullableWrapped {
other_tags,
nullable_id,
} => eq_tag_union_help(
root,
ident_ids,
ctx,
union_layout,
other_tags,
Some(nullable_id),
),
NullableUnwrapped {
other_fields,
nullable_id,
} => eq_tag_union_help(
root,
ident_ids,
ctx,
union_layout,
root.arena.alloc([other_fields]),
Some(nullable_id as TagIdIntType),
),
};
ctx.recursive_union = parent_rec_ptr_layout;
body
}
fn eq_tag_union_help<'a>(
root: &mut CodeGenHelp<'a>,
ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
union_layout: UnionLayout<'a>,
tag_layouts: &'a [&'a [Layout<'a>]],
nullable_id: Option<TagIdIntType>,
) -> Stmt<'a> {
let tailrec_loop = JoinPointId(root.create_symbol(ident_ids, "tailrec_loop"));
let is_non_recursive = matches!(union_layout, UnionLayout::NonRecursive(_));
let operands = if is_non_recursive {
[ARG_1, ARG_2]
} else {
[
root.create_symbol(ident_ids, "a"),
root.create_symbol(ident_ids, "b"),
]
};
let tag_id_layout = union_layout.tag_id_layout();
let tag_id_a = root.create_symbol(ident_ids, "tag_id_a");
let tag_id_a_stmt = |next| {
Stmt::Let(
tag_id_a,
Expr::GetTagId {
structure: operands[0],
union_layout,
},
tag_id_layout,
next,
)
};
let tag_id_b = root.create_symbol(ident_ids, "tag_id_b");
let tag_id_b_stmt = |next| {
Stmt::Let(
tag_id_b,
Expr::GetTagId {
structure: operands[1],
union_layout,
},
tag_id_layout,
next,
)
};
let tag_ids_eq = root.create_symbol(ident_ids, "tag_ids_eq");
let tag_ids_expr = Expr::Call(Call {
call_type: CallType::LowLevel {
op: LowLevel::Eq,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: root.arena.alloc([tag_id_a, tag_id_b]),
});
let tag_ids_eq_stmt = |next| Stmt::Let(tag_ids_eq, tag_ids_expr, LAYOUT_BOOL, next);
let if_equal_ids_branches =
root.arena
.alloc([(0, BranchInfo::None, Stmt::Ret(Symbol::BOOL_FALSE))]);
//
// Switch statement by tag ID
//
let mut tag_branches = Vec::with_capacity_in(tag_layouts.len(), root.arena);
// If there's a null tag, check it first. We might not need to load any data from memory.
if let Some(id) = nullable_id {
tag_branches.push((id as u64, BranchInfo::None, Stmt::Ret(Symbol::BOOL_TRUE)))
}
let mut tag_id: TagIdIntType = 0;
for field_layouts in tag_layouts.iter().take(tag_layouts.len() - 1) {
if let Some(null_id) = nullable_id {
if tag_id == null_id as TagIdIntType {
tag_id += 1;
}
}
let tag_stmt = eq_tag_fields(
root,
ident_ids,
ctx,
tailrec_loop,
union_layout,
field_layouts,
operands,
tag_id,
);
tag_branches.push((tag_id as u64, BranchInfo::None, tag_stmt));
tag_id += 1;
}
let tag_switch_stmt = Stmt::Switch {
cond_symbol: tag_id_a,
cond_layout: tag_id_layout,
branches: tag_branches.into_bump_slice(),
default_branch: (
BranchInfo::None,
root.arena.alloc(eq_tag_fields(
root,
ident_ids,
ctx,
tailrec_loop,
union_layout,
tag_layouts.last().unwrap(),
operands,
tag_id,
)),
),
ret_layout: LAYOUT_BOOL,
};
let if_equal_ids_stmt = Stmt::Switch {
cond_symbol: tag_ids_eq,
cond_layout: LAYOUT_BOOL,
branches: if_equal_ids_branches,
default_branch: (BranchInfo::None, root.arena.alloc(tag_switch_stmt)),
ret_layout: LAYOUT_BOOL,
};
//
// combine all the statments
//
let compare_values = tag_id_a_stmt(root.arena.alloc(
//
tag_id_b_stmt(root.arena.alloc(
//
tag_ids_eq_stmt(root.arena.alloc(
//
if_equal_ids_stmt,
)),
)),
));
let compare_ptr_or_value =
if_pointers_equal_return_true(root, ident_ids, operands, root.arena.alloc(compare_values));
if is_non_recursive {
compare_ptr_or_value
} else {
let loop_params_iter = operands.iter().map(|arg| Param {
symbol: *arg,
borrow: true,
layout: Layout::Union(union_layout),
});
let loop_start = Stmt::Jump(tailrec_loop, root.arena.alloc([ARG_1, ARG_2]));
Stmt::Join {
id: tailrec_loop,
parameters: root.arena.alloc_slice_fill_iter(loop_params_iter),
body: root.arena.alloc(compare_ptr_or_value),
remainder: root.arena.alloc(loop_start),
}
}
}
#[allow(clippy::too_many_arguments)]
fn eq_tag_fields<'a>(
root: &mut CodeGenHelp<'a>,
ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
tailrec_loop: JoinPointId,
union_layout: UnionLayout<'a>,
field_layouts: &'a [Layout<'a>],
operands: [Symbol; 2],
tag_id: TagIdIntType,
) -> Stmt<'a> {
// Find a RecursivePointer to use in the tail recursion loop
// (If there are more than one, the others will use non-tail recursion)
let rec_ptr_index = field_layouts
.iter()
.position(|field| matches!(field, Layout::RecursivePointer));
let (tailrec_index, innermost_stmt) = match rec_ptr_index {
None => {
// This tag has no RecursivePointers. Set tailrec_index out of range.
(field_layouts.len(), Stmt::Ret(Symbol::BOOL_TRUE))
}
Some(i) => {
// Implement tail recursion on this RecursivePointer,
// in the innermost `else` clause after all other fields have been checked
let field1_sym = root.create_symbol(ident_ids, &format!("field_1_{}_{}", tag_id, i));
let field2_sym = root.create_symbol(ident_ids, &format!("field_2_{}_{}", tag_id, i));
let field1_expr = Expr::UnionAtIndex {
union_layout,
tag_id,
index: i as u64,
structure: operands[0],
};
let field2_expr = Expr::UnionAtIndex {
union_layout,
tag_id,
index: i as u64,
structure: operands[1],
};
let inner = Stmt::Let(
field1_sym,
field1_expr,
field_layouts[i],
root.arena.alloc(
//
Stmt::Let(
field2_sym,
field2_expr,
field_layouts[i],
root.arena.alloc(
//
Stmt::Jump(tailrec_loop, root.arena.alloc([field1_sym, field2_sym])),
),
),
),
);
(i, inner)
}
};
let mut stmt = innermost_stmt;
for (i, layout) in field_layouts.iter().enumerate().rev() {
if i == tailrec_index {
continue; // the tail-recursive field is handled elsewhere
}
let field1_sym = root.create_symbol(ident_ids, &format!("field_1_{}_{}", tag_id, i));
let field2_sym = root.create_symbol(ident_ids, &format!("field_2_{}_{}", tag_id, i));
let field1_expr = Expr::UnionAtIndex {
union_layout,
tag_id,
index: i as u64,
structure: operands[0],
};
let field2_expr = Expr::UnionAtIndex {
union_layout,
tag_id,
index: i as u64,
structure: operands[1],
};
let eq_call_expr = root
.call_specialized_op(
ident_ids,
ctx,
*layout,
root.arena.alloc([field1_sym, field2_sym]),
)
.unwrap();
let eq_call_name = format!("eq_call_{}", i);
let eq_call_sym = root.create_symbol(ident_ids, &eq_call_name);
stmt = Stmt::Let(
field1_sym,
field1_expr,
field_layouts[i],
root.arena.alloc(
//
Stmt::Let(
field2_sym,
field2_expr,
field_layouts[i],
root.arena.alloc(
//
Stmt::Let(
eq_call_sym,
eq_call_expr,
LAYOUT_BOOL,
root.arena.alloc(
//
if_false_return_false(
root,
eq_call_sym,
root.arena.alloc(
//
stmt,
),
),
),
),
),
),
),
)
}
stmt
}
fn eq_boxed<'a>(
_root: &mut CodeGenHelp<'a>,
_ident_ids: &mut IdentIds,
_ctx: &mut Context<'a>,
_inner_layout: &'a Layout<'a>,
) -> Stmt<'a> {
todo!()
}
/// List equality
/// We can't use `ListGetUnsafe` because it increments the refcount, and we don't want that.
/// Another way to dereference a heap pointer is to use `Expr::UnionAtIndex`.
/// To achieve this we use `PtrCast` to cast the element pointer to a "Box" layout.
/// Then we can increment the Box pointer in a loop, dereferencing it each time.
/// (An alternative approach would be to create a new lowlevel like ListPeekUnsafe.)
fn eq_list<'a>(
root: &mut CodeGenHelp<'a>,
ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
elem_layout: &Layout<'a>,
) -> Stmt<'a> {
use LowLevel::*;
let layout_isize = root.layout_isize;
let arena = root.arena;
// A "Box" layout (heap pointer to a single list element)
let box_union_layout = UnionLayout::NonNullableUnwrapped(root.arena.alloc([*elem_layout]));
let box_layout = Layout::Union(box_union_layout);
// Compare lengths
let len_1 = root.create_symbol(ident_ids, "len_1");
let len_2 = root.create_symbol(ident_ids, "len_2");
let len_1_stmt = |next| let_lowlevel(arena, layout_isize, len_1, ListLen, &[ARG_1], next);
let len_2_stmt = |next| let_lowlevel(arena, layout_isize, len_2, ListLen, &[ARG_2], next);
let eq_len = root.create_symbol(ident_ids, "eq_len");
let eq_len_stmt = |next| let_lowlevel(arena, LAYOUT_BOOL, eq_len, Eq, &[len_1, len_2], next);
// if lengths are equal...
// get element pointers
let elements_1 = root.create_symbol(ident_ids, "elements_1");
let elements_2 = root.create_symbol(ident_ids, "elements_2");
let elements_1_expr = Expr::StructAtIndex {
index: 0,
field_layouts: root.arena.alloc([box_layout, layout_isize]),
structure: ARG_1,
};
let elements_2_expr = Expr::StructAtIndex {
index: 0,
field_layouts: root.arena.alloc([box_layout, layout_isize]),
structure: ARG_2,
};
let elements_1_stmt = |next| Stmt::Let(elements_1, elements_1_expr, box_layout, next);
let elements_2_stmt = |next| Stmt::Let(elements_2, elements_2_expr, box_layout, next);
// Cast to integers
let start_1 = root.create_symbol(ident_ids, "start_1");
let start_2 = root.create_symbol(ident_ids, "start_2");
let start_1_stmt =
|next| let_lowlevel(arena, layout_isize, start_1, PtrCast, &[elements_1], next);
let start_2_stmt =
|next| let_lowlevel(arena, layout_isize, start_2, PtrCast, &[elements_2], next);
//
// Loop initialisation
//
// let size = literal int
let size = root.create_symbol(ident_ids, "size");
let size_expr = Expr::Literal(Literal::Int(
(elem_layout.stack_size(root.target_info) as i128).to_ne_bytes(),
));
let size_stmt = |next| Stmt::Let(size, size_expr, layout_isize, next);
// let list_size = len_1 * size
let list_size = root.create_symbol(ident_ids, "list_size");
let list_size_stmt =
|next| let_lowlevel(arena, layout_isize, list_size, NumMul, &[len_1, size], next);
// let end_1 = start_1 + list_size
let end_1 = root.create_symbol(ident_ids, "end_1");
let end_1_stmt = |next| {
let_lowlevel(
arena,
layout_isize,
end_1,
NumAdd,
&[start_1, list_size],
next,
)
};
//
// Loop name & parameters
//
let elems_loop = JoinPointId(root.create_symbol(ident_ids, "elems_loop"));
let addr1 = root.create_symbol(ident_ids, "addr1");
let addr2 = root.create_symbol(ident_ids, "addr2");
let param_addr1 = Param {
symbol: addr1,
borrow: false,
layout: layout_isize,
};
let param_addr2 = Param {
symbol: addr2,
borrow: false,
layout: layout_isize,
};
//
// if we haven't reached the end yet...
//
// Cast integers to box pointers
let box1 = root.create_symbol(ident_ids, "box1");
let box2 = root.create_symbol(ident_ids, "box2");
let box1_stmt = |next| let_lowlevel(arena, box_layout, box1, PtrCast, &[addr1], next);
let box2_stmt = |next| let_lowlevel(arena, box_layout, box2, PtrCast, &[addr2], next);
// Dereference the box pointers to get the current elements
let elem1 = root.create_symbol(ident_ids, "elem1");
let elem2 = root.create_symbol(ident_ids, "elem2");
let elem1_expr = Expr::UnionAtIndex {
structure: box1,
union_layout: box_union_layout,
tag_id: 0,
index: 0,
};
let elem2_expr = Expr::UnionAtIndex {
structure: box2,
union_layout: box_union_layout,
tag_id: 0,
index: 0,
};
let elem1_stmt = |next| Stmt::Let(elem1, elem1_expr, *elem_layout, next);
let elem2_stmt = |next| Stmt::Let(elem2, elem2_expr, *elem_layout, next);
// Compare the two current elements
let eq_elems = root.create_symbol(ident_ids, "eq_elems");
let eq_elems_args = root.arena.alloc([elem1, elem2]);
let eq_elems_expr = root
.call_specialized_op(ident_ids, ctx, *elem_layout, eq_elems_args)
.unwrap();
let eq_elems_stmt = |next| Stmt::Let(eq_elems, eq_elems_expr, LAYOUT_BOOL, next);
// If current elements are equal, loop back again
let next_1 = root.create_symbol(ident_ids, "next_1");
let next_2 = root.create_symbol(ident_ids, "next_2");
let next_1_stmt =
|next| let_lowlevel(arena, layout_isize, next_1, NumAdd, &[addr1, size], next);
let next_2_stmt =
|next| let_lowlevel(arena, layout_isize, next_2, NumAdd, &[addr2, size], next);
let jump_back = Stmt::Jump(elems_loop, root.arena.alloc([next_1, next_2]));
//
// Control flow
//
let is_end = root.create_symbol(ident_ids, "is_end");
let is_end_stmt =
|next| let_lowlevel(arena, LAYOUT_BOOL, is_end, NumGte, &[addr1, end_1], next);
let if_elems_not_equal = if_false_return_false(
root,
eq_elems,
// else
root.arena.alloc(
//
next_1_stmt(root.arena.alloc(
//
next_2_stmt(root.arena.alloc(
//
jump_back,
)),
)),
),
);
let if_end_of_list = Stmt::Switch {
cond_symbol: is_end,
cond_layout: LAYOUT_BOOL,
ret_layout: LAYOUT_BOOL,
branches: root
.arena
.alloc([(1, BranchInfo::None, Stmt::Ret(Symbol::BOOL_TRUE))]),
default_branch: (
BranchInfo::None,
root.arena.alloc(
//
box1_stmt(root.arena.alloc(
//
box2_stmt(root.arena.alloc(
//
elem1_stmt(root.arena.alloc(
//
elem2_stmt(root.arena.alloc(
//
eq_elems_stmt(root.arena.alloc(
//
if_elems_not_equal,
)),
)),
)),
)),
)),
),
),
};
let joinpoint_loop = Stmt::Join {
id: elems_loop,
parameters: root.arena.alloc([param_addr1, param_addr2]),
body: root.arena.alloc(
//
is_end_stmt(
//
root.arena.alloc(if_end_of_list),
),
),
remainder: root
.arena
.alloc(Stmt::Jump(elems_loop, root.arena.alloc([start_1, start_2]))),
};
let if_different_lengths = if_false_return_false(
root,
eq_len,
// else
root.arena.alloc(
//
elements_1_stmt(root.arena.alloc(
//
elements_2_stmt(root.arena.alloc(
//
start_1_stmt(root.arena.alloc(
//
start_2_stmt(root.arena.alloc(
//
size_stmt(root.arena.alloc(
//
list_size_stmt(root.arena.alloc(
//
end_1_stmt(root.arena.alloc(
//
joinpoint_loop,
)),
)),
)),
)),
)),
)),
)),
),
);
let pointers_else = len_1_stmt(root.arena.alloc(
//
len_2_stmt(root.arena.alloc(
//
eq_len_stmt(root.arena.alloc(
//
if_different_lengths,
)),
)),
));
if_pointers_equal_return_true(
root,
ident_ids,
[ARG_1, ARG_2],
root.arena.alloc(pointers_else),
)
}

View file

@ -0,0 +1,548 @@
use bumpalo::collections::vec::Vec;
use bumpalo::Bump;
use roc_module::low_level::LowLevel;
use roc_module::symbol::{IdentIds, ModuleId, Symbol};
use roc_target::TargetInfo;
use crate::ir::{
Call, CallSpecId, CallType, Expr, HostExposedLayouts, JoinPointId, ModifyRc, Proc, ProcLayout,
SelfRecursive, Stmt, UpdateModeId,
};
use crate::layout::{Builtin, Layout, UnionLayout};
mod equality;
mod refcount;
const LAYOUT_BOOL: Layout = Layout::Builtin(Builtin::Bool);
const LAYOUT_UNIT: Layout = Layout::UNIT;
const ARG_1: Symbol = Symbol::ARG_1;
const ARG_2: Symbol = Symbol::ARG_2;
/// "Infinite" reference count, for static values
/// Ref counts are encoded as negative numbers where isize::MIN represents 1
pub const REFCOUNT_MAX: usize = 0;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum HelperOp {
Inc,
Dec,
DecRef(JoinPointId),
Reset,
Eq,
}
impl HelperOp {
fn is_decref(&self) -> bool {
matches!(self, Self::DecRef(_))
}
}
#[derive(Debug)]
struct Specialization<'a> {
op: HelperOp,
layout: Layout<'a>,
symbol: Symbol,
proc: Option<Proc<'a>>,
}
#[derive(Debug)]
pub struct Context<'a> {
new_linker_data: Vec<'a, (Symbol, ProcLayout<'a>)>,
recursive_union: Option<UnionLayout<'a>>,
op: HelperOp,
}
/// Generate specialized helper procs for code gen
/// ----------------------------------------------
///
/// Some low level operations need specialized helper procs to traverse data structures at runtime.
/// This includes refcounting, hashing, and equality checks.
///
/// For example, when checking List equality, we need to visit each element and compare them.
/// Depending on the type of the list elements, we may need to recurse deeper into each element.
/// For tag unions, we may need branches for different tag IDs, etc.
///
/// This module creates specialized helper procs for all such operations and types used in the program.
///
/// The backend drives the process, in two steps:
/// 1) When it sees the relevant node, it calls CodeGenHelp to get the replacement IR.
/// CodeGenHelp returns IR for a call to the helper proc, and remembers the specialization.
/// 2) After the backend has generated code for all user procs, it takes the IR for all of the
/// specialized helpers procs, and generates target code for them too.
///
pub struct CodeGenHelp<'a> {
arena: &'a Bump,
home: ModuleId,
target_info: TargetInfo,
layout_isize: Layout<'a>,
union_refcount: UnionLayout<'a>,
specializations: Vec<'a, Specialization<'a>>,
debug_recursion_depth: usize,
}
impl<'a> CodeGenHelp<'a> {
pub fn new(arena: &'a Bump, target_info: TargetInfo, home: ModuleId) -> Self {
let layout_isize = Layout::isize(target_info);
// Refcount is a boxed isize. TODO: use the new Box layout when dev backends support it
let union_refcount = UnionLayout::NonNullableUnwrapped(arena.alloc([layout_isize]));
CodeGenHelp {
arena,
home,
target_info,
layout_isize,
union_refcount,
specializations: Vec::with_capacity_in(16, arena),
debug_recursion_depth: 0,
}
}
pub fn take_procs(&mut self) -> Vec<'a, Proc<'a>> {
let procs_iter = self
.specializations
.drain(0..)
.map(|spec| spec.proc.unwrap());
Vec::from_iter_in(procs_iter, self.arena)
}
// ============================================================================
//
// CALL GENERATED PROCS
//
// ============================================================================
/// Expand a `Refcounting` node to a `Let` node that calls a specialized helper proc.
/// The helper procs themselves are to be generated later with `generate_procs`
pub fn expand_refcount_stmt(
&mut self,
ident_ids: &mut IdentIds,
layout: Layout<'a>,
modify: &ModifyRc,
following: &'a Stmt<'a>,
) -> (&'a Stmt<'a>, Vec<'a, (Symbol, ProcLayout<'a>)>) {
if !refcount::is_rc_implemented_yet(&layout) {
// Just a warning, so we can decouple backend development from refcounting development.
// When we are closer to completion, we can change it to a panic.
println!(
"WARNING! MEMORY LEAK! Refcounting not yet implemented for Layout {:?}",
layout
);
return (following, Vec::new_in(self.arena));
}
let op = match modify {
ModifyRc::Inc(..) => HelperOp::Inc,
ModifyRc::Dec(_) => HelperOp::Dec,
ModifyRc::DecRef(_) => {
let jp_decref = JoinPointId(self.create_symbol(ident_ids, "jp_decref"));
HelperOp::DecRef(jp_decref)
}
};
let mut ctx = Context {
new_linker_data: Vec::new_in(self.arena),
recursive_union: None,
op,
};
let rc_stmt = refcount::refcount_stmt(self, ident_ids, &mut ctx, layout, modify, following);
(rc_stmt, ctx.new_linker_data)
}
pub fn call_reset_refcount(
&mut self,
ident_ids: &mut IdentIds,
layout: Layout<'a>,
argument: Symbol,
) -> (Expr<'a>, Vec<'a, (Symbol, ProcLayout<'a>)>) {
let mut ctx = Context {
new_linker_data: Vec::new_in(self.arena),
recursive_union: None,
op: HelperOp::Reset,
};
let proc_name = self.find_or_create_proc(ident_ids, &mut ctx, layout);
let arguments = self.arena.alloc([argument]);
let ret_layout = self.arena.alloc(layout);
let arg_layouts = self.arena.alloc([layout]);
let expr = Expr::Call(Call {
call_type: CallType::ByName {
name: proc_name,
ret_layout,
arg_layouts,
specialization_id: CallSpecId::BACKEND_DUMMY,
},
arguments,
});
(expr, ctx.new_linker_data)
}
/// Generate a refcount increment procedure, *without* a Call expression.
/// *This method should be rarely used* - only when the proc is to be called from Zig.
/// Otherwise you want to generate the Proc and the Call together, using another method.
pub fn gen_refcount_proc(
&mut self,
ident_ids: &mut IdentIds,
layout: Layout<'a>,
op: HelperOp,
) -> (Symbol, Vec<'a, (Symbol, ProcLayout<'a>)>) {
let mut ctx = Context {
new_linker_data: Vec::new_in(self.arena),
recursive_union: None,
op,
};
let proc_name = self.find_or_create_proc(ident_ids, &mut ctx, layout);
(proc_name, ctx.new_linker_data)
}
/// Replace a generic `Lowlevel::Eq` call with a specialized helper proc.
/// The helper procs themselves are to be generated later with `generate_procs`
pub fn call_specialized_equals(
&mut self,
ident_ids: &mut IdentIds,
layout: &Layout<'a>,
arguments: &'a [Symbol],
) -> (Expr<'a>, Vec<'a, (Symbol, ProcLayout<'a>)>) {
let mut ctx = Context {
new_linker_data: Vec::new_in(self.arena),
recursive_union: None,
op: HelperOp::Eq,
};
let expr = self
.call_specialized_op(ident_ids, &mut ctx, *layout, arguments)
.unwrap();
(expr, ctx.new_linker_data)
}
// ============================================================================
//
// CALL SPECIALIZED OP
//
// ============================================================================
fn call_specialized_op(
&mut self,
ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
called_layout: Layout<'a>,
arguments: &'a [Symbol],
) -> Option<Expr<'a>> {
use HelperOp::*;
// debug_assert!(self.debug_recursion_depth < 100);
self.debug_recursion_depth += 1;
let layout = if matches!(called_layout, Layout::RecursivePointer) {
let union_layout = ctx.recursive_union.unwrap();
Layout::Union(union_layout)
} else {
called_layout
};
if layout_needs_helper_proc(&layout, ctx.op) {
let proc_name = self.find_or_create_proc(ident_ids, ctx, layout);
let (ret_layout, arg_layouts): (&'a Layout<'a>, &'a [Layout<'a>]) = {
let arg = self.replace_rec_ptr(ctx, layout);
match ctx.op {
Dec | DecRef(_) => (&LAYOUT_UNIT, self.arena.alloc([arg])),
Reset => (self.arena.alloc(layout), self.arena.alloc([layout])),
Inc => (&LAYOUT_UNIT, self.arena.alloc([arg, self.layout_isize])),
Eq => (&LAYOUT_BOOL, self.arena.alloc([arg, arg])),
}
};
Some(Expr::Call(Call {
call_type: CallType::ByName {
name: proc_name,
ret_layout,
arg_layouts,
specialization_id: CallSpecId::BACKEND_DUMMY,
},
arguments,
}))
} else if ctx.op == HelperOp::Eq {
Some(Expr::Call(Call {
call_type: CallType::LowLevel {
op: LowLevel::Eq,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments,
}))
} else {
None
}
}
fn find_or_create_proc(
&mut self,
ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
layout: Layout<'a>,
) -> Symbol {
use HelperOp::*;
let layout = self.replace_rec_ptr(ctx, layout);
let found = self
.specializations
.iter()
.find(|spec| spec.op == ctx.op && spec.layout == layout);
if let Some(spec) = found {
return spec.symbol;
}
// Procs can be recursive, so we need to create the symbol before the body is complete
// But with nested recursion, that means Symbols and Procs can end up in different orders.
// We want the same order, especially for function indices in Wasm. So create an empty slot and fill it in later.
let (proc_symbol, proc_layout) = self.create_proc_symbol(ident_ids, ctx, &layout);
ctx.new_linker_data.push((proc_symbol, proc_layout));
let spec_index = self.specializations.len();
self.specializations.push(Specialization {
op: ctx.op,
layout,
symbol: proc_symbol,
proc: None,
});
// Recursively generate the body of the Proc and sub-procs
let (ret_layout, body) = match ctx.op {
Inc | Dec | DecRef(_) => (
LAYOUT_UNIT,
refcount::refcount_generic(self, ident_ids, ctx, layout, Symbol::ARG_1),
),
Reset => (
layout,
refcount::refcount_reset_proc_body(self, ident_ids, ctx, layout, Symbol::ARG_1),
),
Eq => (
LAYOUT_BOOL,
equality::eq_generic(self, ident_ids, ctx, layout),
),
};
let args: &'a [(Layout<'a>, Symbol)] = {
let roc_value = (layout, ARG_1);
match ctx.op {
Inc => {
let inc_amount = (self.layout_isize, ARG_2);
self.arena.alloc([roc_value, inc_amount])
}
Dec | DecRef(_) | Reset => self.arena.alloc([roc_value]),
Eq => self.arena.alloc([roc_value, (layout, ARG_2)]),
}
};
self.specializations[spec_index].proc = Some(Proc {
name: proc_symbol,
args,
body,
closure_data_layout: None,
ret_layout,
is_self_recursive: SelfRecursive::NotSelfRecursive,
must_own_arguments: false,
host_exposed_layouts: HostExposedLayouts::NotHostExposed,
});
proc_symbol
}
fn create_proc_symbol(
&self,
ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
layout: &Layout<'a>,
) -> (Symbol, ProcLayout<'a>) {
let debug_name = format!(
"#help{}_{:?}_{:?}",
self.specializations.len(),
ctx.op,
layout
)
.replace("Builtin", "");
let proc_symbol: Symbol = self.create_symbol(ident_ids, &debug_name);
let proc_layout = match ctx.op {
HelperOp::Inc => ProcLayout {
arguments: self.arena.alloc([*layout, self.layout_isize]),
result: LAYOUT_UNIT,
},
HelperOp::Dec => ProcLayout {
arguments: self.arena.alloc([*layout]),
result: LAYOUT_UNIT,
},
HelperOp::Reset => ProcLayout {
arguments: self.arena.alloc([*layout]),
result: *layout,
},
HelperOp::DecRef(_) => unreachable!("No generated Proc for DecRef"),
HelperOp::Eq => ProcLayout {
arguments: self.arena.alloc([*layout, *layout]),
result: LAYOUT_BOOL,
},
};
(proc_symbol, proc_layout)
}
fn create_symbol(&self, ident_ids: &mut IdentIds, debug_name: &str) -> Symbol {
let ident_id = ident_ids.add_str(debug_name);
Symbol::new(self.home, ident_id)
}
// When creating or looking up Specializations, we need to replace RecursivePointer
// with the particular Union layout it represents at this point in the tree.
// For example if a program uses `RoseTree a : [Tree a (List (RoseTree a))]`
// then it could have both `RoseTree I64` and `RoseTree Str`. In this case it
// needs *two* specializations for `List(RecursivePointer)`, not just one.
fn replace_rec_ptr(&self, ctx: &Context<'a>, layout: Layout<'a>) -> Layout<'a> {
match layout {
Layout::Builtin(Builtin::Dict(k, v)) => Layout::Builtin(Builtin::Dict(
self.arena.alloc(self.replace_rec_ptr(ctx, *k)),
self.arena.alloc(self.replace_rec_ptr(ctx, *v)),
)),
Layout::Builtin(Builtin::Set(k)) => Layout::Builtin(Builtin::Set(
self.arena.alloc(self.replace_rec_ptr(ctx, *k)),
)),
Layout::Builtin(Builtin::List(v)) => Layout::Builtin(Builtin::List(
self.arena.alloc(self.replace_rec_ptr(ctx, *v)),
)),
Layout::Builtin(_) => layout,
Layout::Struct {
field_layouts,
field_order_hash,
} => {
let new_fields_iter = field_layouts.iter().map(|f| self.replace_rec_ptr(ctx, *f));
Layout::Struct {
field_layouts: self.arena.alloc_slice_fill_iter(new_fields_iter),
field_order_hash,
}
}
Layout::Union(UnionLayout::NonRecursive(tags)) => {
let mut new_tags = Vec::with_capacity_in(tags.len(), self.arena);
for fields in tags {
let mut new_fields = Vec::with_capacity_in(fields.len(), self.arena);
for field in fields.iter() {
new_fields.push(self.replace_rec_ptr(ctx, *field))
}
new_tags.push(new_fields.into_bump_slice());
}
Layout::Union(UnionLayout::NonRecursive(new_tags.into_bump_slice()))
}
Layout::Union(_) => {
// we always fully unroll recursive types. That means tha when we find a
// recursive tag union we can replace it with the layout
layout
}
Layout::Boxed(inner) => self.replace_rec_ptr(ctx, *inner),
Layout::LambdaSet(lambda_set) => {
self.replace_rec_ptr(ctx, lambda_set.runtime_representation())
}
// This line is the whole point of the function
Layout::RecursivePointer => Layout::Union(ctx.recursive_union.unwrap()),
}
}
fn union_tail_recursion_fields(
&self,
union: UnionLayout<'a>,
) -> (bool, Vec<'a, Option<usize>>) {
use UnionLayout::*;
match union {
NonRecursive(_) => return (false, bumpalo::vec![in self.arena]),
Recursive(tags) => self.union_tail_recursion_fields_help(tags),
NonNullableUnwrapped(field_layouts) => {
self.union_tail_recursion_fields_help(&[field_layouts])
}
NullableWrapped {
other_tags: tags, ..
} => self.union_tail_recursion_fields_help(tags),
NullableUnwrapped { other_fields, .. } => {
self.union_tail_recursion_fields_help(&[other_fields])
}
}
}
fn union_tail_recursion_fields_help(
&self,
tags: &[&'a [Layout<'a>]],
) -> (bool, Vec<'a, Option<usize>>) {
let mut can_use_tailrec = false;
let mut tailrec_indices = Vec::with_capacity_in(tags.len(), self.arena);
for fields in tags.iter() {
let found_index = fields
.iter()
.position(|f| matches!(f, Layout::RecursivePointer));
tailrec_indices.push(found_index);
can_use_tailrec |= found_index.is_some();
}
(can_use_tailrec, tailrec_indices)
}
}
fn let_lowlevel<'a>(
arena: &'a Bump,
result_layout: Layout<'a>,
result: Symbol,
op: LowLevel,
arguments: &[Symbol],
next: &'a Stmt<'a>,
) -> Stmt<'a> {
Stmt::Let(
result,
Expr::Call(Call {
call_type: CallType::LowLevel {
op,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: arena.alloc_slice_copy(arguments),
}),
result_layout,
next,
)
}
fn layout_needs_helper_proc(layout: &Layout, op: HelperOp) -> bool {
match layout {
Layout::Builtin(Builtin::Int(_) | Builtin::Float(_) | Builtin::Bool | Builtin::Decimal) => {
false
}
Layout::Builtin(Builtin::Str) => {
// Str type can use either Zig functions or generated IR, since it's not generic.
// Eq uses a Zig function, refcount uses generated IR.
// Both are fine, they were just developed at different times.
matches!(op, HelperOp::Inc | HelperOp::Dec | HelperOp::DecRef(_))
}
Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_) | Builtin::List(_)) => true,
Layout::Struct { .. } => true, // note: we do generate a helper for Unit, with just a Stmt::Ret
Layout::Union(UnionLayout::NonRecursive(tags)) => !tags.is_empty(),
Layout::Union(_) => true,
Layout::LambdaSet(_) => true,
Layout::RecursivePointer => false,
Layout::Boxed(_) => true,
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,791 @@
use bumpalo::collections::Vec;
use bumpalo::Bump;
use roc_can::{
def::Def,
expr::{AccessorData, ClosureData, Expr, Field, WhenBranch},
};
use roc_types::{
subs::{
self, AliasVariables, Descriptor, OptVariable, RecordFields, Subs, SubsSlice, UnionLambdas,
UnionTags, Variable, VariableSubsSlice,
},
types::Uls,
};
/// Deep copies the type variables in the type hosted by [`var`] into [`expr`].
/// Returns [`None`] if the expression does not need to be copied.
pub fn deep_copy_type_vars_into_expr<'a>(
arena: &'a Bump,
subs: &mut Subs,
var: Variable,
expr: &Expr,
) -> Option<(Variable, Expr)> {
// Always deal with the root, so that aliases propagate correctly.
let var = subs.get_root_key_without_compacting(var);
let substitutions = deep_copy_type_vars(arena, subs, var);
if substitutions.is_empty() {
return None;
}
let new_var = substitutions
.iter()
.find_map(|&(original, new)| if original == var { Some(new) } else { None })
.expect("Variable marked as cloned, but it isn't");
return Some((new_var, help(subs, expr, &substitutions)));
fn help(subs: &Subs, expr: &Expr, substitutions: &[(Variable, Variable)]) -> Expr {
use Expr::*;
macro_rules! sub {
($var:expr) => {{
// Always deal with the root, so that aliases propagate correctly.
let root = subs.get_root_key_without_compacting($var);
substitutions
.iter()
.find_map(|&(original, new)| if original == root { Some(new) } else { None })
.unwrap_or($var)
}};
}
let go_help = |e: &Expr| help(subs, e, substitutions);
match expr {
Num(var, str, val, bound) => Num(sub!(*var), str.clone(), val.clone(), *bound),
Int(v1, v2, str, val, bound) => {
Int(sub!(*v1), sub!(*v2), str.clone(), val.clone(), *bound)
}
Float(v1, v2, str, val, bound) => {
Float(sub!(*v1), sub!(*v2), str.clone(), *val, *bound)
}
Str(str) => Str(str.clone()),
SingleQuote(char) => SingleQuote(*char),
List {
elem_var,
loc_elems,
} => List {
elem_var: sub!(*elem_var),
loc_elems: loc_elems.iter().map(|le| le.map(go_help)).collect(),
},
Var(sym) => Var(*sym),
&AbilityMember(sym, specialization, specialization_var) => {
AbilityMember(sym, specialization, specialization_var)
}
When {
loc_cond,
cond_var,
expr_var,
region,
branches,
branches_cond_var,
exhaustive,
} => When {
loc_cond: Box::new(loc_cond.map(go_help)),
cond_var: sub!(*cond_var),
expr_var: sub!(*expr_var),
region: *region,
branches: branches
.iter()
.map(
|WhenBranch {
patterns,
value,
guard,
redundant,
}| WhenBranch {
patterns: patterns.clone(),
value: value.map(go_help),
guard: guard.as_ref().map(|le| le.map(go_help)),
redundant: *redundant,
},
)
.collect(),
branches_cond_var: sub!(*branches_cond_var),
exhaustive: *exhaustive,
},
If {
cond_var,
branch_var,
branches,
final_else,
} => If {
cond_var: sub!(*cond_var),
branch_var: sub!(*branch_var),
branches: branches
.iter()
.map(|(c, e)| (c.map(go_help), e.map(go_help)))
.collect(),
final_else: Box::new(final_else.map(go_help)),
},
LetRec(defs, body, cycle_mark) => LetRec(
defs.iter()
.map(
|Def {
loc_pattern,
loc_expr,
expr_var,
pattern_vars,
annotation,
}| Def {
loc_pattern: loc_pattern.clone(),
loc_expr: loc_expr.map(go_help),
expr_var: sub!(*expr_var),
pattern_vars: pattern_vars
.iter()
.map(|(s, v)| (*s, sub!(*v)))
.collect(),
annotation: annotation.clone(),
},
)
.collect(),
Box::new(body.map(go_help)),
*cycle_mark,
),
LetNonRec(def, body) => {
let Def {
loc_pattern,
loc_expr,
expr_var,
pattern_vars,
annotation,
} = &**def;
let def = Def {
loc_pattern: loc_pattern.clone(),
loc_expr: loc_expr.map(go_help),
expr_var: sub!(*expr_var),
pattern_vars: pattern_vars.iter().map(|(s, v)| (*s, sub!(*v))).collect(),
annotation: annotation.clone(),
};
LetNonRec(Box::new(def), Box::new(body.map(go_help)))
}
Call(f, args, called_via) => {
let (fn_var, fn_expr, clos_var, ret_var) = &**f;
Call(
Box::new((
sub!(*fn_var),
fn_expr.map(go_help),
sub!(*clos_var),
sub!(*ret_var),
)),
args.iter()
.map(|(var, expr)| (sub!(*var), expr.map(go_help)))
.collect(),
*called_via,
)
}
RunLowLevel { op, args, ret_var } => RunLowLevel {
op: *op,
args: args
.iter()
.map(|(var, expr)| (sub!(*var), go_help(expr)))
.collect(),
ret_var: sub!(*ret_var),
},
ForeignCall {
foreign_symbol,
args,
ret_var,
} => ForeignCall {
foreign_symbol: foreign_symbol.clone(),
args: args
.iter()
.map(|(var, expr)| (sub!(*var), go_help(expr)))
.collect(),
ret_var: sub!(*ret_var),
},
Closure(ClosureData {
function_type,
closure_type,
return_type,
name,
captured_symbols,
recursive,
arguments,
loc_body,
}) => Closure(ClosureData {
function_type: sub!(*function_type),
closure_type: sub!(*closure_type),
return_type: sub!(*return_type),
name: *name,
captured_symbols: captured_symbols
.iter()
.map(|(s, v)| (*s, sub!(*v)))
.collect(),
recursive: *recursive,
arguments: arguments
.iter()
.map(|(v, mark, pat)| (sub!(*v), *mark, pat.clone()))
.collect(),
loc_body: Box::new(loc_body.map(go_help)),
}),
Record { record_var, fields } => Record {
record_var: sub!(*record_var),
fields: fields
.iter()
.map(
|(
k,
Field {
var,
region,
loc_expr,
},
)| {
(
k.clone(),
Field {
var: sub!(*var),
region: *region,
loc_expr: Box::new(loc_expr.map(go_help)),
},
)
},
)
.collect(),
},
EmptyRecord => EmptyRecord,
Access {
record_var,
ext_var,
field_var,
loc_expr,
field,
} => Access {
record_var: sub!(*record_var),
ext_var: sub!(*ext_var),
field_var: sub!(*field_var),
loc_expr: Box::new(loc_expr.map(go_help)),
field: field.clone(),
},
Accessor(AccessorData {
name,
function_var,
record_var,
closure_var,
ext_var,
field_var,
field,
}) => Accessor(AccessorData {
name: *name,
function_var: sub!(*function_var),
record_var: sub!(*record_var),
closure_var: sub!(*closure_var),
ext_var: sub!(*ext_var),
field_var: sub!(*field_var),
field: field.clone(),
}),
Update {
record_var,
ext_var,
symbol,
updates,
} => Update {
record_var: sub!(*record_var),
ext_var: sub!(*ext_var),
symbol: *symbol,
updates: updates
.iter()
.map(
|(
k,
Field {
var,
region,
loc_expr,
},
)| {
(
k.clone(),
Field {
var: sub!(*var),
region: *region,
loc_expr: Box::new(loc_expr.map(go_help)),
},
)
},
)
.collect(),
},
Tag {
variant_var,
ext_var,
name,
arguments,
} => Tag {
variant_var: sub!(*variant_var),
ext_var: sub!(*ext_var),
name: name.clone(),
arguments: arguments
.iter()
.map(|(v, e)| (sub!(*v), e.map(go_help)))
.collect(),
},
ZeroArgumentTag {
closure_name,
variant_var,
ext_var,
name,
} => ZeroArgumentTag {
closure_name: *closure_name,
variant_var: sub!(*variant_var),
ext_var: sub!(*ext_var),
name: name.clone(),
},
OpaqueRef {
opaque_var,
name,
argument,
specialized_def_type,
type_arguments,
lambda_set_variables,
} => OpaqueRef {
opaque_var: sub!(*opaque_var),
name: *name,
argument: Box::new((sub!(argument.0), argument.1.map(go_help))),
// These shouldn't matter for opaques during mono, because they are only used for reporting
// and pretty-printing to the user. During mono we decay immediately into the argument.
// NB: if there are bugs, check if not substituting here is the problem!
specialized_def_type: specialized_def_type.clone(),
type_arguments: type_arguments.clone(),
lambda_set_variables: lambda_set_variables.clone(),
},
Expect {
loc_condition,
loc_continuation,
lookups_in_cond,
} => Expect {
loc_condition: Box::new(loc_condition.map(go_help)),
loc_continuation: Box::new(loc_continuation.map(go_help)),
lookups_in_cond: lookups_in_cond.to_vec(),
},
TypedHole(v) => TypedHole(sub!(*v)),
RuntimeError(err) => RuntimeError(err.clone()),
}
}
}
/// Deep copies the type variables in [`var`], returning a map of original -> new type variable for
/// all type variables copied.
fn deep_copy_type_vars<'a>(
arena: &'a Bump,
subs: &mut Subs,
var: Variable,
) -> Vec<'a, (Variable, Variable)> {
// Always deal with the root, so that unified variables are treated the same.
let var = subs.get_root_key_without_compacting(var);
let mut copied = Vec::with_capacity_in(16, arena);
let cloned_var = help(arena, subs, &mut copied, var);
// we have tracked all visited variables, and can now traverse them
// in one go (without looking at the UnificationTable) and clear the copy field
let mut result = Vec::with_capacity_in(copied.len(), arena);
for var in copied {
subs.modify(var, |descriptor| {
if let Some(copy) = descriptor.copy.into_variable() {
result.push((var, copy));
descriptor.copy = OptVariable::NONE;
} else {
debug_assert!(false, "{:?} marked as copied but it wasn't", var);
}
})
}
debug_assert!(result.contains(&(var, cloned_var)));
return result;
#[must_use]
fn help(arena: &Bump, subs: &mut Subs, visited: &mut Vec<Variable>, var: Variable) -> Variable {
use roc_types::subs::Content::*;
use roc_types::subs::FlatType::*;
// Always deal with the root, so that unified variables are treated the same.
let var = subs.get_root_key_without_compacting(var);
let desc = subs.get(var);
// Unlike `deep_copy_var` in solve, here we are cloning *all* flex and rigid vars.
// So we only want to short-circuit if we've already done the cloning work for a particular
// var.
if let Some(copy) = desc.copy.into_variable() {
return copy;
}
let content = desc.content;
let copy_descriptor = Descriptor {
content: Error, // we'll update this below
rank: desc.rank,
mark: desc.mark,
copy: OptVariable::NONE,
};
let copy = subs.fresh(copy_descriptor);
subs.set_copy(var, copy.into());
visited.push(var);
macro_rules! descend_slice {
($slice:expr) => {
for var_index in $slice {
let var = subs[var_index];
let _ = help(arena, subs, visited, var);
}
};
}
macro_rules! descend_var {
($var:expr) => {{
help(arena, subs, visited, $var)
}};
}
macro_rules! clone_var_slice {
($slice:expr) => {{
let new_arguments = VariableSubsSlice::reserve_into_subs(subs, $slice.len());
for (target_index, var_index) in (new_arguments.indices()).zip($slice) {
let var = subs[var_index];
let copy_var = subs.get_copy(var).into_variable().unwrap_or(var);
subs.variables[target_index] = copy_var;
}
new_arguments
}};
}
macro_rules! perform_clone {
($do_clone:expr) => {{
// It may the case that while deep-copying nested variables of this type, we
// ended up copying the type itself (notably if it was self-referencing, in a
// recursive type). In that case, short-circuit with the known copy.
// if let Some(copy) = subs.get_ref(var).copy.into_variable() {
// return copy;
// }
// Perform the clone.
$do_clone
}};
}
// Now we recursively copy the content of the variable.
// We have already marked the variable as copied, so we
// will not repeat this work or crawl this variable again.
let new_content = match content {
// The vars for which we want to do something interesting.
FlexVar(opt_name) => FlexVar(opt_name),
FlexAbleVar(opt_name, ability) => FlexAbleVar(opt_name, ability),
RigidVar(name) => RigidVar(name),
RigidAbleVar(name, ability) => RigidAbleVar(name, ability),
// Everything else is a mechanical descent.
Structure(flat_type) => match flat_type {
EmptyRecord | EmptyTagUnion | Erroneous(_) => Structure(flat_type),
Apply(symbol, arguments) => {
descend_slice!(arguments);
perform_clone!({
let new_arguments = clone_var_slice!(arguments);
Structure(Apply(symbol, new_arguments))
})
}
Func(arguments, closure_var, ret_var) => {
descend_slice!(arguments);
let new_closure_var = descend_var!(closure_var);
let new_ret_var = descend_var!(ret_var);
perform_clone!({
let new_arguments = clone_var_slice!(arguments);
Structure(Func(new_arguments, new_closure_var, new_ret_var))
})
}
Record(fields, ext_var) => {
let new_ext_var = descend_var!(ext_var);
descend_slice!(fields.variables());
perform_clone!({
let new_variables = clone_var_slice!(fields.variables());
let new_fields = {
RecordFields {
length: fields.length,
field_names_start: fields.field_names_start,
variables_start: new_variables.start,
field_types_start: fields.field_types_start,
}
};
Structure(Record(new_fields, new_ext_var))
})
}
TagUnion(tags, ext_var) => {
let new_ext_var = descend_var!(ext_var);
for variables_slice_index in tags.variables() {
let variables_slice = subs[variables_slice_index];
descend_slice!(variables_slice);
}
perform_clone!({
let new_variable_slices =
SubsSlice::reserve_variable_slices(subs, tags.len());
let it = (new_variable_slices.indices()).zip(tags.variables());
for (target_index, index) in it {
let slice = subs[index];
let new_variables = clone_var_slice!(slice);
subs.variable_slices[target_index] = new_variables;
}
let new_union_tags =
UnionTags::from_slices(tags.labels(), new_variable_slices);
Structure(TagUnion(new_union_tags, new_ext_var))
})
}
RecursiveTagUnion(rec_var, tags, ext_var) => {
let new_ext_var = descend_var!(ext_var);
let new_rec_var = descend_var!(rec_var);
for variables_slice_index in tags.variables() {
let variables_slice = subs[variables_slice_index];
descend_slice!(variables_slice);
}
perform_clone!({
let new_variable_slices =
SubsSlice::reserve_variable_slices(subs, tags.len());
let it = (new_variable_slices.indices()).zip(tags.variables());
for (target_index, index) in it {
let slice = subs[index];
let new_variables = clone_var_slice!(slice);
subs.variable_slices[target_index] = new_variables;
}
let new_union_tags =
UnionTags::from_slices(tags.labels(), new_variable_slices);
Structure(RecursiveTagUnion(new_rec_var, new_union_tags, new_ext_var))
})
}
FunctionOrTagUnion(tag_name, symbol, ext_var) => {
let new_ext_var = descend_var!(ext_var);
perform_clone!(Structure(FunctionOrTagUnion(tag_name, symbol, new_ext_var)))
}
},
RecursionVar {
opt_name,
structure,
} => {
let new_structure = descend_var!(structure);
perform_clone!({
RecursionVar {
opt_name,
structure: new_structure,
}
})
}
Alias(symbol, arguments, real_type_var, kind) => {
let new_real_type_var = descend_var!(real_type_var);
descend_slice!(arguments.all_variables());
perform_clone!({
let new_variables = clone_var_slice!(arguments.all_variables());
let new_arguments = AliasVariables {
variables_start: new_variables.start,
..arguments
};
Alias(symbol, new_arguments, new_real_type_var, kind)
})
}
LambdaSet(subs::LambdaSet {
solved,
recursion_var,
unspecialized,
}) => {
let new_rec_var = recursion_var.map(|var| descend_var!(var));
for variables_slice_index in solved.variables() {
let variables_slice = subs[variables_slice_index];
descend_slice!(variables_slice);
}
for uls_index in unspecialized {
let Uls(var, _, _) = subs[uls_index];
descend_var!(var);
}
perform_clone!({
let new_variable_slices =
SubsSlice::reserve_variable_slices(subs, solved.len());
let it = (new_variable_slices.indices()).zip(solved.variables());
for (target_index, index) in it {
let slice = subs[index];
let new_variables = clone_var_slice!(slice);
subs.variable_slices[target_index] = new_variables;
}
let new_solved =
UnionLambdas::from_slices(solved.labels(), new_variable_slices);
let new_unspecialized = SubsSlice::reserve_uls_slice(subs, unspecialized.len());
for (target_index, uls_index) in
(new_unspecialized.into_iter()).zip(unspecialized.into_iter())
{
let Uls(var, sym, region) = subs[uls_index];
let copy_var = subs.get_copy(var).into_variable().unwrap_or(var);
subs[target_index] = Uls(copy_var, sym, region);
}
LambdaSet(subs::LambdaSet {
solved: new_solved,
recursion_var: new_rec_var,
unspecialized: new_unspecialized,
})
})
}
RangedNumber(typ, range) => {
let new_typ = descend_var!(typ);
perform_clone!(RangedNumber(new_typ, range))
}
Error => Error,
};
subs.set_content(copy, new_content);
copy
}
}
#[cfg(test)]
mod test {
use super::deep_copy_type_vars;
use bumpalo::Bump;
use roc_error_macros::internal_error;
use roc_module::symbol::Symbol;
use roc_types::subs::{
Content, Content::*, Descriptor, Mark, OptVariable, Rank, Subs, SubsIndex, Variable,
};
#[cfg(test)]
fn new_var(subs: &mut Subs, content: Content) -> Variable {
subs.fresh(Descriptor {
content,
rank: Rank::toplevel(),
mark: Mark::NONE,
copy: OptVariable::NONE,
})
}
#[test]
fn copy_flex_var() {
let mut subs = Subs::new();
let arena = Bump::new();
let field_name = SubsIndex::push_new(&mut subs.field_names, "a".into());
let var = new_var(&mut subs, FlexVar(Some(field_name)));
let mut copies = deep_copy_type_vars(&arena, &mut subs, var);
assert_eq!(copies.len(), 1);
let (original, new) = copies.pop().unwrap();
assert_ne!(original, new);
assert_eq!(original, var);
match subs.get_content_without_compacting(new) {
FlexVar(Some(name)) => {
assert_eq!(subs[*name].as_str(), "a");
}
it => unreachable!("{:?}", it),
}
}
#[test]
fn copy_rigid_var() {
let mut subs = Subs::new();
let arena = Bump::new();
let field_name = SubsIndex::push_new(&mut subs.field_names, "a".into());
let var = new_var(&mut subs, RigidVar(field_name));
let mut copies = deep_copy_type_vars(&arena, &mut subs, var);
assert_eq!(copies.len(), 1);
let (original, new) = copies.pop().unwrap();
assert_ne!(original, new);
assert_eq!(original, var);
match subs.get_content_without_compacting(new) {
RigidVar(name) => {
assert_eq!(subs[*name].as_str(), "a");
}
it => unreachable!("{:?}", it),
}
}
#[test]
fn copy_flex_able_var() {
let mut subs = Subs::new();
let arena = Bump::new();
let field_name = SubsIndex::push_new(&mut subs.field_names, "a".into());
let var = new_var(&mut subs, FlexAbleVar(Some(field_name), Symbol::UNDERSCORE));
let mut copies = deep_copy_type_vars(&arena, &mut subs, var);
assert_eq!(copies.len(), 1);
let (original, new) = copies.pop().unwrap();
assert_ne!(original, new);
assert_eq!(original, var);
match subs.get_content_without_compacting(new) {
FlexAbleVar(Some(name), Symbol::UNDERSCORE) => {
assert_eq!(subs[*name].as_str(), "a");
}
it => unreachable!("{:?}", it),
}
}
#[test]
fn copy_rigid_able_var() {
let mut subs = Subs::new();
let arena = Bump::new();
let field_name = SubsIndex::push_new(&mut subs.field_names, "a".into());
let var = new_var(&mut subs, RigidAbleVar(field_name, Symbol::UNDERSCORE));
let mut copies = deep_copy_type_vars(&arena, &mut subs, var);
assert_eq!(copies.len(), 1);
let (original, new) = copies.pop().unwrap();
assert_ne!(original, new);
assert_eq!(original, var);
match subs.get_content_without_compacting(new) {
RigidAbleVar(name, Symbol::UNDERSCORE) => {
assert_eq!(subs[*name].as_str(), "a");
}
it => internal_error!("{:?}", it),
}
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,672 @@
/// UNUSED
///
/// but kept for future reference
///
///
// pub fn optimize_refcount_operations<'i, T>(
// arena: &'a Bump,
// home: ModuleId,
// ident_ids: &'i mut IdentIds,
// procs: &mut MutMap<T, Proc<'a>>,
// ) {
// use crate::expand_rc;
//
// let deferred = expand_rc::Deferred {
// inc_dec_map: Default::default(),
// assignments: Vec::new_in(arena),
// decrefs: Vec::new_in(arena),
// };
//
// let mut env = expand_rc::Env {
// home,
// arena,
// ident_ids,
// layout_map: Default::default(),
// alias_map: Default::default(),
// constructor_map: Default::default(),
// deferred,
// };
//
// for (_, proc) in procs.iter_mut() {
// let b = expand_rc::expand_and_cancel_proc(
// &mut env,
// arena.alloc(proc.body.clone()),
// proc.args,
// );
// proc.body = b.clone();
// }
// }
use crate::ir::{BranchInfo, Expr, ModifyRc, Stmt};
use crate::layout::{Layout, UnionLayout};
use bumpalo::collections::Vec;
use bumpalo::Bump;
// use linked_hash_map::LinkedHashMap;
use roc_collections::all::MutMap;
use roc_module::symbol::{IdentIds, ModuleId, Symbol};
// This file is heavily inspired by the Perceus paper
//
// https://www.microsoft.com/en-us/research/uploads/prod/2020/11/perceus-tr-v1.pdf
//
// With how we insert RC instructions, this pattern is very common:
//
// when xs is
// Cons x xx ->
// inc x;
// inc xx;
// dec xs;
// ...
//
// This pattern is very inefficient, because it will first increment the tail (recursively),
// and then decrement it again. We can see this more clearly if we inline/specialize the `dec xs`
//
// when xs is
// Cons x xx ->
// inc x;
// inc xx;
// dec x;
// dec xx;
// decref xs
// ...
//
// Here `decref` non-recursively decrements (and possibly frees) `xs`. Now the idea is that we can
// fuse `inc x; dec x` by just doing nothing: they cancel out
//
// We can do slightly more, in the `Nil` case
//
// when xs is
// ...
// Nil ->
// dec xs;
// accum
//
// Here we know that `Nil` is represented by NULL (a linked list has a NullableUnwrapped layout),
// so we can just drop the `dec xs`
//
// # complications
//
// Let's work through the `Cons x xx` example
//
// First we need to know the constructor of `xs` in the particular block. This information would
// normally be lost when we compile pattern matches, but we keep it in the `BranchInfo` field of
// switch branches. here we also store the symbol that was switched on, and the layout of that
// symbol.
//
// Next, we need to know that `x` and `xx` alias the head and tail of `xs`. We store that
// information when encountering a `AccessAtIndex` into `xs`.
//
// In most cases these two pieces of information are enough. We keep track of a
// `LinkedHashMap<Symbol, i64>`: `LinkedHashMap` remembers insertion order, which is crucial here.
// The `i64` value represents the increment (positive value) or decrement (negative value). When
// the value is 0, increments and decrements have cancelled out and we just emit nothing.
//
// We need to do slightly more work in the case of
//
// when xs is
// Cons _ xx ->
// recurse xx (1 + accum)
//
// In this case, the head is not bound. That's OK when the list elements are not refcounted (or
// contain anything refcounted). But when they do, we can't expand the `dec xs` because there is no
// way to reference the head element.
//
// Our refcounting mechanism can't deal well with unused variables (it'll leak their memory). But
// we can insert the access after RC instructions have been inserted. So in the above case we
// actually get
//
// when xs is
// Cons _ xx ->
// let v1 = AccessAtIndex 1 xs
// inc v1;
// let xx = AccessAtIndex 2 xs
// inc xx;
// dec v1;
// dec xx;
// decref xs;
// recurse xx (1 + accum)
//
// Here we see another problem: the increments and decrements cannot be fused immediately.
// Therefore we add a rule that we can "push down" increments and decrements past
//
// - `Let`s binding a `AccessAtIndex`
// - refcount operations
//
// This allows the aforementioned `LinkedHashMap` to accumulate all changes, and then emit
// all (uncancelled) modifications at once before any "non-push-downable-stmt", hence:
//
// when xs is
// Cons _ xx ->
// let v1 = AccessAtIndex 1 xs
// let xx = AccessAtIndex 2 xs
// dec v1;
// decref xs;
// recurse xx (1 + accum)
pub struct Env<'a, 'i> {
/// bump allocator
pub arena: &'a Bump,
/// required for creating new `Symbol`s
pub home: ModuleId,
pub ident_ids: &'i mut IdentIds,
/// layout of the symbol
pub layout_map: MutMap<Symbol, Layout<'a>>,
/// record for each symbol, the aliases of its fields
pub alias_map: MutMap<Symbol, MutMap<u64, Symbol>>,
/// for a symbol (found in a `when x is`), record in which branch we are
pub constructor_map: MutMap<Symbol, u64>,
/// increments and decrements deferred until later
pub deferred: Deferred<'a>,
}
#[derive(Debug)]
pub struct Deferred<'a> {
pub inc_dec_map: LinkedHashMap<Symbol, i64>,
pub assignments: Vec<'a, (Symbol, Expr<'a>, Layout<'a>)>,
pub decrefs: Vec<'a, Symbol>,
}
impl<'a, 'i> Env<'a, 'i> {
fn insert_branch_info(&mut self, info: &BranchInfo<'a>) {
match info {
BranchInfo::Constructor {
layout,
scrutinee,
tag_id,
} => {
self.constructor_map.insert(*scrutinee, *tag_id as u64);
self.layout_map.insert(*scrutinee, *layout);
}
BranchInfo::None => (),
}
}
fn remove_branch_info(&mut self, info: &BranchInfo) {
match info {
BranchInfo::Constructor { scrutinee, .. } => {
self.constructor_map.remove(scrutinee);
self.layout_map.remove(scrutinee);
}
BranchInfo::None => (),
}
}
fn try_insert_struct_info(&mut self, symbol: Symbol, layout: &Layout<'a>) {
use Layout::*;
if let Struct(fields) = layout {
self.constructor_map.insert(symbol, 0);
self.layout_map.insert(symbol, Layout::Struct(fields));
}
}
fn insert_struct_info(&mut self, symbol: Symbol, fields: &'a [Layout<'a>]) {
self.constructor_map.insert(symbol, 0);
self.layout_map.insert(symbol, Layout::Struct(fields));
}
fn remove_struct_info(&mut self, symbol: Symbol) {
self.constructor_map.remove(&symbol);
self.layout_map.remove(&symbol);
}
pub fn unique_symbol(&mut self) -> Symbol {
let ident_id = self.ident_ids.gen_unique();
Symbol::new(self.home, ident_id)
}
#[allow(dead_code)]
fn manual_unique_symbol(home: ModuleId, ident_ids: &mut IdentIds) -> Symbol {
let ident_id = ident_ids.gen_unique();
Symbol::new(home, ident_id)
}
}
fn layout_for_constructor<'a>(
_arena: &'a Bump,
layout: &Layout<'a>,
constructor: u64,
) -> ConstructorLayout<&'a [Layout<'a>]> {
use ConstructorLayout::*;
use Layout::*;
match layout {
Union(variant) => {
use UnionLayout::*;
match variant {
NullableUnwrapped {
nullable_id,
other_fields,
} => {
if (constructor > 0) == *nullable_id {
ConstructorLayout::IsNull
} else {
ConstructorLayout::HasFields(other_fields)
}
}
NullableWrapped {
nullable_id,
other_tags,
} => {
if constructor as i64 == *nullable_id {
ConstructorLayout::IsNull
} else {
ConstructorLayout::HasFields(other_tags[constructor as usize])
}
}
NonRecursive(fields) | Recursive(fields) => HasFields(fields[constructor as usize]),
NonNullableUnwrapped(fields) => {
debug_assert_eq!(constructor, 0);
HasFields(fields)
}
}
}
Struct(fields) => {
debug_assert_eq!(constructor, 0);
HasFields(fields)
}
other => unreachable!("weird layout {:?}", other),
}
}
fn work_for_constructor<'a>(
env: &mut Env<'a, '_>,
symbol: &Symbol,
) -> ConstructorLayout<Vec<'a, Symbol>> {
use ConstructorLayout::*;
let mut result = Vec::new_in(env.arena);
let constructor = match env.constructor_map.get(symbol) {
None => return ConstructorLayout::Unknown,
Some(v) => *v,
};
let full_layout = match env.layout_map.get(symbol) {
None => return ConstructorLayout::Unknown,
Some(v) => v,
};
let field_aliases = env.alias_map.get(symbol);
match layout_for_constructor(env.arena, full_layout, constructor) {
Unknown => Unknown,
IsNull => IsNull,
HasFields(constructor_layout) => {
// figure out if there is at least one aliased refcounted field. Only then
// does it make sense to inline the decrement
let at_least_one_aliased = (|| {
for (i, field_layout) in constructor_layout.iter().enumerate() {
if field_layout.contains_refcounted()
&& field_aliases.and_then(|map| map.get(&(i as u64))).is_some()
{
return true;
}
}
false
})();
// for each field, if it has refcounted content, check if it has an alias
// if so, use the alias, otherwise load the field.
for (i, field_layout) in constructor_layout.iter().enumerate() {
if field_layout.contains_refcounted() {
match field_aliases.and_then(|map| map.get(&(i as u64))) {
Some(alias_symbol) => {
// the field was bound in a pattern match
result.push(*alias_symbol);
}
None if at_least_one_aliased => {
// the field was not bound in a pattern match
// we have to extract it now, but we only extract it
// if at least one field is aliased.
todo!("get the tag id");
/*
let expr = Expr::AccessAtIndex {
index: i as u64,
field_layouts: constructor_layout,
structure: *symbol,
wrapped: todo!("get the tag id"),
};
// create a fresh symbol for this field
let alias_symbol = Env::manual_unique_symbol(env.home, env.ident_ids);
let layout = if let Layout::RecursivePointer = field_layout {
*full_layout
} else {
*field_layout
};
env.deferred.assignments.push((alias_symbol, expr, layout));
result.push(alias_symbol);
*/
}
None => {
// if all refcounted fields were unaliased, generate a normal decrement
// of the whole structure (less code generated this way)
return ConstructorLayout::Unknown;
}
}
}
}
ConstructorLayout::HasFields(result)
}
}
}
fn can_push_inc_through(stmt: &Stmt) -> bool {
use Stmt::*;
match stmt {
Let(_, expr, _, _) => {
// we can always delay an increment/decrement until after a field access
matches!(expr, Expr::StructAtIndex { .. } | Expr::Literal(_))
}
Refcounting(ModifyRc::Inc(_, _), _) => true,
Refcounting(ModifyRc::Dec(_), _) => true,
_ => false,
}
}
#[derive(Debug)]
enum ConstructorLayout<T> {
IsNull,
HasFields(T),
Unknown,
}
pub fn expand_and_cancel_proc<'a>(
env: &mut Env<'a, '_>,
stmt: &'a Stmt<'a>,
arguments: &'a [(Layout<'a>, Symbol)],
) -> &'a Stmt<'a> {
let mut introduced = Vec::new_in(env.arena);
for (layout, symbol) in arguments {
if let Layout::Struct(fields) = layout {
env.insert_struct_info(*symbol, fields);
introduced.push(*symbol);
}
}
let result = expand_and_cancel(env, stmt);
for symbol in introduced {
env.remove_struct_info(symbol);
}
result
}
fn expand_and_cancel<'a>(env: &mut Env<'a, '_>, stmt: &'a Stmt<'a>) -> &'a Stmt<'a> {
use Stmt::*;
let mut deferred = Deferred {
inc_dec_map: Default::default(),
assignments: Vec::new_in(env.arena),
decrefs: Vec::new_in(env.arena),
};
if !can_push_inc_through(stmt) {
std::mem::swap(&mut deferred, &mut env.deferred);
}
let mut result = {
match stmt {
Let(mut symbol, expr, layout, cont) => {
env.layout_map.insert(symbol, *layout);
let mut expr = expr;
let mut layout = layout;
let mut cont = cont;
// prevent long chains of `Let`s from blowing the stack
let mut literal_stack = Vec::new_in(env.arena);
while !matches!(
&expr,
Expr::StructAtIndex { .. } | Expr::Struct(_) | Expr::Call(_)
) {
if let Stmt::Let(symbol1, expr1, layout1, cont1) = cont {
literal_stack.push((symbol, expr.clone(), *layout));
symbol = *symbol1;
expr = expr1;
layout = layout1;
cont = cont1;
} else {
break;
}
}
let new_cont;
match &expr {
Expr::StructAtIndex {
structure,
index,
field_layouts,
} => {
let entry = env
.alias_map
.entry(*structure)
.or_insert_with(MutMap::default);
entry.insert(*index, symbol);
env.layout_map
.insert(*structure, Layout::Struct(field_layouts));
// if the field is a struct, we know its constructor too!
let field_layout = &field_layouts[*index as usize];
env.try_insert_struct_info(symbol, field_layout);
new_cont = expand_and_cancel(env, cont);
env.remove_struct_info(symbol);
// make sure to remove the alias, so other branches don't use it by accident
env.alias_map
.get_mut(structure)
.and_then(|map| map.remove(index));
}
Expr::Struct(_) => {
if let Layout::Struct(fields) = layout {
env.insert_struct_info(symbol, fields);
new_cont = expand_and_cancel(env, cont);
env.remove_struct_info(symbol);
} else {
new_cont = expand_and_cancel(env, cont);
}
}
Expr::Call(_) => {
if let Layout::Struct(fields) = layout {
env.insert_struct_info(symbol, fields);
new_cont = expand_and_cancel(env, cont);
env.remove_struct_info(symbol);
} else {
new_cont = expand_and_cancel(env, cont);
}
}
_ => {
new_cont = expand_and_cancel(env, cont);
}
}
let stmt = Let(symbol, expr.clone(), *layout, new_cont);
let mut stmt = &*env.arena.alloc(stmt);
for (symbol, expr, layout) in literal_stack.into_iter().rev() {
stmt = env.arena.alloc(Stmt::Let(symbol, expr, layout, stmt));
}
stmt
}
Switch {
cond_symbol,
cond_layout,
ret_layout,
branches,
default_branch,
} => {
let mut new_branches = Vec::with_capacity_in(branches.len(), env.arena);
for (id, info, branch) in branches.iter() {
env.insert_branch_info(info);
let branch = expand_and_cancel(env, branch);
env.remove_branch_info(info);
env.constructor_map.remove(cond_symbol);
new_branches.push((*id, info.clone(), branch.clone()));
}
env.insert_branch_info(&default_branch.0);
let new_default = (
default_branch.0.clone(),
expand_and_cancel(env, default_branch.1),
);
env.remove_branch_info(&default_branch.0);
let stmt = Switch {
cond_symbol: *cond_symbol,
cond_layout: *cond_layout,
ret_layout: *ret_layout,
branches: new_branches.into_bump_slice(),
default_branch: new_default,
};
&*env.arena.alloc(stmt)
}
Refcounting(ModifyRc::DecRef(symbol), cont) => {
// decref the current cell
env.deferred.decrefs.push(*symbol);
expand_and_cancel(env, cont)
}
Refcounting(ModifyRc::Dec(symbol), cont) => {
use ConstructorLayout::*;
match work_for_constructor(env, symbol) {
HasFields(dec_symbols) => {
// we can inline the decrement
// decref the current cell
env.deferred.decrefs.push(*symbol);
// and record decrements for all the fields
for dec_symbol in dec_symbols {
let count = env.deferred.inc_dec_map.entry(dec_symbol).or_insert(0);
*count -= 1;
}
}
Unknown => {
// we can't inline the decrement; just record it
let count = env.deferred.inc_dec_map.entry(*symbol).or_insert(0);
*count -= 1;
}
IsNull => {
// we decrement a value represented as `NULL` at runtime;
// we can drop this decrement completely
}
}
expand_and_cancel(env, cont)
}
Refcounting(ModifyRc::Inc(symbol, inc_amount), cont) => {
let count = env.deferred.inc_dec_map.entry(*symbol).or_insert(0);
*count += *inc_amount as i64;
expand_and_cancel(env, cont)
}
Join {
id,
parameters,
body: continuation,
remainder,
} => {
let continuation = expand_and_cancel(env, continuation);
let remainder = expand_and_cancel(env, remainder);
let stmt = Join {
id: *id,
parameters,
body: continuation,
remainder,
};
env.arena.alloc(stmt)
}
Ret(_) | Jump(_, _) | RuntimeError(_) => stmt,
}
};
for symbol in deferred.decrefs {
let stmt = Refcounting(ModifyRc::DecRef(symbol), result);
result = env.arena.alloc(stmt);
}
// do all decrements
for (symbol, amount) in deferred.inc_dec_map.iter().rev() {
use std::cmp::Ordering;
match amount.cmp(&0) {
Ordering::Equal => {
// do nothing else
}
Ordering::Greater => {
// do nothing yet
}
Ordering::Less => {
// the RC insertion should not double decrement in a block
debug_assert_eq!(*amount, -1);
// insert missing decrements
let stmt = Refcounting(ModifyRc::Dec(*symbol), result);
result = env.arena.alloc(stmt);
}
}
}
for (symbol, amount) in deferred.inc_dec_map.into_iter().rev() {
use std::cmp::Ordering;
match amount.cmp(&0) {
Ordering::Equal => {
// do nothing else
}
Ordering::Greater => {
// insert missing increments
let stmt = Refcounting(ModifyRc::Inc(symbol, amount as u64), result);
result = env.arena.alloc(stmt);
}
Ordering::Less => {
// already done
}
}
}
for (symbol, expr, layout) in deferred.assignments {
let stmt = Stmt::Let(symbol, expr, layout, result);
result = env.arena.alloc(stmt);
}
result
}

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,907 @@
use crate::layout::{ext_var_is_empty_record, ext_var_is_empty_tag_union};
use roc_builtins::bitcode::{FloatWidth, IntWidth};
use roc_collections::all::MutMap;
use roc_module::symbol::Symbol;
use roc_target::TargetInfo;
use roc_types::subs::{self, Content, FlatType, Subs, Variable};
use roc_types::types::RecordField;
use std::collections::hash_map::Entry;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Index<T> {
index: u32,
_marker: std::marker::PhantomData<T>,
}
impl<T> Index<T> {
pub const fn new(index: u32) -> Self {
Self {
index,
_marker: std::marker::PhantomData,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Slice<T> {
start: u32,
length: u16,
_marker: std::marker::PhantomData<T>,
}
impl<T> Slice<T> {
pub const fn new(start: u32, length: u16) -> Self {
Self {
start,
length,
_marker: std::marker::PhantomData,
}
}
pub const fn len(&self) -> usize {
self.length as _
}
pub const fn is_empty(&self) -> bool {
self.length == 0
}
pub const fn indices(&self) -> std::ops::Range<usize> {
self.start as usize..(self.start as usize + self.length as usize)
}
pub fn into_iter(&self) -> impl Iterator<Item = Index<T>> {
self.indices().map(|i| Index::new(i as _))
}
}
trait Reserve {
fn reserve(layouts: &mut Layouts, length: usize) -> Self;
}
impl Reserve for Slice<Layout> {
fn reserve(layouts: &mut Layouts, length: usize) -> Self {
let start = layouts.layouts.len() as u32;
let it = std::iter::repeat(Layout::Reserved).take(length);
layouts.layouts.extend(it);
Self {
start,
length: length as u16,
_marker: Default::default(),
}
}
}
impl Reserve for Slice<Slice<Layout>> {
fn reserve(layouts: &mut Layouts, length: usize) -> Self {
let start = layouts.layout_slices.len() as u32;
let empty: Slice<Layout> = Slice::new(0, 0);
let it = std::iter::repeat(empty).take(length);
layouts.layout_slices.extend(it);
Self {
start,
length: length as u16,
_marker: Default::default(),
}
}
}
static_assertions::assert_eq_size!([u8; 12], Layout);
pub struct Layouts {
layouts: Vec<Layout>,
layout_slices: Vec<Slice<Layout>>,
// function_layouts: Vec<(Slice<Layout>, Index<LambdaSet>)>,
lambda_sets: Vec<LambdaSet>,
symbols: Vec<Symbol>,
recursion_variable_to_structure_variable_map: MutMap<Variable, Index<Layout>>,
target_info: TargetInfo,
}
pub struct FunctionLayout {
/// last element is the result, prior elements the arguments
arguments_and_result: Slice<Layout>,
pub lambda_set: Index<LambdaSet>,
}
impl FunctionLayout {
pub fn from_var(
layouts: &mut Layouts,
subs: &Subs,
var: Variable,
) -> Result<Self, LayoutError> {
// so we can set some things/clean up
Self::from_var_help(layouts, subs, var)
}
fn from_var_help(
layouts: &mut Layouts,
subs: &Subs,
var: Variable,
) -> Result<Self, LayoutError> {
let content = &subs.get_content_without_compacting(var);
Self::from_content(layouts, subs, var, content)
}
fn from_content(
layouts: &mut Layouts,
subs: &Subs,
var: Variable,
content: &Content,
) -> Result<Self, LayoutError> {
use LayoutError::*;
match content {
Content::FlexVar(_)
| Content::RigidVar(_)
| Content::FlexAbleVar(_, _)
| Content::RigidAbleVar(_, _) => Err(UnresolvedVariable(var)),
Content::RecursionVar { .. } => Err(TypeError(())),
Content::LambdaSet(lset) => Self::from_lambda_set(layouts, subs, *lset),
Content::Structure(flat_type) => Self::from_flat_type(layouts, subs, flat_type),
Content::Alias(_, _, actual, _) => Self::from_var_help(layouts, subs, *actual),
Content::RangedNumber(actual, _) => Self::from_var_help(layouts, subs, *actual),
Content::Error => Err(TypeError(())),
}
}
fn from_lambda_set(
_layouts: &mut Layouts,
_subs: &Subs,
_lset: subs::LambdaSet,
) -> Result<Self, LayoutError> {
todo!();
}
fn from_flat_type(
layouts: &mut Layouts,
subs: &Subs,
flat_type: &FlatType,
) -> Result<Self, LayoutError> {
use LayoutError::*;
match flat_type {
FlatType::Func(arguments, lambda_set, result) => {
let slice = Slice::reserve(layouts, arguments.len() + 1);
let variable_slice = &subs.variables[arguments.indices()];
let it = slice.indices().zip(variable_slice);
for (target_index, var) in it {
let layout = Layout::from_var_help(layouts, subs, *var)?;
layouts.layouts[target_index] = layout;
}
let result_layout = Layout::from_var_help(layouts, subs, *result)?;
let result_index: Index<Layout> = Index::new(slice.start + slice.len() as u32 - 1);
layouts.layouts[result_index.index as usize] = result_layout;
let lambda_set = LambdaSet::from_var(layouts, subs, *lambda_set)?;
let lambda_set_index = Index::new(layouts.lambda_sets.len() as u32);
layouts.lambda_sets.push(lambda_set);
Ok(Self {
arguments_and_result: slice,
lambda_set: lambda_set_index,
})
}
FlatType::Erroneous(_) => Err(TypeError(())),
_ => todo!(),
}
}
pub fn argument_slice(&self) -> Slice<Layout> {
let mut result = self.arguments_and_result;
result.length -= 1;
result
}
pub fn result_index(&self) -> Index<Layout> {
Index::new(self.arguments_and_result.start + self.arguments_and_result.length as u32 - 1)
}
}
/// Idea: don't include the symbols for the first 3 cases in --optimize mode
pub enum LambdaSet {
Empty {
symbol: Index<Symbol>,
},
Single {
symbol: Index<Symbol>,
layout: Index<Layout>,
},
Struct {
symbol: Index<Symbol>,
layouts: Slice<Layout>,
},
Union {
symbols: Slice<Symbol>,
layouts: Slice<Slice<Layout>>,
},
}
impl LambdaSet {
pub fn from_var(
layouts: &mut Layouts,
subs: &Subs,
var: Variable,
) -> Result<Self, LayoutError> {
// so we can set some things/clean up
Self::from_var_help(layouts, subs, var)
}
fn from_var_help(
layouts: &mut Layouts,
subs: &Subs,
var: Variable,
) -> Result<Self, LayoutError> {
let content = &subs.get_content_without_compacting(var);
Self::from_content(layouts, subs, var, content)
}
fn from_content(
layouts: &mut Layouts,
subs: &Subs,
var: Variable,
content: &Content,
) -> Result<Self, LayoutError> {
use LayoutError::*;
match content {
Content::FlexVar(_)
| Content::RigidVar(_)
| Content::FlexAbleVar(_, _)
| Content::RigidAbleVar(_, _) => Err(UnresolvedVariable(var)),
Content::RecursionVar { .. } => {
unreachable!("lambda sets cannot currently be recursive")
}
Content::LambdaSet(lset) => Self::from_lambda_set(layouts, subs, *lset),
Content::Structure(_flat_type) => unreachable!(),
Content::Alias(_, _, actual, _) => Self::from_var_help(layouts, subs, *actual),
Content::RangedNumber(actual, _) => Self::from_var_help(layouts, subs, *actual),
Content::Error => Err(TypeError(())),
}
}
fn from_lambda_set(
layouts: &mut Layouts,
subs: &Subs,
lset: subs::LambdaSet,
) -> Result<Self, LayoutError> {
let subs::LambdaSet {
solved,
recursion_var: _,
unspecialized: _,
} = lset;
// TODO: handle unspecialized
debug_assert!(
!solved.is_empty(),
"lambda set must contain atleast the function itself"
);
let lambda_names = solved.labels();
let closure_names = Self::get_closure_names(layouts, subs, lambda_names);
let variables = solved.variables();
if variables.len() == 1 {
let symbol = subs.closure_names[lambda_names.start as usize];
let symbol_index = Index::new(layouts.symbols.len() as u32);
layouts.symbols.push(symbol);
let variable_slice = subs.variable_slices[variables.start as usize];
match variable_slice.len() {
0 => Ok(LambdaSet::Empty {
symbol: symbol_index,
}),
1 => {
let var = subs.variables[variable_slice.start as usize];
let layout = Layout::from_var(layouts, subs, var)?;
let index = Index::new(layouts.layouts.len() as u32);
layouts.layouts.push(layout);
Ok(LambdaSet::Single {
symbol: symbol_index,
layout: index,
})
}
_ => {
let slice = Layout::from_variable_slice(layouts, subs, variable_slice)?;
Ok(LambdaSet::Struct {
symbol: symbol_index,
layouts: slice,
})
}
}
} else {
let layouts = Layout::from_slice_variable_slice(layouts, subs, solved.variables())?;
Ok(LambdaSet::Union {
symbols: closure_names,
layouts,
})
}
}
fn get_closure_names(
layouts: &mut Layouts,
subs: &Subs,
subs_slice: roc_types::subs::SubsSlice<Symbol>,
) -> Slice<Symbol> {
let slice = Slice::new(layouts.symbols.len() as u32, subs_slice.len() as u16);
let symbols = &subs.closure_names[subs_slice.indices()];
for symbol in symbols {
layouts.symbols.push(*symbol);
}
slice
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Layout {
// theory: we can zero out memory to reserve space for many layouts
Reserved,
// Question: where to store signedness information?
Int(IntWidth),
Float(FloatWidth),
Decimal,
Str,
Dict(Index<(Layout, Layout)>),
Set(Index<Layout>),
List(Index<Layout>),
Struct(Slice<Layout>),
UnionNonRecursive(Slice<Slice<Layout>>),
Boxed(Index<Layout>),
UnionRecursive(Slice<Slice<Layout>>),
// UnionNonNullableUnwrapped(Slice<Layout>),
// UnionNullableWrapper {
// data: NullableUnionIndex,
// tag_id: u16,
// },
//
// UnionNullableUnwrappedTrue(Slice<Layout>),
// UnionNullableUnwrappedFalse(Slice<Layout>),
// RecursivePointer,
}
fn round_up_to_alignment(unaligned: u16, alignment_bytes: u16) -> u16 {
let unaligned = unaligned as i32;
let alignment_bytes = alignment_bytes as i32;
if alignment_bytes <= 1 {
return unaligned as u16;
}
if alignment_bytes.count_ones() != 1 {
panic!(
"Cannot align to {} bytes. Not a power of 2.",
alignment_bytes
);
}
let mut aligned = unaligned;
aligned += alignment_bytes - 1; // if lower bits are non-zero, push it over the next boundary
aligned &= -alignment_bytes; // mask with a flag that has upper bits 1, lower bits 0
aligned as u16
}
impl Layouts {
const VOID_INDEX: Index<Layout> = Index::new(0);
const VOID_TUPLE: Index<(Layout, Layout)> = Index::new(0);
const UNIT_INDEX: Index<Layout> = Index::new(2);
pub fn new(target_info: TargetInfo) -> Self {
let mut layouts = Vec::with_capacity(64);
layouts.push(Layout::VOID);
layouts.push(Layout::VOID);
layouts.push(Layout::UNIT);
// sanity check
debug_assert_eq!(layouts[Self::VOID_INDEX.index as usize], Layout::VOID);
debug_assert_eq!(layouts[Self::VOID_TUPLE.index as usize + 1], Layout::VOID);
debug_assert_eq!(layouts[Self::UNIT_INDEX.index as usize], Layout::UNIT);
Layouts {
layouts: Vec::default(),
layout_slices: Vec::default(),
lambda_sets: Vec::default(),
symbols: Vec::default(),
recursion_variable_to_structure_variable_map: MutMap::default(),
target_info,
}
}
/// sort a slice according to elements' alignment
fn sort_slice_by_alignment(&mut self, layout_slice: Slice<Layout>) {
let slice = &mut self.layouts[layout_slice.indices()];
// SAFETY: the align_of function does not mutate the layouts vector
// this unsafety is required to circumvent the borrow checker
let sneaky_slice =
unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr(), slice.len()) };
sneaky_slice.sort_by(|layout1, layout2| {
let align1 = self.align_of_layout(*layout1);
let align2 = self.align_of_layout(*layout2);
// we want the biggest alignment first
align2.cmp(&align1)
});
}
fn usize(&self) -> Layout {
let usize_int_width = match self.target_info.ptr_width() {
roc_target::PtrWidth::Bytes4 => IntWidth::U32,
roc_target::PtrWidth::Bytes8 => IntWidth::U64,
};
Layout::Int(usize_int_width)
}
fn align_of_layout_index(&self, index: Index<Layout>) -> u16 {
let layout = self.layouts[index.index as usize];
self.align_of_layout(layout)
}
fn align_of_layout(&self, layout: Layout) -> u16 {
let usize_int_width = match self.target_info.ptr_width() {
roc_target::PtrWidth::Bytes4 => IntWidth::U32,
roc_target::PtrWidth::Bytes8 => IntWidth::U64,
};
let ptr_alignment = usize_int_width.alignment_bytes(self.target_info) as u16;
match layout {
Layout::Reserved => unreachable!(),
Layout::Int(int_width) => int_width.alignment_bytes(self.target_info) as u16,
Layout::Float(float_width) => float_width.alignment_bytes(self.target_info) as u16,
Layout::Decimal => IntWidth::U128.alignment_bytes(self.target_info) as u16,
Layout::Str | Layout::Dict(_) | Layout::Set(_) | Layout::List(_) => ptr_alignment,
Layout::Struct(slice) => self.align_of_layout_slice(slice),
Layout::Boxed(_) | Layout::UnionRecursive(_) => ptr_alignment,
Layout::UnionNonRecursive(slices) => {
let tag_id_align = IntWidth::I64.alignment_bytes(self.target_info) as u16;
self.align_of_layout_slices(slices).max(tag_id_align)
}
// Layout::UnionNonNullableUnwrapped(_) => todo!(),
// Layout::UnionNullableWrapper { data, tag_id } => todo!(),
// Layout::UnionNullableUnwrappedTrue(_) => todo!(),
// Layout::UnionNullableUnwrappedFalse(_) => todo!(),
// Layout::RecursivePointer => todo!(),
}
}
/// Invariant: the layouts are sorted from biggest to smallest alignment
fn align_of_layout_slice(&self, slice: Slice<Layout>) -> u16 {
match slice.into_iter().next() {
None => 0,
Some(first_index) => self.align_of_layout_index(first_index),
}
}
fn align_of_layout_slices(&self, slice: Slice<Slice<Layout>>) -> u16 {
slice
.into_iter()
.map(|index| self.layout_slices[index.index as usize])
.map(|slice| self.align_of_layout_slice(slice))
.max()
.unwrap_or_default()
}
/// Invariant: the layouts are sorted from biggest to smallest alignment
fn size_of_layout_slice(&self, slice: Slice<Layout>) -> u16 {
match slice.into_iter().next() {
None => 0,
Some(first_index) => {
let alignment = self.align_of_layout_index(first_index);
let mut sum = 0;
for index in slice.into_iter() {
sum += self.size_of_layout_index(index);
}
round_up_to_alignment(sum, alignment)
}
}
}
pub fn size_of_layout_index(&self, index: Index<Layout>) -> u16 {
let layout = self.layouts[index.index as usize];
self.size_of_layout(layout)
}
pub fn size_of_layout(&self, layout: Layout) -> u16 {
let usize_int_width = match self.target_info.ptr_width() {
roc_target::PtrWidth::Bytes4 => IntWidth::U32,
roc_target::PtrWidth::Bytes8 => IntWidth::U64,
};
let ptr_width = usize_int_width.stack_size() as u16;
match layout {
Layout::Reserved => unreachable!(),
Layout::Int(int_width) => int_width.stack_size() as _,
Layout::Float(float_width) => float_width as _,
Layout::Decimal => (std::mem::size_of::<roc_std::RocDec>()) as _,
Layout::Str | Layout::Dict(_) | Layout::Set(_) | Layout::List(_) => 2 * ptr_width,
Layout::Struct(slice) => self.size_of_layout_slice(slice),
Layout::Boxed(_) | Layout::UnionRecursive(_) => ptr_width,
Layout::UnionNonRecursive(slices) if slices.is_empty() => 0,
Layout::UnionNonRecursive(slices) => {
let tag_id = IntWidth::I64;
let max_slice_size = slices
.into_iter()
.map(|index| self.layout_slices[index.index as usize])
.map(|slice| self.align_of_layout_slice(slice))
.max()
.unwrap_or_default();
tag_id.stack_size() as u16 + max_slice_size
}
// Layout::UnionNonNullableUnwrapped(_) => todo!(),
// Layout::UnionNullableWrapper { data, tag_id } => todo!(),
// Layout::UnionNullableUnwrappedTrue(_) => todo!(),
// Layout::UnionNullableUnwrappedFalse(_) => todo!(),
// Layout::RecursivePointer => todo!(),
}
}
}
pub enum LayoutError {
UnresolvedVariable(Variable),
TypeError(()),
}
impl Layout {
pub const UNIT: Self = Self::Struct(Slice::new(0, 0));
pub const VOID: Self = Self::UnionNonRecursive(Slice::new(0, 0));
pub const EMPTY_LIST: Self = Self::List(Layouts::VOID_INDEX);
pub const EMPTY_DICT: Self = Self::Dict(Layouts::VOID_TUPLE);
pub const EMPTY_SET: Self = Self::Set(Layouts::VOID_INDEX);
pub fn from_var(
layouts: &mut Layouts,
subs: &Subs,
var: Variable,
) -> Result<Layout, LayoutError> {
// so we can set some things/clean up
Self::from_var_help(layouts, subs, var)
}
fn from_var_help(
layouts: &mut Layouts,
subs: &Subs,
var: Variable,
) -> Result<Layout, LayoutError> {
let content = &subs.get_content_without_compacting(var);
Self::from_content(layouts, subs, var, content)
}
/// Used in situations where an unspecialized variable is not a problem,
/// and we can substitute with `[]`, the empty tag union.
/// e.g. an empty list literal has type `List *`. We can still generate code
/// in those cases by just picking any concrete type for the list element,
/// and we pick the empty tag union in practice.
fn from_var_help_or_void(
layouts: &mut Layouts,
subs: &Subs,
var: Variable,
) -> Result<Layout, LayoutError> {
let content = &subs.get_content_without_compacting(var);
match content {
Content::FlexVar(_) | Content::RigidVar(_) => Ok(Layout::VOID),
_ => Self::from_content(layouts, subs, var, content),
}
}
fn from_content(
layouts: &mut Layouts,
subs: &Subs,
var: Variable,
content: &Content,
) -> Result<Layout, LayoutError> {
use LayoutError::*;
match content {
Content::FlexVar(_)
| Content::RigidVar(_)
| Content::FlexAbleVar(_, _)
| Content::RigidAbleVar(_, _) => Err(UnresolvedVariable(var)),
Content::RecursionVar {
structure,
opt_name: _,
} => {
let structure = subs.get_root_key_without_compacting(*structure);
let entry = layouts
.recursion_variable_to_structure_variable_map
.entry(structure);
match entry {
Entry::Vacant(vacant) => {
let reserved = Index::new(layouts.layouts.len() as _);
layouts.layouts.push(Layout::Reserved);
vacant.insert(reserved);
let layout = Layout::from_var(layouts, subs, structure)?;
layouts.layouts[reserved.index as usize] = layout;
Ok(Layout::Boxed(reserved))
}
Entry::Occupied(occupied) => {
let index = occupied.get();
Ok(Layout::Boxed(*index))
}
}
}
// Lambda set layout is same as tag union
Content::LambdaSet(lset) => Self::from_lambda_set(layouts, subs, *lset),
Content::Structure(flat_type) => Self::from_flat_type(layouts, subs, flat_type),
Content::Alias(symbol, _, actual, _) => {
let symbol = *symbol;
if let Some(int_width) = IntWidth::try_from_symbol(symbol) {
return Ok(Layout::Int(int_width));
}
if let Some(float_width) = FloatWidth::try_from_symbol(symbol) {
return Ok(Layout::Float(float_width));
}
match symbol {
Symbol::NUM_DECIMAL => Ok(Layout::Decimal),
Symbol::NUM_NAT | Symbol::NUM_NATURAL => Ok(layouts.usize()),
_ => {
// at this point we throw away alias information
Self::from_var_help(layouts, subs, *actual)
}
}
}
Content::RangedNumber(typ, _) => Self::from_var_help(layouts, subs, *typ),
Content::Error => Err(TypeError(())),
}
}
fn from_lambda_set(
layouts: &mut Layouts,
subs: &Subs,
lset: subs::LambdaSet,
) -> Result<Layout, LayoutError> {
let subs::LambdaSet {
solved,
recursion_var,
unspecialized: _,
} = lset;
// TODO: handle unspecialized lambda set
match recursion_var.into_variable() {
Some(rec_var) => {
let rec_var = subs.get_root_key_without_compacting(rec_var);
let cached = layouts
.recursion_variable_to_structure_variable_map
.get(&rec_var);
if let Some(layout_index) = cached {
match layouts.layouts[layout_index.index as usize] {
Layout::Reserved => {
// we have to do the work here to fill this reserved variable in
}
other => {
return Ok(other);
}
}
}
let slices = Self::from_slice_variable_slice(layouts, subs, solved.variables())?;
Ok(Layout::UnionRecursive(slices))
}
None => {
let slices = Self::from_slice_variable_slice(layouts, subs, solved.variables())?;
Ok(Layout::UnionNonRecursive(slices))
}
}
}
fn from_flat_type(
layouts: &mut Layouts,
subs: &Subs,
flat_type: &FlatType,
) -> Result<Layout, LayoutError> {
use LayoutError::*;
match flat_type {
FlatType::Apply(Symbol::LIST_LIST, arguments) => {
debug_assert_eq!(arguments.len(), 1);
let element_var = subs.variables[arguments.start as usize];
let element_layout = Self::from_var_help_or_void(layouts, subs, element_var)?;
let element_index = Index::new(layouts.layouts.len() as _);
layouts.layouts.push(element_layout);
Ok(Layout::List(element_index))
}
FlatType::Apply(Symbol::DICT_DICT, arguments) => {
debug_assert_eq!(arguments.len(), 2);
let key_var = subs.variables[arguments.start as usize];
let value_var = subs.variables[arguments.start as usize + 1];
let key_layout = Self::from_var_help_or_void(layouts, subs, key_var)?;
let value_layout = Self::from_var_help_or_void(layouts, subs, value_var)?;
let index = Index::new(layouts.layouts.len() as _);
layouts.layouts.push(key_layout);
layouts.layouts.push(value_layout);
Ok(Layout::Dict(index))
}
FlatType::Apply(Symbol::SET_SET, arguments) => {
debug_assert_eq!(arguments.len(), 1);
let element_var = subs.variables[arguments.start as usize];
let element_layout = Self::from_var_help_or_void(layouts, subs, element_var)?;
let element_index = Index::new(layouts.layouts.len() as _);
layouts.layouts.push(element_layout);
Ok(Layout::Set(element_index))
}
FlatType::Apply(symbol, _) => {
unreachable!("Symbol {:?} does not have a layout", symbol)
}
FlatType::Func(_arguments, lambda_set, _result) => {
// in this case, a function (pointer) is represented by the environment it
// captures: the lambda set
Self::from_var_help(layouts, subs, *lambda_set)
}
FlatType::Record(fields, ext) => {
debug_assert!(ext_var_is_empty_record(subs, *ext));
let mut slice = Slice::reserve(layouts, fields.len());
let mut non_optional_fields = 0;
let it = slice.indices().zip(fields.iter_all());
for (target_index, (_, field_index, var_index)) in it {
match subs.record_fields[field_index.index as usize] {
RecordField::Optional(_) => {
// do nothing
}
RecordField::Required(_) | RecordField::Demanded(_) => {
let var = subs.variables[var_index.index as usize];
let layout = Layout::from_var_help(layouts, subs, var)?;
layouts.layouts[target_index] = layout;
non_optional_fields += 1;
}
}
}
// we have some wasted space in the case of optional fields; so be it
slice.length = non_optional_fields;
layouts.sort_slice_by_alignment(slice);
Ok(Layout::Struct(slice))
}
FlatType::TagUnion(union_tags, ext) => {
debug_assert!(ext_var_is_empty_tag_union(subs, *ext));
let slices =
Self::from_slice_variable_slice(layouts, subs, union_tags.variables())?;
Ok(Layout::UnionNonRecursive(slices))
}
FlatType::FunctionOrTagUnion(_, _, ext) => {
debug_assert!(ext_var_is_empty_tag_union(subs, *ext));
// at this point we know this is a tag
Ok(Layout::UNIT)
}
FlatType::RecursiveTagUnion(rec_var, union_tags, ext) => {
debug_assert!(ext_var_is_empty_tag_union(subs, *ext));
let rec_var = subs.get_root_key_without_compacting(*rec_var);
let cached = layouts
.recursion_variable_to_structure_variable_map
.get(&rec_var);
if let Some(layout_index) = cached {
match layouts.layouts[layout_index.index as usize] {
Layout::Reserved => {
// we have to do the work here to fill this reserved variable in
}
other => {
return Ok(other);
}
}
}
let slices =
Self::from_slice_variable_slice(layouts, subs, union_tags.variables())?;
Ok(Layout::UnionRecursive(slices))
}
FlatType::Erroneous(_) => Err(TypeError(())),
FlatType::EmptyRecord => Ok(Layout::UNIT),
FlatType::EmptyTagUnion => Ok(Layout::VOID),
}
}
fn from_slice_variable_slice(
layouts: &mut Layouts,
subs: &Subs,
slice_variable_slice: roc_types::subs::SubsSlice<roc_types::subs::VariableSubsSlice>,
) -> Result<Slice<Slice<Layout>>, LayoutError> {
let slice = Slice::reserve(layouts, slice_variable_slice.len());
let variable_slices = &subs.variable_slices[slice_variable_slice.indices()];
let it = slice.indices().zip(variable_slices);
for (target_index, variable_slice) in it {
let layout_slice = Layout::from_variable_slice(layouts, subs, *variable_slice)?;
layouts.layout_slices[target_index] = layout_slice;
}
Ok(slice)
}
fn from_variable_slice(
layouts: &mut Layouts,
subs: &Subs,
variable_subs_slice: roc_types::subs::VariableSubsSlice,
) -> Result<Slice<Layout>, LayoutError> {
let slice = Slice::reserve(layouts, variable_subs_slice.len());
let variable_slice = &subs.variables[variable_subs_slice.indices()];
let it = slice.indices().zip(variable_slice);
for (target_index, var) in it {
let layout = Layout::from_var_help(layouts, subs, *var)?;
layouts.layouts[target_index] = layout;
}
layouts.sort_slice_by_alignment(slice);
Ok(slice)
}
}

View file

@ -0,0 +1,19 @@
#![warn(clippy::dbg_macro)]
// See github.com/rtfeldman/roc/issues/800 for discussion of the large_enum_variant check.
#![allow(clippy::large_enum_variant, clippy::upper_case_acronyms)]
pub mod borrow;
pub mod code_gen_help;
mod copy;
pub mod inc_dec;
pub mod ir;
pub mod layout;
pub mod layout_soa;
pub mod low_level;
pub mod reset_reuse;
pub mod tail_recursion;
// Temporary, while we can build up test cases and optimize the exhaustiveness checking.
// For now, following this warning's advice will lead to nasty type inference errors.
//#[allow(clippy::ptr_arg)]
pub mod decision_tree;

View file

@ -0,0 +1,210 @@
use roc_module::symbol::Symbol;
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum HigherOrder {
ListMap {
xs: Symbol,
},
ListMap2 {
xs: Symbol,
ys: Symbol,
},
ListMap3 {
xs: Symbol,
ys: Symbol,
zs: Symbol,
},
ListMap4 {
xs: Symbol,
ys: Symbol,
zs: Symbol,
ws: Symbol,
},
ListMapWithIndex {
xs: Symbol,
},
ListKeepIf {
xs: Symbol,
},
ListWalk {
xs: Symbol,
state: Symbol,
},
ListWalkUntil {
xs: Symbol,
state: Symbol,
},
ListWalkBackwards {
xs: Symbol,
state: Symbol,
},
ListKeepOks {
xs: Symbol,
},
ListKeepErrs {
xs: Symbol,
},
ListSortWith {
xs: Symbol,
},
ListAny {
xs: Symbol,
},
ListAll {
xs: Symbol,
},
ListFindUnsafe {
xs: Symbol,
},
DictWalk {
xs: Symbol,
state: Symbol,
},
}
impl HigherOrder {
pub fn function_arity(&self) -> usize {
match self {
HigherOrder::ListMap { .. } => 1,
HigherOrder::ListMap2 { .. } => 2,
HigherOrder::ListMap3 { .. } => 3,
HigherOrder::ListMap4 { .. } => 4,
HigherOrder::ListMapWithIndex { .. } => 2,
HigherOrder::ListKeepIf { .. } => 1,
HigherOrder::ListWalk { .. } => 2,
HigherOrder::ListWalkUntil { .. } => 2,
HigherOrder::ListWalkBackwards { .. } => 2,
HigherOrder::ListKeepOks { .. } => 1,
HigherOrder::ListKeepErrs { .. } => 1,
HigherOrder::ListSortWith { .. } => 2,
HigherOrder::ListFindUnsafe { .. } => 1,
HigherOrder::DictWalk { .. } => 2,
HigherOrder::ListAny { .. } => 1,
HigherOrder::ListAll { .. } => 1,
}
}
/// Index in the array of arguments of the symbol that is the closure data
/// (captured environment for this function)
pub const fn closure_data_index(&self) -> usize {
use HigherOrder::*;
match self {
ListMap { .. }
| ListMapWithIndex { .. }
| ListSortWith { .. }
| ListKeepIf { .. }
| ListKeepOks { .. }
| ListKeepErrs { .. }
| ListAny { .. }
| ListAll { .. }
| ListFindUnsafe { .. } => 2,
ListMap2 { .. } => 3,
ListMap3 { .. } => 4,
ListMap4 { .. } => 5,
ListWalk { .. } | ListWalkUntil { .. } | ListWalkBackwards { .. } | DictWalk { .. } => {
3
}
}
}
/// Index of the function symbol in the argument list
pub const fn function_index(&self) -> usize {
self.closure_data_index() - 1
}
}
#[allow(dead_code)]
enum FirstOrder {
StrConcat,
StrJoinWith,
StrIsEmpty,
StrStartsWith,
StrStartsWithCodePt,
StrEndsWith,
StrSplit,
StrCountGraphemes,
StrFromInt,
StrFromUtf8,
StrFromUtf8Range,
StrToUtf8,
StrRepeat,
StrFromFloat,
ListLen,
ListGetUnsafe,
ListSet,
ListSublist,
ListDropAt,
ListSingle,
ListRepeat,
ListReverse,
ListConcat,
ListContains,
ListAppend,
ListPrepend,
ListJoin,
ListRange,
ListSwap,
DictSize,
DictEmpty,
DictInsert,
DictRemove,
DictContains,
DictGetUnsafe,
DictKeys,
DictValues,
DictUnion,
DictIntersection,
DictDifference,
SetFromList,
NumAdd,
NumAddWrap,
NumAddChecked,
NumSub,
NumSubWrap,
NumSubChecked,
NumMul,
NumMulWrap,
NumMulSaturated,
NumMulChecked,
NumGt,
NumGte,
NumLt,
NumLte,
NumCompare,
NumDivUnchecked,
NumRemUnchecked,
NumIsMultipleOf,
NumAbs,
NumNeg,
NumSin,
NumCos,
NumSqrtUnchecked,
NumLogUnchecked,
NumRound,
NumToFrac,
NumPow,
NumCeiling,
NumPowInt,
NumFloor,
NumIsFinite,
NumAtan,
NumAcos,
NumAsin,
NumBitwiseAnd,
NumBitwiseXor,
NumBitwiseOr,
NumShiftLeftBy,
NumShiftRightBy,
NumBytesToU16,
NumBytesToU32,
NumShiftRightZfBy,
NumIntCast,
NumFloatCast,
Eq,
NotEq,
And,
Or,
Not,
Hash,
}

View file

@ -0,0 +1,722 @@
//! This module inserts reset/reuse statements into the mono IR. These statements provide an
//! opportunity to reduce memory pressure by reusing memory slots of non-shared values. From the
//! introduction of the relevant paper:
//!
//! > [We] have added two additional instructions to our IR: `let y = reset x` and
//! > `let z = (reuse y in ctor_i w)`. The two instructions are used together; if `x`
//! > is a shared value, then `y` is set to a special reference, and the reuse instruction
//! > just allocates a new constructor value `ctor_i w`. If `x` is not shared, then reset
//! > decrements the reference counters of the components of `x`, and `y` is set to `x`.
//! > Then, reuse reuses the memory cell used by `x` to store the constructor value `ctor_i w`.
//!
//! See also
//! - [Counting Immutable Beans](https://arxiv.org/pdf/1908.05647.pdf) (Ullrich and Moura, 2020)
//! - [The lean implementation](https://github.com/leanprover/lean4/blob/master/src/Lean/Compiler/IR/ResetReuse.lean)
use crate::inc_dec::{collect_stmt, occurring_variables_expr, JPLiveVarMap, LiveVarSet};
use crate::ir::{
BranchInfo, Call, Expr, ListLiteralElement, Proc, Stmt, UpdateModeId, UpdateModeIds,
};
use crate::layout::{Layout, TagIdIntType, UnionLayout};
use bumpalo::collections::Vec;
use bumpalo::Bump;
use roc_collections::all::MutSet;
use roc_module::symbol::{IdentIds, ModuleId, Symbol};
pub fn insert_reset_reuse<'a, 'i>(
arena: &'a Bump,
home: ModuleId,
ident_ids: &'i mut IdentIds,
update_mode_ids: &'i mut UpdateModeIds,
mut proc: Proc<'a>,
) -> Proc<'a> {
let mut env = Env {
arena,
home,
ident_ids,
update_mode_ids,
jp_live_vars: Default::default(),
};
let new_body = function_r(&mut env, arena.alloc(proc.body));
proc.body = new_body.clone();
proc
}
#[derive(Debug)]
struct CtorInfo<'a> {
id: TagIdIntType,
layout: UnionLayout<'a>,
}
fn may_reuse(tag_layout: UnionLayout, tag_id: TagIdIntType, other: &CtorInfo) -> bool {
if tag_layout != other.layout {
return false;
}
// we should not get here if the tag we matched on is represented as NULL
debug_assert!(!tag_layout.tag_is_null(other.id as _));
// furthermore, we can only use the memory if the tag we're creating is non-NULL
!tag_layout.tag_is_null(tag_id)
}
#[derive(Debug)]
struct Env<'a, 'i> {
arena: &'a Bump,
/// required for creating new `Symbol`s
home: ModuleId,
ident_ids: &'i mut IdentIds,
update_mode_ids: &'i mut UpdateModeIds,
jp_live_vars: JPLiveVarMap,
}
impl<'a, 'i> Env<'a, 'i> {
fn unique_symbol(&mut self) -> Symbol {
let ident_id = self.ident_ids.gen_unique();
Symbol::new(self.home, ident_id)
}
}
fn function_s<'a, 'i>(
env: &mut Env<'a, 'i>,
w: Opportunity,
c: &CtorInfo<'a>,
stmt: &'a Stmt<'a>,
) -> &'a Stmt<'a> {
use Stmt::*;
let arena = env.arena;
match stmt {
Let(symbol, expr, layout, continuation) => match expr {
Expr::Tag {
tag_layout,
tag_id,
tag_name,
arguments,
} if may_reuse(*tag_layout, *tag_id, c) => {
// for now, always overwrite the tag ID just to be sure
let update_tag_id = true;
let new_expr = Expr::Reuse {
symbol: w.symbol,
update_mode: w.update_mode,
update_tag_id,
tag_layout: *tag_layout,
tag_id: *tag_id,
tag_name: tag_name.clone(),
arguments,
};
let new_stmt = Let(*symbol, new_expr, *layout, continuation);
arena.alloc(new_stmt)
}
_ => {
let rest = function_s(env, w, c, continuation);
let new_stmt = Let(*symbol, expr.clone(), *layout, rest);
arena.alloc(new_stmt)
}
},
Join {
id,
parameters,
body,
remainder,
} => {
let id = *id;
let body: &Stmt = *body;
let new_body = function_s(env, w, c, body);
let new_join = if std::ptr::eq(body, new_body) || body == new_body {
// the join point body will consume w
Join {
id,
parameters,
body: new_body,
remainder,
}
} else {
let new_remainder = function_s(env, w, c, remainder);
Join {
id,
parameters,
body,
remainder: new_remainder,
}
};
arena.alloc(new_join)
}
Switch {
cond_symbol,
cond_layout,
branches,
default_branch,
ret_layout,
} => {
let mut new_branches = Vec::with_capacity_in(branches.len(), arena);
new_branches.extend(branches.iter().map(|(tag, info, body)| {
let new_body = function_s(env, w, c, body);
(*tag, info.clone(), new_body.clone())
}));
let new_default = function_s(env, w, c, default_branch.1);
let new_switch = Switch {
cond_symbol: *cond_symbol,
cond_layout: *cond_layout,
branches: new_branches.into_bump_slice(),
default_branch: (default_branch.0.clone(), new_default),
ret_layout: *ret_layout,
};
arena.alloc(new_switch)
}
Refcounting(op, continuation) => {
let continuation: &Stmt = *continuation;
let new_continuation = function_s(env, w, c, continuation);
if std::ptr::eq(continuation, new_continuation) || continuation == new_continuation {
stmt
} else {
let new_refcounting = Refcounting(*op, new_continuation);
arena.alloc(new_refcounting)
}
}
Expect {
condition,
region,
lookups,
layouts,
remainder,
} => {
let continuation: &Stmt = *remainder;
let new_continuation = function_s(env, w, c, continuation);
if std::ptr::eq(continuation, new_continuation) || continuation == new_continuation {
stmt
} else {
let new_refcounting = Expect {
condition: *condition,
region: *region,
lookups,
layouts,
remainder: new_continuation,
};
arena.alloc(new_refcounting)
}
}
Ret(_) | Jump(_, _) | RuntimeError(_) => stmt,
}
}
#[derive(Clone, Copy)]
struct Opportunity {
symbol: Symbol,
update_mode: UpdateModeId,
}
fn try_function_s<'a, 'i>(
env: &mut Env<'a, 'i>,
x: Symbol,
c: &CtorInfo<'a>,
stmt: &'a Stmt<'a>,
) -> &'a Stmt<'a> {
let w = Opportunity {
symbol: env.unique_symbol(),
update_mode: env.update_mode_ids.next_id(),
};
let new_stmt = function_s(env, w, c, stmt);
if std::ptr::eq(stmt, new_stmt) || stmt == new_stmt {
stmt
} else {
insert_reset(env, w, x, c.layout, new_stmt)
}
}
fn insert_reset<'a>(
env: &mut Env<'a, '_>,
w: Opportunity,
x: Symbol,
union_layout: UnionLayout<'a>,
mut stmt: &'a Stmt<'a>,
) -> &'a Stmt<'a> {
use crate::ir::Expr::*;
let mut stack = vec![];
while let Stmt::Let(symbol, expr, expr_layout, rest) = stmt {
match &expr {
StructAtIndex { .. } | GetTagId { .. } | UnionAtIndex { .. } => {
stack.push((symbol, expr, expr_layout));
stmt = rest;
}
ExprBox { .. } | ExprUnbox { .. } => {
// TODO
break;
}
Literal(_)
| Call(_)
| Tag { .. }
| Struct(_)
| Array { .. }
| EmptyArray
| Reuse { .. }
| Reset { .. }
| RuntimeErrorFunction(_) => break,
}
}
let reset_expr = Expr::Reset {
symbol: x,
update_mode: w.update_mode,
};
let layout = Layout::Union(union_layout);
stmt = env
.arena
.alloc(Stmt::Let(w.symbol, reset_expr, layout, stmt));
for (symbol, expr, expr_layout) in stack.into_iter().rev() {
stmt = env
.arena
.alloc(Stmt::Let(*symbol, expr.clone(), *expr_layout, stmt));
}
stmt
}
fn function_d_finalize<'a, 'i>(
env: &mut Env<'a, 'i>,
x: Symbol,
c: &CtorInfo<'a>,
output: (&'a Stmt<'a>, bool),
) -> &'a Stmt<'a> {
let (stmt, x_live_in_stmt) = output;
if x_live_in_stmt {
stmt
} else {
try_function_s(env, x, c, stmt)
}
}
fn function_d_main<'a, 'i>(
env: &mut Env<'a, 'i>,
x: Symbol,
c: &CtorInfo<'a>,
stmt: &'a Stmt<'a>,
) -> (&'a Stmt<'a>, bool) {
use Stmt::*;
let arena = env.arena;
match stmt {
Let(symbol, expr, layout, continuation) => {
match expr {
Expr::Tag { arguments, .. } if arguments.iter().any(|s| *s == x) => {
// If the scrutinee `x` (the one that is providing memory) is being
// stored in a constructor, then reuse will probably not be able to reuse memory at runtime.
// It may work only if the new cell is consumed, but we ignore this case.
(stmt, true)
}
_ => {
let (b, found) = function_d_main(env, x, c, continuation);
// NOTE the &b != continuation is not found in the Lean source, but is required
// otherwise we observe the same symbol being reset twice
let mut result = MutSet::default();
if found
|| {
occurring_variables_expr(expr, &mut result);
!result.contains(&x)
}
|| &b != continuation
{
let let_stmt = Let(*symbol, expr.clone(), *layout, b);
(arena.alloc(let_stmt), found)
} else {
let b = try_function_s(env, x, c, b);
let let_stmt = Let(*symbol, expr.clone(), *layout, b);
(arena.alloc(let_stmt), found)
}
}
}
}
Switch {
cond_symbol,
cond_layout,
branches,
default_branch,
ret_layout,
} => {
if has_live_var(&env.jp_live_vars, stmt, x) {
// if `x` is live in `stmt`, we recursively process each branch
let mut new_branches = Vec::with_capacity_in(branches.len(), arena);
for (tag, info, body) in branches.iter() {
let temp = function_d_main(env, x, c, body);
let new_body = function_d_finalize(env, x, c, temp);
new_branches.push((*tag, info.clone(), new_body.clone()));
}
let new_default = {
let (info, body) = default_branch;
let temp = function_d_main(env, x, c, body);
let new_body = function_d_finalize(env, x, c, temp);
(info.clone(), new_body)
};
let new_switch = Switch {
cond_symbol: *cond_symbol,
cond_layout: *cond_layout,
branches: new_branches.into_bump_slice(),
default_branch: new_default,
ret_layout: *ret_layout,
};
(arena.alloc(new_switch), true)
} else {
(stmt, false)
}
}
Refcounting(modify_rc, continuation) => {
let (b, found) = function_d_main(env, x, c, continuation);
if found || modify_rc.get_symbol() != x {
let refcounting = Refcounting(*modify_rc, b);
(arena.alloc(refcounting), found)
} else {
let b = try_function_s(env, x, c, b);
let refcounting = Refcounting(*modify_rc, b);
(arena.alloc(refcounting), found)
}
}
Expect {
condition,
region,
lookups,
layouts,
remainder,
} => {
let (b, found) = function_d_main(env, x, c, remainder);
if found || *condition != x {
let refcounting = Expect {
condition: *condition,
region: *region,
lookups,
layouts,
remainder: b,
};
(arena.alloc(refcounting), found)
} else {
let b = try_function_s(env, x, c, b);
let refcounting = Expect {
condition: *condition,
region: *region,
lookups,
layouts,
remainder: b,
};
(arena.alloc(refcounting), found)
}
}
Join {
id,
parameters,
body,
remainder,
} => {
env.jp_live_vars.insert(*id, LiveVarSet::default());
let body_live_vars = collect_stmt(body, &env.jp_live_vars, LiveVarSet::default());
env.jp_live_vars.insert(*id, body_live_vars);
let (b, found) = function_d_main(env, x, c, remainder);
let (v, _found) = function_d_main(env, x, c, body);
env.jp_live_vars.remove(id);
// If `found' == true`, then `Dmain b` must also have returned `(b, true)` since
// we assume the IR does not have dead join points. So, if `x` is live in `j` (i.e., `v`),
// then it must also live in `b` since `j` is reachable from `b` with a `jmp`.
// On the other hand, `x` may be live in `b` but dead in `j` (i.e., `v`). -/
let new_join = Join {
id: *id,
parameters,
body: v,
remainder: b,
};
(arena.alloc(new_join), found)
}
Ret(_) | Jump(_, _) | RuntimeError(_) => (stmt, has_live_var(&env.jp_live_vars, stmt, x)),
}
}
fn function_d<'a, 'i>(
env: &mut Env<'a, 'i>,
x: Symbol,
c: &CtorInfo<'a>,
stmt: &'a Stmt<'a>,
) -> &'a Stmt<'a> {
let temp = function_d_main(env, x, c, stmt);
function_d_finalize(env, x, c, temp)
}
fn function_r_branch_body<'a, 'i>(
env: &mut Env<'a, 'i>,
info: &BranchInfo<'a>,
body: &'a Stmt<'a>,
) -> &'a Stmt<'a> {
let temp = function_r(env, body);
match info {
BranchInfo::None => temp,
BranchInfo::Constructor {
scrutinee,
layout,
tag_id,
} => match layout {
Layout::Union(UnionLayout::NonRecursive(_)) => temp,
Layout::Union(union_layout) if !union_layout.tag_is_null(*tag_id) => {
let ctor_info = CtorInfo {
layout: *union_layout,
id: *tag_id,
};
function_d(env, *scrutinee, &ctor_info, temp)
}
_ => temp,
},
}
}
fn function_r<'a, 'i>(env: &mut Env<'a, 'i>, stmt: &'a Stmt<'a>) -> &'a Stmt<'a> {
use Stmt::*;
let arena = env.arena;
match stmt {
Switch {
cond_symbol,
cond_layout,
branches,
default_branch,
ret_layout,
} => {
let mut new_branches = Vec::with_capacity_in(branches.len(), arena);
for (tag, info, body) in branches.iter() {
let new_body = function_r_branch_body(env, info, body);
new_branches.push((*tag, info.clone(), new_body.clone()));
}
let new_default = {
let (info, body) = default_branch;
let new_body = function_r_branch_body(env, info, body);
(info.clone(), new_body)
};
let new_switch = Switch {
cond_symbol: *cond_symbol,
cond_layout: *cond_layout,
branches: new_branches.into_bump_slice(),
default_branch: new_default,
ret_layout: *ret_layout,
};
arena.alloc(new_switch)
}
Join {
id,
parameters,
body,
remainder,
} => {
env.jp_live_vars.insert(*id, LiveVarSet::default());
let body_live_vars = collect_stmt(body, &env.jp_live_vars, LiveVarSet::default());
env.jp_live_vars.insert(*id, body_live_vars);
let b = function_r(env, remainder);
let v = function_r(env, body);
env.jp_live_vars.remove(id);
let join = Join {
id: *id,
parameters,
body: v,
remainder: b,
};
arena.alloc(join)
}
Let(symbol, expr, layout, continuation) => {
let b = function_r(env, continuation);
arena.alloc(Let(*symbol, expr.clone(), *layout, b))
}
Refcounting(modify_rc, continuation) => {
let b = function_r(env, continuation);
arena.alloc(Refcounting(*modify_rc, b))
}
Expect {
condition,
region,
lookups,
layouts,
remainder,
} => {
let b = function_r(env, remainder);
let expect = Expect {
condition: *condition,
region: *region,
lookups,
layouts,
remainder: b,
};
arena.alloc(expect)
}
Ret(_) | Jump(_, _) | RuntimeError(_) => {
// terminals
stmt
}
}
}
fn has_live_var<'a>(jp_live_vars: &JPLiveVarMap, stmt: &'a Stmt<'a>, needle: Symbol) -> bool {
use Stmt::*;
match stmt {
Let(s, e, _, c) => {
debug_assert_ne!(*s, needle);
has_live_var_expr(e, needle) || has_live_var(jp_live_vars, c, needle)
}
Switch { cond_symbol, .. } if *cond_symbol == needle => true,
Switch {
branches,
default_branch,
..
} => {
has_live_var(jp_live_vars, default_branch.1, needle)
|| branches
.iter()
.any(|(_, _, body)| has_live_var(jp_live_vars, body, needle))
}
Ret(s) => *s == needle,
Refcounting(modify_rc, cont) => {
modify_rc.get_symbol() == needle || has_live_var(jp_live_vars, cont, needle)
}
Expect {
condition,
remainder,
..
} => *condition == needle || has_live_var(jp_live_vars, remainder, needle),
Join {
id,
parameters,
body,
remainder,
} => {
debug_assert!(parameters.iter().all(|p| p.symbol != needle));
let mut jp_live_vars = jp_live_vars.clone();
jp_live_vars.insert(*id, LiveVarSet::default());
let body_live_vars = collect_stmt(body, &jp_live_vars, LiveVarSet::default());
if body_live_vars.contains(&needle) {
return true;
}
jp_live_vars.insert(*id, body_live_vars);
has_live_var(&jp_live_vars, remainder, needle)
}
Jump(id, arguments) => {
arguments.iter().any(|s| *s == needle) || jp_live_vars[id].contains(&needle)
}
RuntimeError(_) => false,
}
}
fn has_live_var_expr<'a>(expr: &'a Expr<'a>, needle: Symbol) -> bool {
match expr {
Expr::Literal(_) => false,
Expr::Call(call) => has_live_var_call(call, needle),
Expr::Array { elems: fields, .. } => {
for element in fields.iter() {
if let ListLiteralElement::Symbol(s) = element {
if *s == needle {
return true;
}
}
}
false
}
Expr::Tag {
arguments: fields, ..
}
| Expr::Struct(fields) => fields.iter().any(|s| *s == needle),
Expr::StructAtIndex { structure, .. }
| Expr::GetTagId { structure, .. }
| Expr::UnionAtIndex { structure, .. } => *structure == needle,
Expr::EmptyArray => false,
Expr::Reuse {
symbol, arguments, ..
} => needle == *symbol || arguments.iter().any(|s| *s == needle),
Expr::Reset { symbol, .. } => needle == *symbol,
Expr::ExprBox { symbol, .. } => needle == *symbol,
Expr::ExprUnbox { symbol, .. } => needle == *symbol,
Expr::RuntimeErrorFunction(_) => false,
}
}
fn has_live_var_call<'a>(call: &'a Call<'a>, needle: Symbol) -> bool {
call.arguments.iter().any(|s| *s == needle)
}

View file

@ -0,0 +1,280 @@
#![allow(clippy::manual_map)]
use crate::ir::{CallType, Expr, JoinPointId, Param, Stmt};
use crate::layout::Layout;
use bumpalo::collections::Vec;
use bumpalo::Bump;
use roc_module::symbol::Symbol;
/// Make tail calls into loops (using join points)
///
/// e.g.
///
/// > factorial n accum = if n == 1 then accum else factorial (n - 1) (n * accum)
///
/// becomes
///
/// ```elm
/// factorial n1 accum1 =
/// let joinpoint j n accum =
/// if n == 1 then
/// accum
/// else
/// jump j (n - 1) (n * accum)
///
/// in
/// jump j n1 accum1
/// ```
///
/// This will effectively compile into a loop in llvm, and
/// won't grow the call stack for each iteration
pub fn make_tail_recursive<'a>(
arena: &'a Bump,
id: JoinPointId,
needle: Symbol,
stmt: Stmt<'a>,
args: &'a [(Layout<'a>, Symbol, Symbol)],
ret_layout: Layout,
) -> Option<Stmt<'a>> {
let allocated = arena.alloc(stmt);
let new_stmt = insert_jumps(arena, allocated, id, needle, args, ret_layout)?;
// if we did not early-return, jumps were inserted, we must now add a join point
let params = Vec::from_iter_in(
args.iter().map(|(layout, symbol, _)| Param {
symbol: *symbol,
layout: *layout,
borrow: true,
}),
arena,
)
.into_bump_slice();
// TODO could this be &[]?
let args = Vec::from_iter_in(args.iter().map(|t| t.2), arena).into_bump_slice();
let jump = arena.alloc(Stmt::Jump(id, args));
let join = Stmt::Join {
id,
remainder: jump,
parameters: params,
body: new_stmt,
};
Some(join)
}
fn insert_jumps<'a>(
arena: &'a Bump,
stmt: &'a Stmt<'a>,
goal_id: JoinPointId,
needle: Symbol,
needle_arguments: &'a [(Layout<'a>, Symbol, Symbol)],
needle_result: Layout,
) -> Option<&'a Stmt<'a>> {
use Stmt::*;
// to insert a tail-call, it must not just be a call to the function itself, but it must also
// have the same layout. In particular when lambda sets get involved, a self-recursive call may
// have a different type and should not be converted to a jump!
let is_equal_function = |function_name: Symbol, arguments: &[_], result| {
let it = needle_arguments.iter().map(|t| &t.0);
needle == function_name && it.eq(arguments.iter()) && needle_result == result
};
match stmt {
Let(
symbol,
Expr::Call(crate::ir::Call {
call_type:
CallType::ByName {
name: fsym,
ret_layout,
arg_layouts,
..
},
arguments,
}),
_,
Stmt::Ret(rsym),
) if symbol == rsym && is_equal_function(*fsym, arg_layouts, **ret_layout) => {
// replace the call and return with a jump
let jump = Stmt::Jump(goal_id, arguments);
Some(arena.alloc(jump))
}
Let(symbol, expr, layout, cont) => {
let opt_cont = insert_jumps(
arena,
cont,
goal_id,
needle,
needle_arguments,
needle_result,
);
if opt_cont.is_some() {
let cont = opt_cont.unwrap_or(cont);
Some(arena.alloc(Let(*symbol, expr.clone(), *layout, cont)))
} else {
None
}
}
Join {
id,
parameters,
remainder,
body: continuation,
} => {
let opt_remainder = insert_jumps(
arena,
remainder,
goal_id,
needle,
needle_arguments,
needle_result,
);
let opt_continuation = insert_jumps(
arena,
continuation,
goal_id,
needle,
needle_arguments,
needle_result,
);
if opt_remainder.is_some() || opt_continuation.is_some() {
let remainder = opt_remainder.unwrap_or(remainder);
let continuation = opt_continuation.unwrap_or(*continuation);
Some(arena.alloc(Join {
id: *id,
parameters,
remainder,
body: continuation,
}))
} else {
None
}
}
Switch {
cond_symbol,
cond_layout,
branches,
default_branch,
ret_layout,
} => {
let opt_default = insert_jumps(
arena,
default_branch.1,
goal_id,
needle,
needle_arguments,
needle_result,
);
let mut did_change = false;
let opt_branches = Vec::from_iter_in(
branches.iter().map(|(label, info, branch)| {
match insert_jumps(
arena,
branch,
goal_id,
needle,
needle_arguments,
needle_result,
) {
None => None,
Some(branch) => {
did_change = true;
Some((*label, info.clone(), branch.clone()))
}
}
}),
arena,
);
if opt_default.is_some() || did_change {
let default_branch = (
default_branch.0.clone(),
opt_default.unwrap_or(default_branch.1),
);
let branches = if did_change {
let new = Vec::from_iter_in(
opt_branches.into_iter().zip(branches.iter()).map(
|(opt_branch, branch)| match opt_branch {
None => branch.clone(),
Some(new_branch) => new_branch,
},
),
arena,
);
new.into_bump_slice()
} else {
branches
};
Some(arena.alloc(Switch {
cond_symbol: *cond_symbol,
cond_layout: *cond_layout,
default_branch,
branches,
ret_layout: *ret_layout,
}))
} else {
None
}
}
Refcounting(modify, cont) => {
match insert_jumps(
arena,
cont,
goal_id,
needle,
needle_arguments,
needle_result,
) {
Some(cont) => Some(arena.alloc(Refcounting(*modify, cont))),
None => None,
}
}
Expect {
condition,
region,
lookups,
layouts,
remainder,
} => match insert_jumps(
arena,
remainder,
goal_id,
needle,
needle_arguments,
needle_result,
) {
Some(cont) => Some(arena.alloc(Expect {
condition: *condition,
region: *region,
lookups,
layouts,
remainder: cont,
})),
None => None,
},
Ret(_) => None,
Jump(_, _) => None,
RuntimeError(_) => None,
}
}