search for multiple TRMC opportunities

This commit is contained in:
Folkert 2023-06-21 20:56:36 +02:00
parent b349fca521
commit c87519c209
No known key found for this signature in database
GPG key ID: 1F17F6FFD112B97C
8 changed files with 138 additions and 116 deletions

View file

@ -10,7 +10,7 @@ use crate::layout::{
};
use bumpalo::collections::Vec;
use bumpalo::Bump;
use roc_collections::MutMap;
use roc_collections::{MutMap, VecSet};
use roc_module::low_level::LowLevel;
use roc_module::symbol::{IdentIds, ModuleId, Symbol};
@ -53,8 +53,11 @@ pub fn apply_trmc<'a, 'i>(
for proc in procs.values_mut() {
use self::SelfRecursive::*;
if let SelfRecursive(id) = proc.is_self_recursive {
if crate::tail_recursion::is_trmc_candidate(env.interner, proc) {
let new_proc = crate::tail_recursion::TrmcEnv::init(env, proc);
let trmc_candidate_symbols = trmc_candidates(env.interner, proc);
if !trmc_candidate_symbols.is_empty() {
let new_proc =
crate::tail_recursion::TrmcEnv::init(env, proc, trmc_candidate_symbols);
*proc = new_proc;
} else {
let mut args = Vec::with_capacity_in(proc.args.len(), arena);
@ -402,7 +405,49 @@ fn insert_jumps<'a>(
}
}
pub(crate) fn is_trmc_candidate<'a, I>(interner: &'_ I, proc: &'_ Proc<'a>) -> bool
#[derive(Debug, Clone, Default)]
struct TrmcCandidateSet {
/// Recursive calls for which we have found a TRMC opportunity
confirmed: VecSet<Symbol>,
/// Recursive calls that are (still) considered for TRMC
active: VecSet<Symbol>,
/// Recursive calls that are used in such a way that makes TRMC impossible
invalid: VecSet<Symbol>,
}
impl TrmcCandidateSet {
fn insert(&mut self, call: Symbol) {
// there really is no way it could have been inserted already
debug_assert!(!self.invalid.contains(&call));
self.active.insert(call);
}
fn extend(&mut self, other: Self) {
self.confirmed.keep_if_in_either(other.confirmed);
self.invalid.keep_if_in_either(other.invalid);
self.active.keep_if_in_either(other.active);
self.active.retain(|k| !self.invalid.contains(k));
self.confirmed.retain(|k| !self.invalid.contains(k));
}
fn retain<F>(&mut self, keep: F)
where
F: Fn(&Symbol) -> bool,
{
for c in self.active.iter() {
if !keep(c) {
self.invalid.insert(*c);
}
}
self.active.retain(|k| !self.invalid.contains(k));
self.confirmed.retain(|k| !self.invalid.contains(k));
}
}
fn trmc_candidates<'a, I>(interner: &'_ I, proc: &'_ Proc<'a>) -> VecSet<Symbol>
where
I: LayoutInterner<'a>,
{
@ -411,87 +456,50 @@ where
proc.is_self_recursive,
crate::ir::SelfRecursive::SelfRecursive(_)
) {
return false;
return VecSet::default();
}
// and return a recursive tag union
if !matches!(interner.get_repr(proc.ret_layout), LayoutRepr::Union(union_layout) if union_layout.is_recursive())
{
return false;
return VecSet::default();
}
match has_cons_in_tail_position(&proc.body, proc.name, None) {
SymbolUse::NotUsed | SymbolUse::Used => false,
SymbolUse::TrmcOppotunity => true,
}
trmc_candidates_help(proc.name, &proc.body, TrmcCandidateSet::default()).confirmed
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord)]
#[repr(C)]
enum SymbolUse {
#[default]
NotUsed = 0,
TrmcOppotunity = 1,
Used = 2,
}
impl SymbolUse {
#[must_use]
fn mappend(self, y: Self) -> Self {
debug_assert_eq!(self.mappend_slow(y), Ord::max(self, y));
Ord::max(self, y)
}
fn mappend_slow(self, y: Self) -> Self {
use SymbolUse::*;
match (self, y) {
(Used, _) | (_, Used) => Used,
(TrmcOppotunity, _) | (_, TrmcOppotunity) => TrmcOppotunity,
(NotUsed, NotUsed) => NotUsed,
}
}
}
fn has_cons_in_tail_position(
stmt: &Stmt<'_>,
fn trmc_candidates_help<'a>(
function_name: LambdaName,
recursive_call: Option<Symbol>,
) -> SymbolUse {
// we are looking for code of the form
//
// let x = Tag a b c
// ret x
stmt: &'_ Stmt<'a>,
mut candidates: TrmcCandidateSet,
) -> TrmcCandidateSet {
// if this stmt is the literal tail tag application and return, then this is a TRMC opportunity
if let Some(cons_info) = TrmcEnv::is_terminal_constructor(stmt) {
// must use the result of a recursive call directly as an argument
if let Some(recursive_call) = recursive_call {
if cons_info.arguments.contains(&recursive_call) {
return SymbolUse::TrmcOppotunity;
// we pick the (syntactically) first one
for recursive_call in candidates.active.iter() {
if cons_info.arguments.contains(recursive_call) {
return TrmcCandidateSet {
confirmed: VecSet::singleton(*recursive_call),
active: VecSet::default(),
invalid: candidates.invalid,
};
}
}
}
// if the stmt uses the active recursive call, that invalidates the recursive call for this branch
if let Some(recursive_call) = recursive_call {
if stmt_contains_symbol_nonrec(stmt, recursive_call) {
// this means we really only check for the first recursive call (in each branch)
// whether it presents a TRMC opportunity. In theory we can look at all recursive calls
// this is future work.
return SymbolUse::Used;
}
}
candidates.retain(|recursive_call| !stmt_contains_symbol_nonrec(stmt, *recursive_call));
match stmt {
Stmt::Let(symbol, expr, _, next) => {
// find a new recursive call if we currently have none
// that means we generally pick the first recursive call we find
let recursive_call = recursive_call
.or_else(|| TrmcEnv::is_recursive_expr(expr, function_name).map(|_| *symbol));
if TrmcEnv::is_recursive_expr(expr, function_name).is_some() {
candidates.insert(*symbol);
}
has_cons_in_tail_position(next, function_name, recursive_call)
trmc_candidates_help(function_name, next, candidates)
}
Stmt::Switch {
branches,
@ -503,54 +511,43 @@ fn has_cons_in_tail_position(
.map(|(_, _, stmt)| stmt)
.chain([default_branch.1]);
let mut accum = SymbolUse::NotUsed;
let mut accum = candidates.clone();
for next in it {
let x = has_cons_in_tail_position(next, function_name, recursive_call);
accum = accum.mappend(x);
let x = trmc_candidates_help(function_name, next, candidates.clone());
if let SymbolUse::Used = accum {
return SymbolUse::Used;
}
accum.extend(x);
}
accum
}
Stmt::Refcounting(_, next) => {
has_cons_in_tail_position(next, function_name, recursive_call)
}
Stmt::Refcounting(_, next) => trmc_candidates_help(function_name, next, candidates),
Stmt::Expect { remainder, .. }
| Stmt::ExpectFx { remainder, .. }
| Stmt::Dbg { remainder, .. } => {
has_cons_in_tail_position(remainder, function_name, recursive_call)
}
| Stmt::Dbg { remainder, .. } => trmc_candidates_help(function_name, remainder, candidates),
Stmt::Join {
body, remainder, ..
} => {
let x = has_cons_in_tail_position(body, function_name, recursive_call);
let mut x = trmc_candidates_help(function_name, body, candidates.clone());
let y = trmc_candidates_help(function_name, remainder, candidates.clone());
if let SymbolUse::Used = x {
SymbolUse::Used
} else {
let y = has_cons_in_tail_position(remainder, function_name, recursive_call);
x.mappend(y)
}
x.extend(y);
x
}
Stmt::Ret(_) | Stmt::Jump(_, _) | Stmt::Crash(_, _) => SymbolUse::NotUsed,
Stmt::Ret(_) | Stmt::Jump(_, _) | Stmt::Crash(_, _) => candidates,
}
}
#[derive(Clone)]
pub(crate) struct TrmcEnv<'a> {
function_name: LambdaName<'a>,
hole_symbol: Symbol,
initial_ptr_symbol: Symbol,
joinpoint_id: JoinPointId,
return_layout: InLayout<'a>,
ptr_return_layout: InLayout<'a>,
// the call we are performing TRMC on
recursive_call: Option<(Symbol, Call<'a>)>,
trmc_calls: MutMap<Symbol, Option<Call<'a>>>,
}
#[derive(Debug)]
@ -634,7 +631,11 @@ impl<'a> TrmcEnv<'a> {
)
}
pub fn init<'i>(env: &mut Env<'a, 'i>, proc: &Proc<'a>) -> Proc<'a> {
pub fn init<'i>(
env: &mut Env<'a, 'i>,
proc: &Proc<'a>,
trmc_calls: VecSet<Symbol>,
) -> Proc<'a> {
let arena = env.arena;
let return_layout = proc.ret_layout;
@ -682,14 +683,15 @@ impl<'a> TrmcEnv<'a> {
let jump_stmt = Stmt::Jump(joinpoint_id, jump_arguments.into_bump_slice());
let trmc_calls = trmc_calls.iter().map(|s| (*s, None)).collect();
let mut this = Self {
function_name: proc.name,
hole_symbol,
initial_ptr_symbol,
joinpoint_id,
return_layout,
ptr_return_layout,
recursive_call: None,
trmc_calls,
};
let param = Param {
@ -733,24 +735,30 @@ impl<'a> TrmcEnv<'a> {
match stmt {
Stmt::Let(symbol, expr, layout, next) => {
if self.recursive_call.is_none() {
if let Some(call) = Self::is_recursive_expr(expr, self.function_name) {
let can_trmc =
has_cons_in_tail_position(next, self.function_name, Some(*symbol));
// if this is a TRMC call,
if let Some(opt_call) = self.trmc_calls.get_mut(symbol) {
debug_assert!(opt_call.is_none());
match can_trmc {
SymbolUse::NotUsed => { /* the variable is dead */ }
SymbolUse::TrmcOppotunity => {
self.recursive_call = Some((*symbol, call));
return self.walk_stmt(env, next);
}
SymbolUse::Used => { /* the variable is used making TRMC invaid */ }
}
}
let call = match expr {
Expr::Call(call) => call,
_ => unreachable!(),
};
*opt_call = Some(call.clone());
return self.walk_stmt(env, next);
}
if let Some(cons_info) = Self::is_terminal_constructor(stmt) {
match &self.recursive_call {
// figure out which TRMC call to use here. We pick the first one that works
let opt_recursive_call = cons_info.arguments.iter().find_map(|arg| {
self.trmc_calls
.get(arg)
.and_then(|x| x.as_ref())
.map(|x| (arg, x))
});
match opt_recursive_call {
None => {
// this control flow path did not encounter a recursive call. Just
// write the end result into the hole and we're done.
@ -863,16 +871,12 @@ impl<'a> TrmcEnv<'a> {
} => {
let mut new_branches = Vec::with_capacity_in(branches.len(), arena);
let opt_recursive_call = self.recursive_call.clone();
for (id, info, stmt) in branches.iter() {
self.recursive_call = opt_recursive_call.clone();
let new_stmt = self.walk_stmt(env, stmt);
new_branches.push((*id, info.clone(), new_stmt));
}
self.recursive_call = opt_recursive_call;
let new_default_branch = &*arena.alloc(self.walk_stmt(env, default_branch.1));
Stmt::Switch {