mirror of
https://github.com/roc-lang/roc.git
synced 2025-09-26 21:39:07 +00:00
changes after review
This commit is contained in:
parent
0b03a0bc26
commit
654cf7b861
17 changed files with 142 additions and 123 deletions
|
@ -10,7 +10,7 @@ use crate::layout::{
|
|||
};
|
||||
use bumpalo::collections::Vec;
|
||||
use bumpalo::Bump;
|
||||
use roc_collections::{MutMap, VecSet};
|
||||
use roc_collections::{MutMap, VecMap, VecSet};
|
||||
use roc_module::low_level::LowLevel;
|
||||
use roc_module::symbol::{IdentIds, ModuleId, Symbol};
|
||||
|
||||
|
@ -423,15 +423,6 @@ impl TrmcCandidateSet {
|
|||
self.active.insert(call);
|
||||
}
|
||||
|
||||
fn extend(&mut self, other: Self) {
|
||||
self.confirmed.keep_if_in_either(other.confirmed);
|
||||
self.invalid.keep_if_in_either(other.invalid);
|
||||
self.active.keep_if_in_either(other.active);
|
||||
|
||||
self.active.retain(|k| !self.invalid.contains(k));
|
||||
self.confirmed.retain(|k| !self.invalid.contains(k));
|
||||
}
|
||||
|
||||
fn retain<F>(&mut self, keep: F)
|
||||
where
|
||||
F: Fn(&Symbol) -> bool,
|
||||
|
@ -443,7 +434,7 @@ impl TrmcCandidateSet {
|
|||
}
|
||||
|
||||
self.active.retain(|k| !self.invalid.contains(k));
|
||||
self.confirmed.retain(|k| !self.invalid.contains(k));
|
||||
debug_assert!(!self.confirmed.iter().any(|x| self.invalid.contains(x)));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -465,26 +456,31 @@ where
|
|||
return VecSet::default();
|
||||
}
|
||||
|
||||
trmc_candidates_help(proc.name, &proc.body, TrmcCandidateSet::default()).confirmed
|
||||
let mut candidate_set = TrmcCandidateSet::default();
|
||||
trmc_candidates_help(proc.name, &proc.body, &mut candidate_set);
|
||||
candidate_set.confirmed
|
||||
}
|
||||
|
||||
fn trmc_candidates_help<'a>(
|
||||
function_name: LambdaName,
|
||||
stmt: &'_ Stmt<'a>,
|
||||
mut candidates: TrmcCandidateSet,
|
||||
) -> TrmcCandidateSet {
|
||||
candidates: &mut TrmcCandidateSet,
|
||||
) {
|
||||
// if this stmt is the literal tail tag application and return, then this is a TRMC opportunity
|
||||
if let Some(cons_info) = TrmcEnv::is_terminal_constructor(stmt) {
|
||||
// must use the result of a recursive call directly as an argument
|
||||
// we pick the (syntactically) first one
|
||||
for recursive_call in candidates.active.iter() {
|
||||
if cons_info.arguments.contains(recursive_call) {
|
||||
return TrmcCandidateSet {
|
||||
confirmed: VecSet::singleton(*recursive_call),
|
||||
active: VecSet::default(),
|
||||
invalid: candidates.invalid,
|
||||
};
|
||||
}
|
||||
// the tag application must directly use the result of the recursive call
|
||||
let recursive_call = candidates
|
||||
.active
|
||||
.iter()
|
||||
.copied()
|
||||
.find(|call| cons_info.arguments.contains(call));
|
||||
|
||||
// if we find a usage, this is a confirmed TRMC call
|
||||
if let Some(recursive_call) = recursive_call {
|
||||
candidates.active.remove(&recursive_call);
|
||||
candidates.confirmed.insert(recursive_call);
|
||||
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -511,15 +507,9 @@ fn trmc_candidates_help<'a>(
|
|||
.map(|(_, _, stmt)| stmt)
|
||||
.chain([default_branch.1]);
|
||||
|
||||
let mut accum = candidates.clone();
|
||||
|
||||
for next in it {
|
||||
let x = trmc_candidates_help(function_name, next, candidates.clone());
|
||||
|
||||
accum.extend(x);
|
||||
trmc_candidates_help(function_name, next, candidates);
|
||||
}
|
||||
|
||||
accum
|
||||
}
|
||||
Stmt::Refcounting(_, next) => trmc_candidates_help(function_name, next, candidates),
|
||||
Stmt::Expect { remainder, .. }
|
||||
|
@ -528,26 +518,63 @@ fn trmc_candidates_help<'a>(
|
|||
Stmt::Join {
|
||||
body, remainder, ..
|
||||
} => {
|
||||
let mut x = trmc_candidates_help(function_name, body, candidates.clone());
|
||||
let y = trmc_candidates_help(function_name, remainder, candidates.clone());
|
||||
|
||||
x.extend(y);
|
||||
|
||||
x
|
||||
trmc_candidates_help(function_name, body, candidates);
|
||||
trmc_candidates_help(function_name, remainder, candidates);
|
||||
}
|
||||
Stmt::Ret(_) | Stmt::Jump(_, _) | Stmt::Crash(_, _) => candidates,
|
||||
Stmt::Ret(_) | Stmt::Jump(_, _) | Stmt::Crash(_, _) => { /* terminal */ }
|
||||
}
|
||||
}
|
||||
|
||||
// TRMC (tail recursion modulo constructor) is an optimization for some recursive functions that return a recursive data type. The most basic example is a repeat function on linked lists:
|
||||
//
|
||||
// ```roc
|
||||
// LinkedList a : [ Nil, Cons a (LinkedList a) ]
|
||||
//
|
||||
// repeat : a, Nat -> LinkedList a
|
||||
// repeat = \element, n ->
|
||||
// when n is
|
||||
// 0 -> Nil
|
||||
// _ -> Cons element (repeat element (n - 1))
|
||||
// ```
|
||||
//
|
||||
// This function is recursive, but cannot use standard tail-call elimintation, because the recursive call is not in tail position (i.e. the last thing happening before a return). Rather the recursive call is an argument to a constructor of the recursive output type. This means that `repeat n` will creat `n` stack frames. For big inputs, a stack overflow is inevitable.
|
||||
//
|
||||
// But there is a trick: TRMC. Using TRMC and join points, we are able to convert this function into a loop, which uses only one stack frame for the whole process.
|
||||
//
|
||||
// ```pseudo-roc
|
||||
// repeat : a, Nat -> LinkedList a
|
||||
// repeat = \initialElement, initialN ->
|
||||
// joinpoint trmc = \element, n, hole, head ->
|
||||
// when n is
|
||||
// 0 ->
|
||||
// # write the value `Nil` into the hole
|
||||
// *hole = Nil
|
||||
// # dereference (load from) the pointer to the first element
|
||||
// *head
|
||||
//
|
||||
// _ ->
|
||||
// *hole = Cons element NULL
|
||||
// newHole = &hole.Cons.1
|
||||
// jump trmc element (n - 1) newHole head
|
||||
// in
|
||||
// # creates a stack allocation, gives a pointer to that stack allocation
|
||||
// initial : Ptr (LinkedList a) = #alloca NULL
|
||||
// jump trmc initialElement initialN initial initial
|
||||
// ```
|
||||
//
|
||||
// The functionality here figures out whether this transformation can be applied in valid way, and then performs the transformation.
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct TrmcEnv<'a> {
|
||||
/// Current hole to fill
|
||||
hole_symbol: Symbol,
|
||||
/// Pointer to the first constructor ("the head of the list")
|
||||
head_symbol: Symbol,
|
||||
joinpoint_id: JoinPointId,
|
||||
return_layout: InLayout<'a>,
|
||||
ptr_return_layout: InLayout<'a>,
|
||||
|
||||
trmc_calls: MutMap<Symbol, Option<Call<'a>>>,
|
||||
trmc_calls: VecMap<Symbol, Option<Call<'a>>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
@ -599,7 +626,7 @@ impl<'a> TrmcEnv<'a> {
|
|||
fn is_recursive_call(call: &Call<'a>, lambda_name: LambdaName<'_>) -> bool {
|
||||
match call.call_type {
|
||||
CallType::ByName { name, .. } => {
|
||||
// TODO are there other restrictions?
|
||||
// because we do not allow polymorphic recursion, this is the only constraint
|
||||
name == lambda_name
|
||||
}
|
||||
CallType::Foreign { .. } | CallType::LowLevel { .. } | CallType::HigherOrder(_) => {
|
||||
|
@ -617,7 +644,6 @@ impl<'a> TrmcEnv<'a> {
|
|||
let ptr_write = Call {
|
||||
call_type: crate::ir::CallType::LowLevel {
|
||||
op: LowLevel::PtrStore,
|
||||
// update_mode: env.next_update_mode_id(),
|
||||
update_mode: UpdateModeId::BACKEND_DUMMY,
|
||||
},
|
||||
arguments: env.arena.alloc([ptr, value]),
|
||||
|
@ -639,9 +665,9 @@ impl<'a> TrmcEnv<'a> {
|
|||
let arena = env.arena;
|
||||
let return_layout = proc.ret_layout;
|
||||
|
||||
let mut joinpoint_parameters = Vec::with_capacity_in(proc.args.len() + 1, env.arena);
|
||||
let mut joinpoint_parameters = Vec::with_capacity_in(proc.args.len() + 2, env.arena);
|
||||
let mut new_proc_arguments = Vec::with_capacity_in(proc.args.len(), env.arena);
|
||||
let mut jump_arguments = Vec::with_capacity_in(proc.args.len() + 1, env.arena);
|
||||
let mut jump_arguments = Vec::with_capacity_in(proc.args.len() + 2, env.arena);
|
||||
|
||||
for (i, (layout, old_symbol)) in proc.args.iter().enumerate() {
|
||||
let symbol = env.named_unique_symbol(&format!("arg_{i}"));
|
||||
|
@ -670,7 +696,7 @@ impl<'a> TrmcEnv<'a> {
|
|||
|
||||
let call = Call {
|
||||
call_type: CallType::LowLevel {
|
||||
op: LowLevel::PtrToStackValue,
|
||||
op: LowLevel::Alloca,
|
||||
update_mode: UpdateModeId::BACKEND_DUMMY,
|
||||
},
|
||||
arguments: arena.alloc([null_symbol]),
|
||||
|
@ -744,9 +770,13 @@ impl<'a> TrmcEnv<'a> {
|
|||
|
||||
match stmt {
|
||||
Stmt::Let(symbol, expr, layout, next) => {
|
||||
// if this is a TRMC call,
|
||||
// if this is a TRMC call, remember what the call looks like, so we can turn it
|
||||
// into a jump later. The call is then removed from the Stmt
|
||||
if let Some(opt_call) = self.trmc_calls.get_mut(symbol) {
|
||||
debug_assert!(opt_call.is_none());
|
||||
debug_assert!(
|
||||
opt_call.is_none(),
|
||||
"didn't expect to visit call again since symbols are unique"
|
||||
);
|
||||
|
||||
let call = match expr {
|
||||
Expr::Call(call) => call,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue