mirror of
https://github.com/roc-lang/roc.git
synced 2025-09-26 21:39:07 +00:00
search for multiple TRMC opportunities
This commit is contained in:
parent
b349fca521
commit
c87519c209
8 changed files with 138 additions and 116 deletions
|
@ -10,7 +10,7 @@ use crate::layout::{
|
|||
};
|
||||
use bumpalo::collections::Vec;
|
||||
use bumpalo::Bump;
|
||||
use roc_collections::MutMap;
|
||||
use roc_collections::{MutMap, VecSet};
|
||||
use roc_module::low_level::LowLevel;
|
||||
use roc_module::symbol::{IdentIds, ModuleId, Symbol};
|
||||
|
||||
|
@ -53,8 +53,11 @@ pub fn apply_trmc<'a, 'i>(
|
|||
for proc in procs.values_mut() {
|
||||
use self::SelfRecursive::*;
|
||||
if let SelfRecursive(id) = proc.is_self_recursive {
|
||||
if crate::tail_recursion::is_trmc_candidate(env.interner, proc) {
|
||||
let new_proc = crate::tail_recursion::TrmcEnv::init(env, proc);
|
||||
let trmc_candidate_symbols = trmc_candidates(env.interner, proc);
|
||||
|
||||
if !trmc_candidate_symbols.is_empty() {
|
||||
let new_proc =
|
||||
crate::tail_recursion::TrmcEnv::init(env, proc, trmc_candidate_symbols);
|
||||
*proc = new_proc;
|
||||
} else {
|
||||
let mut args = Vec::with_capacity_in(proc.args.len(), arena);
|
||||
|
@ -402,7 +405,49 @@ fn insert_jumps<'a>(
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn is_trmc_candidate<'a, I>(interner: &'_ I, proc: &'_ Proc<'a>) -> bool
|
||||
#[derive(Debug, Clone, Default)]
|
||||
struct TrmcCandidateSet {
|
||||
/// Recursive calls for which we have found a TRMC opportunity
|
||||
confirmed: VecSet<Symbol>,
|
||||
/// Recursive calls that are (still) considered for TRMC
|
||||
active: VecSet<Symbol>,
|
||||
/// Recursive calls that are used in such a way that makes TRMC impossible
|
||||
invalid: VecSet<Symbol>,
|
||||
}
|
||||
|
||||
impl TrmcCandidateSet {
|
||||
fn insert(&mut self, call: Symbol) {
|
||||
// there really is no way it could have been inserted already
|
||||
debug_assert!(!self.invalid.contains(&call));
|
||||
|
||||
self.active.insert(call);
|
||||
}
|
||||
|
||||
fn extend(&mut self, other: Self) {
|
||||
self.confirmed.keep_if_in_either(other.confirmed);
|
||||
self.invalid.keep_if_in_either(other.invalid);
|
||||
self.active.keep_if_in_either(other.active);
|
||||
|
||||
self.active.retain(|k| !self.invalid.contains(k));
|
||||
self.confirmed.retain(|k| !self.invalid.contains(k));
|
||||
}
|
||||
|
||||
fn retain<F>(&mut self, keep: F)
|
||||
where
|
||||
F: Fn(&Symbol) -> bool,
|
||||
{
|
||||
for c in self.active.iter() {
|
||||
if !keep(c) {
|
||||
self.invalid.insert(*c);
|
||||
}
|
||||
}
|
||||
|
||||
self.active.retain(|k| !self.invalid.contains(k));
|
||||
self.confirmed.retain(|k| !self.invalid.contains(k));
|
||||
}
|
||||
}
|
||||
|
||||
fn trmc_candidates<'a, I>(interner: &'_ I, proc: &'_ Proc<'a>) -> VecSet<Symbol>
|
||||
where
|
||||
I: LayoutInterner<'a>,
|
||||
{
|
||||
|
@ -411,87 +456,50 @@ where
|
|||
proc.is_self_recursive,
|
||||
crate::ir::SelfRecursive::SelfRecursive(_)
|
||||
) {
|
||||
return false;
|
||||
return VecSet::default();
|
||||
}
|
||||
|
||||
// and return a recursive tag union
|
||||
if !matches!(interner.get_repr(proc.ret_layout), LayoutRepr::Union(union_layout) if union_layout.is_recursive())
|
||||
{
|
||||
return false;
|
||||
return VecSet::default();
|
||||
}
|
||||
|
||||
match has_cons_in_tail_position(&proc.body, proc.name, None) {
|
||||
SymbolUse::NotUsed | SymbolUse::Used => false,
|
||||
SymbolUse::TrmcOppotunity => true,
|
||||
}
|
||||
trmc_candidates_help(proc.name, &proc.body, TrmcCandidateSet::default()).confirmed
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord)]
|
||||
#[repr(C)]
|
||||
enum SymbolUse {
|
||||
#[default]
|
||||
NotUsed = 0,
|
||||
TrmcOppotunity = 1,
|
||||
Used = 2,
|
||||
}
|
||||
|
||||
impl SymbolUse {
|
||||
#[must_use]
|
||||
fn mappend(self, y: Self) -> Self {
|
||||
debug_assert_eq!(self.mappend_slow(y), Ord::max(self, y));
|
||||
|
||||
Ord::max(self, y)
|
||||
}
|
||||
|
||||
fn mappend_slow(self, y: Self) -> Self {
|
||||
use SymbolUse::*;
|
||||
|
||||
match (self, y) {
|
||||
(Used, _) | (_, Used) => Used,
|
||||
(TrmcOppotunity, _) | (_, TrmcOppotunity) => TrmcOppotunity,
|
||||
(NotUsed, NotUsed) => NotUsed,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn has_cons_in_tail_position(
|
||||
stmt: &Stmt<'_>,
|
||||
fn trmc_candidates_help<'a>(
|
||||
function_name: LambdaName,
|
||||
recursive_call: Option<Symbol>,
|
||||
) -> SymbolUse {
|
||||
// we are looking for code of the form
|
||||
//
|
||||
// let x = Tag a b c
|
||||
// ret x
|
||||
|
||||
stmt: &'_ Stmt<'a>,
|
||||
mut candidates: TrmcCandidateSet,
|
||||
) -> TrmcCandidateSet {
|
||||
// if this stmt is the literal tail tag application and return, then this is a TRMC opportunity
|
||||
if let Some(cons_info) = TrmcEnv::is_terminal_constructor(stmt) {
|
||||
// must use the result of a recursive call directly as an argument
|
||||
if let Some(recursive_call) = recursive_call {
|
||||
if cons_info.arguments.contains(&recursive_call) {
|
||||
return SymbolUse::TrmcOppotunity;
|
||||
// we pick the (syntactically) first one
|
||||
for recursive_call in candidates.active.iter() {
|
||||
if cons_info.arguments.contains(recursive_call) {
|
||||
return TrmcCandidateSet {
|
||||
confirmed: VecSet::singleton(*recursive_call),
|
||||
active: VecSet::default(),
|
||||
invalid: candidates.invalid,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// if the stmt uses the active recursive call, that invalidates the recursive call for this branch
|
||||
if let Some(recursive_call) = recursive_call {
|
||||
if stmt_contains_symbol_nonrec(stmt, recursive_call) {
|
||||
// this means we really only check for the first recursive call (in each branch)
|
||||
// whether it presents a TRMC opportunity. In theory we can look at all recursive calls
|
||||
// this is future work.
|
||||
return SymbolUse::Used;
|
||||
}
|
||||
}
|
||||
candidates.retain(|recursive_call| !stmt_contains_symbol_nonrec(stmt, *recursive_call));
|
||||
|
||||
match stmt {
|
||||
Stmt::Let(symbol, expr, _, next) => {
|
||||
// find a new recursive call if we currently have none
|
||||
// that means we generally pick the first recursive call we find
|
||||
let recursive_call = recursive_call
|
||||
.or_else(|| TrmcEnv::is_recursive_expr(expr, function_name).map(|_| *symbol));
|
||||
if TrmcEnv::is_recursive_expr(expr, function_name).is_some() {
|
||||
candidates.insert(*symbol);
|
||||
}
|
||||
|
||||
has_cons_in_tail_position(next, function_name, recursive_call)
|
||||
trmc_candidates_help(function_name, next, candidates)
|
||||
}
|
||||
Stmt::Switch {
|
||||
branches,
|
||||
|
@ -503,54 +511,43 @@ fn has_cons_in_tail_position(
|
|||
.map(|(_, _, stmt)| stmt)
|
||||
.chain([default_branch.1]);
|
||||
|
||||
let mut accum = SymbolUse::NotUsed;
|
||||
let mut accum = candidates.clone();
|
||||
|
||||
for next in it {
|
||||
let x = has_cons_in_tail_position(next, function_name, recursive_call);
|
||||
accum = accum.mappend(x);
|
||||
let x = trmc_candidates_help(function_name, next, candidates.clone());
|
||||
|
||||
if let SymbolUse::Used = accum {
|
||||
return SymbolUse::Used;
|
||||
}
|
||||
accum.extend(x);
|
||||
}
|
||||
|
||||
accum
|
||||
}
|
||||
Stmt::Refcounting(_, next) => {
|
||||
has_cons_in_tail_position(next, function_name, recursive_call)
|
||||
}
|
||||
Stmt::Refcounting(_, next) => trmc_candidates_help(function_name, next, candidates),
|
||||
Stmt::Expect { remainder, .. }
|
||||
| Stmt::ExpectFx { remainder, .. }
|
||||
| Stmt::Dbg { remainder, .. } => {
|
||||
has_cons_in_tail_position(remainder, function_name, recursive_call)
|
||||
}
|
||||
| Stmt::Dbg { remainder, .. } => trmc_candidates_help(function_name, remainder, candidates),
|
||||
Stmt::Join {
|
||||
body, remainder, ..
|
||||
} => {
|
||||
let x = has_cons_in_tail_position(body, function_name, recursive_call);
|
||||
let mut x = trmc_candidates_help(function_name, body, candidates.clone());
|
||||
let y = trmc_candidates_help(function_name, remainder, candidates.clone());
|
||||
|
||||
if let SymbolUse::Used = x {
|
||||
SymbolUse::Used
|
||||
} else {
|
||||
let y = has_cons_in_tail_position(remainder, function_name, recursive_call);
|
||||
x.mappend(y)
|
||||
}
|
||||
x.extend(y);
|
||||
|
||||
x
|
||||
}
|
||||
Stmt::Ret(_) | Stmt::Jump(_, _) | Stmt::Crash(_, _) => SymbolUse::NotUsed,
|
||||
Stmt::Ret(_) | Stmt::Jump(_, _) | Stmt::Crash(_, _) => candidates,
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct TrmcEnv<'a> {
|
||||
function_name: LambdaName<'a>,
|
||||
hole_symbol: Symbol,
|
||||
initial_ptr_symbol: Symbol,
|
||||
joinpoint_id: JoinPointId,
|
||||
return_layout: InLayout<'a>,
|
||||
ptr_return_layout: InLayout<'a>,
|
||||
|
||||
// the call we are performing TRMC on
|
||||
recursive_call: Option<(Symbol, Call<'a>)>,
|
||||
trmc_calls: MutMap<Symbol, Option<Call<'a>>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
@ -634,7 +631,11 @@ impl<'a> TrmcEnv<'a> {
|
|||
)
|
||||
}
|
||||
|
||||
pub fn init<'i>(env: &mut Env<'a, 'i>, proc: &Proc<'a>) -> Proc<'a> {
|
||||
pub fn init<'i>(
|
||||
env: &mut Env<'a, 'i>,
|
||||
proc: &Proc<'a>,
|
||||
trmc_calls: VecSet<Symbol>,
|
||||
) -> Proc<'a> {
|
||||
let arena = env.arena;
|
||||
let return_layout = proc.ret_layout;
|
||||
|
||||
|
@ -682,14 +683,15 @@ impl<'a> TrmcEnv<'a> {
|
|||
|
||||
let jump_stmt = Stmt::Jump(joinpoint_id, jump_arguments.into_bump_slice());
|
||||
|
||||
let trmc_calls = trmc_calls.iter().map(|s| (*s, None)).collect();
|
||||
|
||||
let mut this = Self {
|
||||
function_name: proc.name,
|
||||
hole_symbol,
|
||||
initial_ptr_symbol,
|
||||
joinpoint_id,
|
||||
return_layout,
|
||||
ptr_return_layout,
|
||||
recursive_call: None,
|
||||
trmc_calls,
|
||||
};
|
||||
|
||||
let param = Param {
|
||||
|
@ -733,24 +735,30 @@ impl<'a> TrmcEnv<'a> {
|
|||
|
||||
match stmt {
|
||||
Stmt::Let(symbol, expr, layout, next) => {
|
||||
if self.recursive_call.is_none() {
|
||||
if let Some(call) = Self::is_recursive_expr(expr, self.function_name) {
|
||||
let can_trmc =
|
||||
has_cons_in_tail_position(next, self.function_name, Some(*symbol));
|
||||
// if this is a TRMC call,
|
||||
if let Some(opt_call) = self.trmc_calls.get_mut(symbol) {
|
||||
debug_assert!(opt_call.is_none());
|
||||
|
||||
match can_trmc {
|
||||
SymbolUse::NotUsed => { /* the variable is dead */ }
|
||||
SymbolUse::TrmcOppotunity => {
|
||||
self.recursive_call = Some((*symbol, call));
|
||||
return self.walk_stmt(env, next);
|
||||
}
|
||||
SymbolUse::Used => { /* the variable is used making TRMC invaid */ }
|
||||
}
|
||||
}
|
||||
let call = match expr {
|
||||
Expr::Call(call) => call,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
*opt_call = Some(call.clone());
|
||||
|
||||
return self.walk_stmt(env, next);
|
||||
}
|
||||
|
||||
if let Some(cons_info) = Self::is_terminal_constructor(stmt) {
|
||||
match &self.recursive_call {
|
||||
// figure out which TRMC call to use here. We pick the first one that works
|
||||
let opt_recursive_call = cons_info.arguments.iter().find_map(|arg| {
|
||||
self.trmc_calls
|
||||
.get(arg)
|
||||
.and_then(|x| x.as_ref())
|
||||
.map(|x| (arg, x))
|
||||
});
|
||||
|
||||
match opt_recursive_call {
|
||||
None => {
|
||||
// this control flow path did not encounter a recursive call. Just
|
||||
// write the end result into the hole and we're done.
|
||||
|
@ -863,16 +871,12 @@ impl<'a> TrmcEnv<'a> {
|
|||
} => {
|
||||
let mut new_branches = Vec::with_capacity_in(branches.len(), arena);
|
||||
|
||||
let opt_recursive_call = self.recursive_call.clone();
|
||||
|
||||
for (id, info, stmt) in branches.iter() {
|
||||
self.recursive_call = opt_recursive_call.clone();
|
||||
let new_stmt = self.walk_stmt(env, stmt);
|
||||
|
||||
new_branches.push((*id, info.clone(), new_stmt));
|
||||
}
|
||||
|
||||
self.recursive_call = opt_recursive_call;
|
||||
let new_default_branch = &*arena.alloc(self.walk_stmt(env, default_branch.1));
|
||||
|
||||
Stmt::Switch {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue