mirror of
https://github.com/roc-lang/roc.git
synced 2025-09-29 14:54:47 +00:00

Figuring out what this module was doing, and why, took me a bit less than half an hour. We should document what's happening for others in the future so they don't need to follow up on Zulip necessarily.
629 lines
19 KiB
Rust
629 lines
19 KiB
Rust
//! This module inserts reset/reuse statements into the mono IR. These statements provide an
|
|
//! opportunity to reduce memory pressure by reusing memory slots of non-shared values. From the
|
|
//! introduction of the relevant paper:
|
|
//!
|
|
//! > [We] have added two additional instructions to our IR: `let y = reset x` and
|
|
//! > `let z = (reuse y in ctor_i w)`. The two instructions are used together; if `x`
|
|
//! > is a shared value, then `y` is set to a special reference, and the reuse instruction
|
|
//! > just allocates a new constructor value `ctor_i w`. If `x` is not shared, then reset
|
|
//! > decrements the reference counters of the components of `x`, and `y` is set to `x`.
|
|
//! > Then, reuse reuses the memory cell used by `x` to store the constructor value `ctor_i w`.
|
|
//!
|
|
//! See also
|
|
//! - [Counting Immutable Beans](https://arxiv.org/pdf/1908.05647.pdf) (Ullrich and Moura, 2020)
|
|
//! - [The lean implementation](https://github.com/leanprover/lean4/blob/master/src/Lean/Compiler/IR/ResetReuse.lean)
|
|
|
|
use crate::inc_dec::{collect_stmt, occurring_variables_expr, JPLiveVarMap, LiveVarSet};
|
|
use crate::ir::{
|
|
BranchInfo, Call, Expr, ListLiteralElement, Proc, Stmt, UpdateModeId, UpdateModeIds,
|
|
};
|
|
use crate::layout::{Layout, TagIdIntType, UnionLayout};
|
|
use bumpalo::collections::Vec;
|
|
use bumpalo::Bump;
|
|
use roc_collections::all::MutSet;
|
|
use roc_module::symbol::{IdentIds, ModuleId, Symbol};
|
|
|
|
pub fn insert_reset_reuse<'a, 'i>(
|
|
arena: &'a Bump,
|
|
home: ModuleId,
|
|
ident_ids: &'i mut IdentIds,
|
|
update_mode_ids: &'i mut UpdateModeIds,
|
|
mut proc: Proc<'a>,
|
|
) -> Proc<'a> {
|
|
let mut env = Env {
|
|
arena,
|
|
home,
|
|
ident_ids,
|
|
update_mode_ids,
|
|
jp_live_vars: Default::default(),
|
|
};
|
|
|
|
let new_body = function_r(&mut env, arena.alloc(proc.body));
|
|
proc.body = new_body.clone();
|
|
|
|
proc
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
struct CtorInfo<'a> {
|
|
id: TagIdIntType,
|
|
layout: UnionLayout<'a>,
|
|
}
|
|
|
|
fn may_reuse(tag_layout: UnionLayout, tag_id: TagIdIntType, other: &CtorInfo) -> bool {
|
|
if tag_layout != other.layout {
|
|
return false;
|
|
}
|
|
|
|
// we should not get here if the tag we matched on is represented as NULL
|
|
debug_assert!(!tag_layout.tag_is_null(other.id as _));
|
|
|
|
// furthermore, we can only use the memory if the tag we're creating is non-NULL
|
|
!tag_layout.tag_is_null(tag_id)
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
struct Env<'a, 'i> {
|
|
arena: &'a Bump,
|
|
|
|
/// required for creating new `Symbol`s
|
|
home: ModuleId,
|
|
ident_ids: &'i mut IdentIds,
|
|
update_mode_ids: &'i mut UpdateModeIds,
|
|
|
|
jp_live_vars: JPLiveVarMap,
|
|
}
|
|
|
|
impl<'a, 'i> Env<'a, 'i> {
|
|
fn unique_symbol(&mut self) -> Symbol {
|
|
let ident_id = self.ident_ids.gen_unique();
|
|
|
|
Symbol::new(self.home, ident_id)
|
|
}
|
|
}
|
|
|
|
fn function_s<'a, 'i>(
|
|
env: &mut Env<'a, 'i>,
|
|
w: Opportunity,
|
|
c: &CtorInfo<'a>,
|
|
stmt: &'a Stmt<'a>,
|
|
) -> &'a Stmt<'a> {
|
|
use Stmt::*;
|
|
|
|
let arena = env.arena;
|
|
|
|
match stmt {
|
|
Let(symbol, expr, layout, continuation) => match expr {
|
|
Expr::Tag {
|
|
tag_layout,
|
|
tag_id,
|
|
tag_name,
|
|
arguments,
|
|
} if may_reuse(*tag_layout, *tag_id, c) => {
|
|
// for now, always overwrite the tag ID just to be sure
|
|
let update_tag_id = true;
|
|
|
|
let new_expr = Expr::Reuse {
|
|
symbol: w.symbol,
|
|
update_mode: w.update_mode,
|
|
update_tag_id,
|
|
tag_layout: *tag_layout,
|
|
tag_id: *tag_id,
|
|
tag_name: tag_name.clone(),
|
|
arguments,
|
|
};
|
|
let new_stmt = Let(*symbol, new_expr, *layout, continuation);
|
|
|
|
arena.alloc(new_stmt)
|
|
}
|
|
_ => {
|
|
let rest = function_s(env, w, c, continuation);
|
|
let new_stmt = Let(*symbol, expr.clone(), *layout, rest);
|
|
|
|
arena.alloc(new_stmt)
|
|
}
|
|
},
|
|
Join {
|
|
id,
|
|
parameters,
|
|
body,
|
|
remainder,
|
|
} => {
|
|
let id = *id;
|
|
let body: &Stmt = *body;
|
|
let new_body = function_s(env, w, c, body);
|
|
|
|
let new_join = if std::ptr::eq(body, new_body) || body == new_body {
|
|
// the join point body will consume w
|
|
Join {
|
|
id,
|
|
parameters,
|
|
body: new_body,
|
|
remainder,
|
|
}
|
|
} else {
|
|
let new_remainder = function_s(env, w, c, remainder);
|
|
|
|
Join {
|
|
id,
|
|
parameters,
|
|
body,
|
|
remainder: new_remainder,
|
|
}
|
|
};
|
|
|
|
arena.alloc(new_join)
|
|
}
|
|
Switch {
|
|
cond_symbol,
|
|
cond_layout,
|
|
branches,
|
|
default_branch,
|
|
ret_layout,
|
|
} => {
|
|
let mut new_branches = Vec::with_capacity_in(branches.len(), arena);
|
|
new_branches.extend(branches.iter().map(|(tag, info, body)| {
|
|
let new_body = function_s(env, w, c, body);
|
|
|
|
(*tag, info.clone(), new_body.clone())
|
|
}));
|
|
|
|
let new_default = function_s(env, w, c, default_branch.1);
|
|
|
|
let new_switch = Switch {
|
|
cond_symbol: *cond_symbol,
|
|
cond_layout: *cond_layout,
|
|
branches: new_branches.into_bump_slice(),
|
|
default_branch: (default_branch.0.clone(), new_default),
|
|
ret_layout: *ret_layout,
|
|
};
|
|
|
|
arena.alloc(new_switch)
|
|
}
|
|
Refcounting(op, continuation) => {
|
|
let continuation: &Stmt = *continuation;
|
|
let new_continuation = function_s(env, w, c, continuation);
|
|
|
|
if std::ptr::eq(continuation, new_continuation) || continuation == new_continuation {
|
|
stmt
|
|
} else {
|
|
let new_refcounting = Refcounting(*op, new_continuation);
|
|
|
|
arena.alloc(new_refcounting)
|
|
}
|
|
}
|
|
Ret(_) | Jump(_, _) | RuntimeError(_) => stmt,
|
|
}
|
|
}
|
|
|
|
#[derive(Clone, Copy)]
|
|
struct Opportunity {
|
|
symbol: Symbol,
|
|
update_mode: UpdateModeId,
|
|
}
|
|
|
|
fn try_function_s<'a, 'i>(
|
|
env: &mut Env<'a, 'i>,
|
|
x: Symbol,
|
|
c: &CtorInfo<'a>,
|
|
stmt: &'a Stmt<'a>,
|
|
) -> &'a Stmt<'a> {
|
|
let w = Opportunity {
|
|
symbol: env.unique_symbol(),
|
|
update_mode: env.update_mode_ids.next_id(),
|
|
};
|
|
|
|
let new_stmt = function_s(env, w, c, stmt);
|
|
|
|
if std::ptr::eq(stmt, new_stmt) || stmt == new_stmt {
|
|
stmt
|
|
} else {
|
|
insert_reset(env, w, x, c.layout, new_stmt)
|
|
}
|
|
}
|
|
|
|
fn insert_reset<'a>(
|
|
env: &mut Env<'a, '_>,
|
|
w: Opportunity,
|
|
x: Symbol,
|
|
union_layout: UnionLayout<'a>,
|
|
mut stmt: &'a Stmt<'a>,
|
|
) -> &'a Stmt<'a> {
|
|
use crate::ir::Expr::*;
|
|
|
|
let mut stack = vec![];
|
|
|
|
while let Stmt::Let(symbol, expr, expr_layout, rest) = stmt {
|
|
match &expr {
|
|
StructAtIndex { .. } | GetTagId { .. } | UnionAtIndex { .. } => {
|
|
stack.push((symbol, expr, expr_layout));
|
|
stmt = rest;
|
|
}
|
|
Literal(_)
|
|
| Call(_)
|
|
| Tag { .. }
|
|
| Struct(_)
|
|
| Array { .. }
|
|
| EmptyArray
|
|
| Reuse { .. }
|
|
| Reset { .. }
|
|
| RuntimeErrorFunction(_) => break,
|
|
}
|
|
}
|
|
|
|
let reset_expr = Expr::Reset {
|
|
symbol: x,
|
|
update_mode: w.update_mode,
|
|
};
|
|
|
|
let layout = Layout::Union(union_layout);
|
|
|
|
stmt = env
|
|
.arena
|
|
.alloc(Stmt::Let(w.symbol, reset_expr, layout, stmt));
|
|
|
|
for (symbol, expr, expr_layout) in stack.into_iter().rev() {
|
|
stmt = env
|
|
.arena
|
|
.alloc(Stmt::Let(*symbol, expr.clone(), *expr_layout, stmt));
|
|
}
|
|
|
|
stmt
|
|
}
|
|
|
|
fn function_d_finalize<'a, 'i>(
|
|
env: &mut Env<'a, 'i>,
|
|
x: Symbol,
|
|
c: &CtorInfo<'a>,
|
|
output: (&'a Stmt<'a>, bool),
|
|
) -> &'a Stmt<'a> {
|
|
let (stmt, x_live_in_stmt) = output;
|
|
if x_live_in_stmt {
|
|
stmt
|
|
} else {
|
|
try_function_s(env, x, c, stmt)
|
|
}
|
|
}
|
|
|
|
fn function_d_main<'a, 'i>(
|
|
env: &mut Env<'a, 'i>,
|
|
x: Symbol,
|
|
c: &CtorInfo<'a>,
|
|
stmt: &'a Stmt<'a>,
|
|
) -> (&'a Stmt<'a>, bool) {
|
|
use Stmt::*;
|
|
|
|
let arena = env.arena;
|
|
|
|
match stmt {
|
|
Let(symbol, expr, layout, continuation) => {
|
|
match expr {
|
|
Expr::Tag { arguments, .. } if arguments.iter().any(|s| *s == x) => {
|
|
// If the scrutinee `x` (the one that is providing memory) is being
|
|
// stored in a constructor, then reuse will probably not be able to reuse memory at runtime.
|
|
// It may work only if the new cell is consumed, but we ignore this case.
|
|
(stmt, true)
|
|
}
|
|
_ => {
|
|
let (b, found) = function_d_main(env, x, c, continuation);
|
|
|
|
// NOTE the &b != continuation is not found in the Lean source, but is required
|
|
// otherwise we observe the same symbol being reset twice
|
|
let mut result = MutSet::default();
|
|
if found
|
|
|| {
|
|
occurring_variables_expr(expr, &mut result);
|
|
!result.contains(&x)
|
|
}
|
|
|| &b != continuation
|
|
{
|
|
let let_stmt = Let(*symbol, expr.clone(), *layout, b);
|
|
|
|
(arena.alloc(let_stmt), found)
|
|
} else {
|
|
let b = try_function_s(env, x, c, b);
|
|
let let_stmt = Let(*symbol, expr.clone(), *layout, b);
|
|
|
|
(arena.alloc(let_stmt), found)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Switch {
|
|
cond_symbol,
|
|
cond_layout,
|
|
branches,
|
|
default_branch,
|
|
ret_layout,
|
|
} => {
|
|
if has_live_var(&env.jp_live_vars, stmt, x) {
|
|
// if `x` is live in `stmt`, we recursively process each branch
|
|
let mut new_branches = Vec::with_capacity_in(branches.len(), arena);
|
|
|
|
for (tag, info, body) in branches.iter() {
|
|
let temp = function_d_main(env, x, c, body);
|
|
let new_body = function_d_finalize(env, x, c, temp);
|
|
|
|
new_branches.push((*tag, info.clone(), new_body.clone()));
|
|
}
|
|
|
|
let new_default = {
|
|
let (info, body) = default_branch;
|
|
let temp = function_d_main(env, x, c, body);
|
|
let new_body = function_d_finalize(env, x, c, temp);
|
|
|
|
(info.clone(), new_body)
|
|
};
|
|
|
|
let new_switch = Switch {
|
|
cond_symbol: *cond_symbol,
|
|
cond_layout: *cond_layout,
|
|
branches: new_branches.into_bump_slice(),
|
|
default_branch: new_default,
|
|
ret_layout: *ret_layout,
|
|
};
|
|
|
|
(arena.alloc(new_switch), true)
|
|
} else {
|
|
(stmt, false)
|
|
}
|
|
}
|
|
Refcounting(modify_rc, continuation) => {
|
|
let (b, found) = function_d_main(env, x, c, continuation);
|
|
|
|
if found || modify_rc.get_symbol() != x {
|
|
let refcounting = Refcounting(*modify_rc, b);
|
|
|
|
(arena.alloc(refcounting), found)
|
|
} else {
|
|
let b = try_function_s(env, x, c, b);
|
|
let refcounting = Refcounting(*modify_rc, b);
|
|
|
|
(arena.alloc(refcounting), found)
|
|
}
|
|
}
|
|
Join {
|
|
id,
|
|
parameters,
|
|
body,
|
|
remainder,
|
|
} => {
|
|
env.jp_live_vars.insert(*id, LiveVarSet::default());
|
|
|
|
let body_live_vars = collect_stmt(body, &env.jp_live_vars, LiveVarSet::default());
|
|
|
|
env.jp_live_vars.insert(*id, body_live_vars);
|
|
|
|
let (b, found) = function_d_main(env, x, c, remainder);
|
|
|
|
let (v, _found) = function_d_main(env, x, c, body);
|
|
|
|
env.jp_live_vars.remove(id);
|
|
|
|
// If `found' == true`, then `Dmain b` must also have returned `(b, true)` since
|
|
// we assume the IR does not have dead join points. So, if `x` is live in `j` (i.e., `v`),
|
|
// then it must also live in `b` since `j` is reachable from `b` with a `jmp`.
|
|
// On the other hand, `x` may be live in `b` but dead in `j` (i.e., `v`). -/
|
|
let new_join = Join {
|
|
id: *id,
|
|
parameters,
|
|
body: v,
|
|
remainder: b,
|
|
};
|
|
|
|
(arena.alloc(new_join), found)
|
|
}
|
|
Ret(_) | Jump(_, _) | RuntimeError(_) => (stmt, has_live_var(&env.jp_live_vars, stmt, x)),
|
|
}
|
|
}
|
|
|
|
fn function_d<'a, 'i>(
|
|
env: &mut Env<'a, 'i>,
|
|
x: Symbol,
|
|
c: &CtorInfo<'a>,
|
|
stmt: &'a Stmt<'a>,
|
|
) -> &'a Stmt<'a> {
|
|
let temp = function_d_main(env, x, c, stmt);
|
|
|
|
function_d_finalize(env, x, c, temp)
|
|
}
|
|
|
|
fn function_r_branch_body<'a, 'i>(
|
|
env: &mut Env<'a, 'i>,
|
|
info: &BranchInfo<'a>,
|
|
body: &'a Stmt<'a>,
|
|
) -> &'a Stmt<'a> {
|
|
let temp = function_r(env, body);
|
|
|
|
match info {
|
|
BranchInfo::None => temp,
|
|
BranchInfo::Constructor {
|
|
scrutinee,
|
|
layout,
|
|
tag_id,
|
|
} => match layout {
|
|
Layout::Union(UnionLayout::NonRecursive(_)) => temp,
|
|
Layout::Union(union_layout) if !union_layout.tag_is_null(*tag_id) => {
|
|
let ctor_info = CtorInfo {
|
|
layout: *union_layout,
|
|
id: *tag_id,
|
|
};
|
|
function_d(env, *scrutinee, &ctor_info, temp)
|
|
}
|
|
_ => temp,
|
|
},
|
|
}
|
|
}
|
|
|
|
fn function_r<'a, 'i>(env: &mut Env<'a, 'i>, stmt: &'a Stmt<'a>) -> &'a Stmt<'a> {
|
|
use Stmt::*;
|
|
|
|
let arena = env.arena;
|
|
|
|
match stmt {
|
|
Switch {
|
|
cond_symbol,
|
|
cond_layout,
|
|
branches,
|
|
default_branch,
|
|
ret_layout,
|
|
} => {
|
|
let mut new_branches = Vec::with_capacity_in(branches.len(), arena);
|
|
|
|
for (tag, info, body) in branches.iter() {
|
|
let new_body = function_r_branch_body(env, info, body);
|
|
|
|
new_branches.push((*tag, info.clone(), new_body.clone()));
|
|
}
|
|
|
|
let new_default = {
|
|
let (info, body) = default_branch;
|
|
|
|
let new_body = function_r_branch_body(env, info, body);
|
|
|
|
(info.clone(), new_body)
|
|
};
|
|
|
|
let new_switch = Switch {
|
|
cond_symbol: *cond_symbol,
|
|
cond_layout: *cond_layout,
|
|
branches: new_branches.into_bump_slice(),
|
|
default_branch: new_default,
|
|
ret_layout: *ret_layout,
|
|
};
|
|
|
|
arena.alloc(new_switch)
|
|
}
|
|
|
|
Join {
|
|
id,
|
|
parameters,
|
|
body,
|
|
remainder,
|
|
} => {
|
|
env.jp_live_vars.insert(*id, LiveVarSet::default());
|
|
|
|
let body_live_vars = collect_stmt(body, &env.jp_live_vars, LiveVarSet::default());
|
|
|
|
env.jp_live_vars.insert(*id, body_live_vars);
|
|
|
|
let b = function_r(env, remainder);
|
|
|
|
let v = function_r(env, body);
|
|
|
|
env.jp_live_vars.remove(id);
|
|
|
|
let join = Join {
|
|
id: *id,
|
|
parameters,
|
|
body: v,
|
|
remainder: b,
|
|
};
|
|
|
|
arena.alloc(join)
|
|
}
|
|
|
|
Let(symbol, expr, layout, continuation) => {
|
|
let b = function_r(env, continuation);
|
|
|
|
arena.alloc(Let(*symbol, expr.clone(), *layout, b))
|
|
}
|
|
Refcounting(modify_rc, continuation) => {
|
|
let b = function_r(env, continuation);
|
|
|
|
arena.alloc(Refcounting(*modify_rc, b))
|
|
}
|
|
|
|
Ret(_) | Jump(_, _) | RuntimeError(_) => {
|
|
// terminals
|
|
stmt
|
|
}
|
|
}
|
|
}
|
|
|
|
fn has_live_var<'a>(jp_live_vars: &JPLiveVarMap, stmt: &'a Stmt<'a>, needle: Symbol) -> bool {
|
|
use Stmt::*;
|
|
|
|
match stmt {
|
|
Let(s, e, _, c) => {
|
|
debug_assert_ne!(*s, needle);
|
|
has_live_var_expr(e, needle) || has_live_var(jp_live_vars, c, needle)
|
|
}
|
|
Switch { cond_symbol, .. } if *cond_symbol == needle => true,
|
|
Switch {
|
|
branches,
|
|
default_branch,
|
|
..
|
|
} => {
|
|
has_live_var(jp_live_vars, default_branch.1, needle)
|
|
|| branches
|
|
.iter()
|
|
.any(|(_, _, body)| has_live_var(jp_live_vars, body, needle))
|
|
}
|
|
Ret(s) => *s == needle,
|
|
Refcounting(modify_rc, cont) => {
|
|
modify_rc.get_symbol() == needle || has_live_var(jp_live_vars, cont, needle)
|
|
}
|
|
Join {
|
|
id,
|
|
parameters,
|
|
body,
|
|
remainder,
|
|
} => {
|
|
debug_assert!(parameters.iter().all(|p| p.symbol != needle));
|
|
|
|
let mut jp_live_vars = jp_live_vars.clone();
|
|
|
|
jp_live_vars.insert(*id, LiveVarSet::default());
|
|
|
|
let body_live_vars = collect_stmt(body, &jp_live_vars, LiveVarSet::default());
|
|
|
|
if body_live_vars.contains(&needle) {
|
|
return true;
|
|
}
|
|
|
|
jp_live_vars.insert(*id, body_live_vars);
|
|
|
|
has_live_var(&jp_live_vars, remainder, needle)
|
|
}
|
|
Jump(id, arguments) => {
|
|
arguments.iter().any(|s| *s == needle) || jp_live_vars[id].contains(&needle)
|
|
}
|
|
RuntimeError(_) => false,
|
|
}
|
|
}
|
|
|
|
fn has_live_var_expr<'a>(expr: &'a Expr<'a>, needle: Symbol) -> bool {
|
|
match expr {
|
|
Expr::Literal(_) => false,
|
|
Expr::Call(call) => has_live_var_call(call, needle),
|
|
Expr::Array { elems: fields, .. } => {
|
|
for element in fields.iter() {
|
|
if let ListLiteralElement::Symbol(s) = element {
|
|
if *s == needle {
|
|
return true;
|
|
}
|
|
}
|
|
}
|
|
|
|
false
|
|
}
|
|
Expr::Tag {
|
|
arguments: fields, ..
|
|
}
|
|
| Expr::Struct(fields) => fields.iter().any(|s| *s == needle),
|
|
Expr::StructAtIndex { structure, .. }
|
|
| Expr::GetTagId { structure, .. }
|
|
| Expr::UnionAtIndex { structure, .. } => *structure == needle,
|
|
Expr::EmptyArray => false,
|
|
Expr::Reuse {
|
|
symbol, arguments, ..
|
|
} => needle == *symbol || arguments.iter().any(|s| *s == needle),
|
|
Expr::Reset { symbol, .. } => needle == *symbol,
|
|
Expr::RuntimeErrorFunction(_) => false,
|
|
}
|
|
}
|
|
|
|
fn has_live_var_call<'a>(call: &'a Call<'a>, needle: Symbol) -> bool {
|
|
call.arguments.iter().any(|s| *s == needle)
|
|
}
|