diff --git a/crates/compiler/mono/src/tail_recursion.rs b/crates/compiler/mono/src/tail_recursion.rs index aec5b38c76..f2afbc9c7e 100644 --- a/crates/compiler/mono/src/tail_recursion.rs +++ b/crates/compiler/mono/src/tail_recursion.rs @@ -420,67 +420,124 @@ where return false; } - has_cons_in_tail_position(&proc.body, proc.name, None) + match has_cons_in_tail_position(&proc.body, proc.name, None) { + SymbolUse::NotUsed | SymbolUse::Used => false, + SymbolUse::TrmcOppotunity => true, + } +} + +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord)] +#[repr(C)] +enum SymbolUse { + #[default] + NotUsed = 0, + TrmcOppotunity = 1, + Used = 2, +} + +impl SymbolUse { + #[must_use] + fn mappend(self, y: Self) -> Self { + debug_assert_eq!(self.mappend_slow(y), Ord::max(self, y)); + + Ord::max(self, y) + } + + fn mappend_slow(self, y: Self) -> Self { + use SymbolUse::*; + + match (self, y) { + (Used, _) | (_, Used) => Used, + (TrmcOppotunity, _) | (_, TrmcOppotunity) => TrmcOppotunity, + (NotUsed, NotUsed) => NotUsed, + } + } } fn has_cons_in_tail_position( - initial_stmt: &Stmt<'_>, + stmt: &Stmt<'_>, function_name: LambdaName, - initial_recursive_call: Option, -) -> bool { + recursive_call: Option, +) -> SymbolUse { // we are looking for code of the form // // let x = Tag a b c // ret x - let mut stack = vec![(initial_recursive_call, initial_stmt)]; - - while let Some((recursive_call, stmt)) = stack.pop() { - match stmt { - Stmt::Let(symbol, expr, _, next) => { - if let Some(cons_info) = TrmcEnv::is_terminal_constructor(stmt) { - // must use the result of a recursive call directly as an argument - if let Some(recursive_call) = recursive_call { - if cons_info.arguments.contains(&recursive_call) { - return true; - } - } - } - - let recursive_call = recursive_call - .or_else(|| TrmcEnv::is_recursive_expr(expr, function_name).map(|_| *symbol)); - - stack.push((recursive_call, next)); + // if this stmt is the literal tail tag application and return, then this is a TRMC opportunity + if let Some(cons_info) = TrmcEnv::is_terminal_constructor(stmt) { + // must use the result of a recursive call directly as an argument + if let Some(recursive_call) = recursive_call { + if cons_info.arguments.contains(&recursive_call) { + return SymbolUse::TrmcOppotunity; } - Stmt::Switch { - branches, - default_branch, - .. - } => { - for (_, _, stmt) in branches.iter() { - stack.push((recursive_call, stmt)); - } - stack.push((recursive_call, default_branch.1)); - } - Stmt::Refcounting(_, next) => { - stack.push((recursive_call, next)); - } - Stmt::Expect { remainder, .. } - | Stmt::ExpectFx { remainder, .. } - | Stmt::Dbg { remainder, .. } => { - stack.push((recursive_call, remainder)); - } - Stmt::Join { - body, remainder, .. - } => { - stack.push((recursive_call, body)); - stack.push((recursive_call, remainder)); - } - Stmt::Ret(_) | Stmt::Jump(_, _) | Stmt::Crash(_, _) => { /* terminal */ } } } - false + // if the stmt uses the active recursive call, that invalidates the recursive call for this branch + if let Some(recursive_call) = recursive_call { + if stmt_contains_symbol_nonrec(stmt, recursive_call) { + // this means we really only check for the first recursive call (in each branch) + // whether it presents a TRMC opportunity. In theory we can look at all recursive calls + // this is future work. + return SymbolUse::Used; + } + } + + match stmt { + Stmt::Let(symbol, expr, _, next) => { + // find a new recursive call if we currently have none + // that means we generally pick the first recursive call we find + let recursive_call = recursive_call + .or_else(|| TrmcEnv::is_recursive_expr(expr, function_name).map(|_| *symbol)); + + has_cons_in_tail_position(next, function_name, recursive_call) + } + Stmt::Switch { + branches, + default_branch, + .. + } => { + let it = branches + .iter() + .map(|(_, _, stmt)| stmt) + .chain([default_branch.1]); + + let mut accum = SymbolUse::NotUsed; + + for next in it { + let x = has_cons_in_tail_position(next, function_name, recursive_call); + accum = accum.mappend(x); + + if let SymbolUse::Used = accum { + return SymbolUse::Used; + } + } + + accum + } + Stmt::Refcounting(_, next) => { + has_cons_in_tail_position(next, function_name, recursive_call) + } + Stmt::Expect { remainder, .. } + | Stmt::ExpectFx { remainder, .. } + | Stmt::Dbg { remainder, .. } => { + has_cons_in_tail_position(remainder, function_name, recursive_call) + } + Stmt::Join { + body, remainder, .. + } => { + let x = has_cons_in_tail_position(body, function_name, recursive_call); + + if let SymbolUse::Used = x { + SymbolUse::Used + } else { + let y = has_cons_in_tail_position(remainder, function_name, recursive_call); + x.mappend(y) + } + } + Stmt::Ret(_) | Stmt::Jump(_, _) | Stmt::Crash(_, _) => SymbolUse::NotUsed, + } } #[derive(Clone)] @@ -678,9 +735,16 @@ impl<'a> TrmcEnv<'a> { Stmt::Let(symbol, expr, layout, next) => { if self.recursive_call.is_none() { if let Some(call) = Self::is_recursive_expr(expr, self.function_name) { - if has_cons_in_tail_position(next, self.function_name, Some(*symbol)) { - self.recursive_call = Some((*symbol, call)); - return self.walk_stmt(env, next); + let can_trmc = + has_cons_in_tail_position(next, self.function_name, Some(*symbol)); + + match can_trmc { + SymbolUse::NotUsed => { /* the variable is dead */ } + SymbolUse::TrmcOppotunity => { + self.recursive_call = Some((*symbol, call)); + return self.walk_stmt(env, next); + } + SymbolUse::Used => { /* the variable is used making TRMC invaid */ } } } } @@ -911,3 +975,51 @@ impl<'a> TrmcEnv<'a> { ) } } + +fn expr_contains_symbol(expr: &Expr, needle: Symbol) -> bool { + match expr { + Expr::Literal(_) => false, + Expr::Call(call) => call.arguments.contains(&needle), + Expr::Tag { arguments, .. } => arguments.contains(&needle), + Expr::Struct(fields) => fields.contains(&needle), + Expr::NullPointer => false, + Expr::StructAtIndex { structure, .. } + | Expr::GetTagId { structure, .. } + | Expr::UnionAtIndex { structure, .. } + | Expr::UnionFieldPtrAtIndex { structure, .. } => needle == *structure, + Expr::Array { elems, .. } => elems.iter().any(|element| match element { + crate::ir::ListLiteralElement::Literal(_) => false, + crate::ir::ListLiteralElement::Symbol(symbol) => needle == *symbol, + }), + Expr::EmptyArray => false, + Expr::ExprBox { symbol } | Expr::ExprUnbox { symbol } => needle == *symbol, + Expr::Reuse { + symbol, arguments, .. + } => needle == *symbol || arguments.contains(&needle), + Expr::Reset { symbol, .. } | Expr::ResetRef { symbol, .. } => needle == *symbol, + Expr::RuntimeErrorFunction(_) => false, + } +} + +fn stmt_contains_symbol_nonrec(stmt: &Stmt, needle: Symbol) -> bool { + use crate::ir::ModifyRc::*; + + match stmt { + Stmt::Let(_, expr, _, _) => expr_contains_symbol(expr, needle), + Stmt::Switch { cond_symbol, .. } => needle == *cond_symbol, + Stmt::Ret(symbol) => needle == *symbol, + Stmt::Refcounting(modify, _) => { + matches!( modify, Inc(symbol, _) | Dec(symbol) | DecRef(symbol) if needle == *symbol ) + } + Stmt::Expect { + condition, lookups, .. + } + | Stmt::ExpectFx { + condition, lookups, .. + } => needle == *condition || lookups.contains(&needle), + Stmt::Dbg { symbol, .. } => needle == *symbol, + Stmt::Join { .. } => false, + Stmt::Jump(_, arguments) => arguments.contains(&needle), + Stmt::Crash(symbol, _) => needle == *symbol, + } +}