changes after review

This commit is contained in:
Folkert 2023-06-23 22:36:21 +02:00
parent 0b03a0bc26
commit 654cf7b861
No known key found for this signature in database
GPG key ID: 1F17F6FFD112B97C
17 changed files with 142 additions and 123 deletions

View file

@ -10,7 +10,7 @@ use crate::layout::{
};
use bumpalo::collections::Vec;
use bumpalo::Bump;
use roc_collections::{MutMap, VecSet};
use roc_collections::{MutMap, VecMap, VecSet};
use roc_module::low_level::LowLevel;
use roc_module::symbol::{IdentIds, ModuleId, Symbol};
@ -423,15 +423,6 @@ impl TrmcCandidateSet {
self.active.insert(call);
}
fn extend(&mut self, other: Self) {
self.confirmed.keep_if_in_either(other.confirmed);
self.invalid.keep_if_in_either(other.invalid);
self.active.keep_if_in_either(other.active);
self.active.retain(|k| !self.invalid.contains(k));
self.confirmed.retain(|k| !self.invalid.contains(k));
}
fn retain<F>(&mut self, keep: F)
where
F: Fn(&Symbol) -> bool,
@ -443,7 +434,7 @@ impl TrmcCandidateSet {
}
self.active.retain(|k| !self.invalid.contains(k));
self.confirmed.retain(|k| !self.invalid.contains(k));
debug_assert!(!self.confirmed.iter().any(|x| self.invalid.contains(x)));
}
}
@ -465,26 +456,31 @@ where
return VecSet::default();
}
trmc_candidates_help(proc.name, &proc.body, TrmcCandidateSet::default()).confirmed
let mut candidate_set = TrmcCandidateSet::default();
trmc_candidates_help(proc.name, &proc.body, &mut candidate_set);
candidate_set.confirmed
}
fn trmc_candidates_help<'a>(
function_name: LambdaName,
stmt: &'_ Stmt<'a>,
mut candidates: TrmcCandidateSet,
) -> TrmcCandidateSet {
candidates: &mut TrmcCandidateSet,
) {
// if this stmt is the literal tail tag application and return, then this is a TRMC opportunity
if let Some(cons_info) = TrmcEnv::is_terminal_constructor(stmt) {
// must use the result of a recursive call directly as an argument
// we pick the (syntactically) first one
for recursive_call in candidates.active.iter() {
if cons_info.arguments.contains(recursive_call) {
return TrmcCandidateSet {
confirmed: VecSet::singleton(*recursive_call),
active: VecSet::default(),
invalid: candidates.invalid,
};
}
// the tag application must directly use the result of the recursive call
let recursive_call = candidates
.active
.iter()
.copied()
.find(|call| cons_info.arguments.contains(call));
// if we find a usage, this is a confirmed TRMC call
if let Some(recursive_call) = recursive_call {
candidates.active.remove(&recursive_call);
candidates.confirmed.insert(recursive_call);
return;
}
}
@ -511,15 +507,9 @@ fn trmc_candidates_help<'a>(
.map(|(_, _, stmt)| stmt)
.chain([default_branch.1]);
let mut accum = candidates.clone();
for next in it {
let x = trmc_candidates_help(function_name, next, candidates.clone());
accum.extend(x);
trmc_candidates_help(function_name, next, candidates);
}
accum
}
Stmt::Refcounting(_, next) => trmc_candidates_help(function_name, next, candidates),
Stmt::Expect { remainder, .. }
@ -528,26 +518,63 @@ fn trmc_candidates_help<'a>(
Stmt::Join {
body, remainder, ..
} => {
let mut x = trmc_candidates_help(function_name, body, candidates.clone());
let y = trmc_candidates_help(function_name, remainder, candidates.clone());
x.extend(y);
x
trmc_candidates_help(function_name, body, candidates);
trmc_candidates_help(function_name, remainder, candidates);
}
Stmt::Ret(_) | Stmt::Jump(_, _) | Stmt::Crash(_, _) => candidates,
Stmt::Ret(_) | Stmt::Jump(_, _) | Stmt::Crash(_, _) => { /* terminal */ }
}
}
// TRMC (tail recursion modulo constructor) is an optimization for some recursive functions that return a recursive data type. The most basic example is a repeat function on linked lists:
//
// ```roc
// LinkedList a : [ Nil, Cons a (LinkedList a) ]
//
// repeat : a, Nat -> LinkedList a
// repeat = \element, n ->
// when n is
// 0 -> Nil
// _ -> Cons element (repeat element (n - 1))
// ```
//
// This function is recursive, but cannot use standard tail-call elimintation, because the recursive call is not in tail position (i.e. the last thing happening before a return). Rather the recursive call is an argument to a constructor of the recursive output type. This means that `repeat n` will creat `n` stack frames. For big inputs, a stack overflow is inevitable.
//
// But there is a trick: TRMC. Using TRMC and join points, we are able to convert this function into a loop, which uses only one stack frame for the whole process.
//
// ```pseudo-roc
// repeat : a, Nat -> LinkedList a
// repeat = \initialElement, initialN ->
// joinpoint trmc = \element, n, hole, head ->
// when n is
// 0 ->
// # write the value `Nil` into the hole
// *hole = Nil
// # dereference (load from) the pointer to the first element
// *head
//
// _ ->
// *hole = Cons element NULL
// newHole = &hole.Cons.1
// jump trmc element (n - 1) newHole head
// in
// # creates a stack allocation, gives a pointer to that stack allocation
// initial : Ptr (LinkedList a) = #alloca NULL
// jump trmc initialElement initialN initial initial
// ```
//
// The functionality here figures out whether this transformation can be applied in valid way, and then performs the transformation.
#[derive(Clone)]
pub(crate) struct TrmcEnv<'a> {
/// Current hole to fill
hole_symbol: Symbol,
/// Pointer to the first constructor ("the head of the list")
head_symbol: Symbol,
joinpoint_id: JoinPointId,
return_layout: InLayout<'a>,
ptr_return_layout: InLayout<'a>,
trmc_calls: MutMap<Symbol, Option<Call<'a>>>,
trmc_calls: VecMap<Symbol, Option<Call<'a>>>,
}
#[derive(Debug)]
@ -599,7 +626,7 @@ impl<'a> TrmcEnv<'a> {
fn is_recursive_call(call: &Call<'a>, lambda_name: LambdaName<'_>) -> bool {
match call.call_type {
CallType::ByName { name, .. } => {
// TODO are there other restrictions?
// because we do not allow polymorphic recursion, this is the only constraint
name == lambda_name
}
CallType::Foreign { .. } | CallType::LowLevel { .. } | CallType::HigherOrder(_) => {
@ -617,7 +644,6 @@ impl<'a> TrmcEnv<'a> {
let ptr_write = Call {
call_type: crate::ir::CallType::LowLevel {
op: LowLevel::PtrStore,
// update_mode: env.next_update_mode_id(),
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: env.arena.alloc([ptr, value]),
@ -639,9 +665,9 @@ impl<'a> TrmcEnv<'a> {
let arena = env.arena;
let return_layout = proc.ret_layout;
let mut joinpoint_parameters = Vec::with_capacity_in(proc.args.len() + 1, env.arena);
let mut joinpoint_parameters = Vec::with_capacity_in(proc.args.len() + 2, env.arena);
let mut new_proc_arguments = Vec::with_capacity_in(proc.args.len(), env.arena);
let mut jump_arguments = Vec::with_capacity_in(proc.args.len() + 1, env.arena);
let mut jump_arguments = Vec::with_capacity_in(proc.args.len() + 2, env.arena);
for (i, (layout, old_symbol)) in proc.args.iter().enumerate() {
let symbol = env.named_unique_symbol(&format!("arg_{i}"));
@ -670,7 +696,7 @@ impl<'a> TrmcEnv<'a> {
let call = Call {
call_type: CallType::LowLevel {
op: LowLevel::PtrToStackValue,
op: LowLevel::Alloca,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: arena.alloc([null_symbol]),
@ -744,9 +770,13 @@ impl<'a> TrmcEnv<'a> {
match stmt {
Stmt::Let(symbol, expr, layout, next) => {
// if this is a TRMC call,
// if this is a TRMC call, remember what the call looks like, so we can turn it
// into a jump later. The call is then removed from the Stmt
if let Some(opt_call) = self.trmc_calls.get_mut(symbol) {
debug_assert!(opt_call.is_none());
debug_assert!(
opt_call.is_none(),
"didn't expect to visit call again since symbols are unique"
);
let call = match expr {
Expr::Call(call) => call,