mirror of
https://github.com/roc-lang/roc.git
synced 2025-09-26 21:39:07 +00:00
invalidate TRMC attempt when symbol is used before TRMC opportunity
This commit is contained in:
parent
9ab4413beb
commit
880d2ef788
1 changed files with 163 additions and 51 deletions
|
@ -420,67 +420,124 @@ where
|
|||
return false;
|
||||
}
|
||||
|
||||
has_cons_in_tail_position(&proc.body, proc.name, None)
|
||||
match has_cons_in_tail_position(&proc.body, proc.name, None) {
|
||||
SymbolUse::NotUsed | SymbolUse::Used => false,
|
||||
SymbolUse::TrmcOppotunity => true,
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord)]
|
||||
#[repr(C)]
|
||||
enum SymbolUse {
|
||||
#[default]
|
||||
NotUsed = 0,
|
||||
TrmcOppotunity = 1,
|
||||
Used = 2,
|
||||
}
|
||||
|
||||
impl SymbolUse {
|
||||
#[must_use]
|
||||
fn mappend(self, y: Self) -> Self {
|
||||
debug_assert_eq!(self.mappend_slow(y), Ord::max(self, y));
|
||||
|
||||
Ord::max(self, y)
|
||||
}
|
||||
|
||||
fn mappend_slow(self, y: Self) -> Self {
|
||||
use SymbolUse::*;
|
||||
|
||||
match (self, y) {
|
||||
(Used, _) | (_, Used) => Used,
|
||||
(TrmcOppotunity, _) | (_, TrmcOppotunity) => TrmcOppotunity,
|
||||
(NotUsed, NotUsed) => NotUsed,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn has_cons_in_tail_position(
|
||||
initial_stmt: &Stmt<'_>,
|
||||
stmt: &Stmt<'_>,
|
||||
function_name: LambdaName,
|
||||
initial_recursive_call: Option<Symbol>,
|
||||
) -> bool {
|
||||
recursive_call: Option<Symbol>,
|
||||
) -> SymbolUse {
|
||||
// we are looking for code of the form
|
||||
//
|
||||
// let x = Tag a b c
|
||||
// ret x
|
||||
|
||||
let mut stack = vec![(initial_recursive_call, initial_stmt)];
|
||||
|
||||
while let Some((recursive_call, stmt)) = stack.pop() {
|
||||
match stmt {
|
||||
Stmt::Let(symbol, expr, _, next) => {
|
||||
if let Some(cons_info) = TrmcEnv::is_terminal_constructor(stmt) {
|
||||
// must use the result of a recursive call directly as an argument
|
||||
if let Some(recursive_call) = recursive_call {
|
||||
if cons_info.arguments.contains(&recursive_call) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let recursive_call = recursive_call
|
||||
.or_else(|| TrmcEnv::is_recursive_expr(expr, function_name).map(|_| *symbol));
|
||||
|
||||
stack.push((recursive_call, next));
|
||||
// if this stmt is the literal tail tag application and return, then this is a TRMC opportunity
|
||||
if let Some(cons_info) = TrmcEnv::is_terminal_constructor(stmt) {
|
||||
// must use the result of a recursive call directly as an argument
|
||||
if let Some(recursive_call) = recursive_call {
|
||||
if cons_info.arguments.contains(&recursive_call) {
|
||||
return SymbolUse::TrmcOppotunity;
|
||||
}
|
||||
Stmt::Switch {
|
||||
branches,
|
||||
default_branch,
|
||||
..
|
||||
} => {
|
||||
for (_, _, stmt) in branches.iter() {
|
||||
stack.push((recursive_call, stmt));
|
||||
}
|
||||
stack.push((recursive_call, default_branch.1));
|
||||
}
|
||||
Stmt::Refcounting(_, next) => {
|
||||
stack.push((recursive_call, next));
|
||||
}
|
||||
Stmt::Expect { remainder, .. }
|
||||
| Stmt::ExpectFx { remainder, .. }
|
||||
| Stmt::Dbg { remainder, .. } => {
|
||||
stack.push((recursive_call, remainder));
|
||||
}
|
||||
Stmt::Join {
|
||||
body, remainder, ..
|
||||
} => {
|
||||
stack.push((recursive_call, body));
|
||||
stack.push((recursive_call, remainder));
|
||||
}
|
||||
Stmt::Ret(_) | Stmt::Jump(_, _) | Stmt::Crash(_, _) => { /* terminal */ }
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
// if the stmt uses the active recursive call, that invalidates the recursive call for this branch
|
||||
if let Some(recursive_call) = recursive_call {
|
||||
if stmt_contains_symbol_nonrec(stmt, recursive_call) {
|
||||
// this means we really only check for the first recursive call (in each branch)
|
||||
// whether it presents a TRMC opportunity. In theory we can look at all recursive calls
|
||||
// this is future work.
|
||||
return SymbolUse::Used;
|
||||
}
|
||||
}
|
||||
|
||||
match stmt {
|
||||
Stmt::Let(symbol, expr, _, next) => {
|
||||
// find a new recursive call if we currently have none
|
||||
// that means we generally pick the first recursive call we find
|
||||
let recursive_call = recursive_call
|
||||
.or_else(|| TrmcEnv::is_recursive_expr(expr, function_name).map(|_| *symbol));
|
||||
|
||||
has_cons_in_tail_position(next, function_name, recursive_call)
|
||||
}
|
||||
Stmt::Switch {
|
||||
branches,
|
||||
default_branch,
|
||||
..
|
||||
} => {
|
||||
let it = branches
|
||||
.iter()
|
||||
.map(|(_, _, stmt)| stmt)
|
||||
.chain([default_branch.1]);
|
||||
|
||||
let mut accum = SymbolUse::NotUsed;
|
||||
|
||||
for next in it {
|
||||
let x = has_cons_in_tail_position(next, function_name, recursive_call);
|
||||
accum = accum.mappend(x);
|
||||
|
||||
if let SymbolUse::Used = accum {
|
||||
return SymbolUse::Used;
|
||||
}
|
||||
}
|
||||
|
||||
accum
|
||||
}
|
||||
Stmt::Refcounting(_, next) => {
|
||||
has_cons_in_tail_position(next, function_name, recursive_call)
|
||||
}
|
||||
Stmt::Expect { remainder, .. }
|
||||
| Stmt::ExpectFx { remainder, .. }
|
||||
| Stmt::Dbg { remainder, .. } => {
|
||||
has_cons_in_tail_position(remainder, function_name, recursive_call)
|
||||
}
|
||||
Stmt::Join {
|
||||
body, remainder, ..
|
||||
} => {
|
||||
let x = has_cons_in_tail_position(body, function_name, recursive_call);
|
||||
|
||||
if let SymbolUse::Used = x {
|
||||
SymbolUse::Used
|
||||
} else {
|
||||
let y = has_cons_in_tail_position(remainder, function_name, recursive_call);
|
||||
x.mappend(y)
|
||||
}
|
||||
}
|
||||
Stmt::Ret(_) | Stmt::Jump(_, _) | Stmt::Crash(_, _) => SymbolUse::NotUsed,
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
|
@ -678,9 +735,16 @@ impl<'a> TrmcEnv<'a> {
|
|||
Stmt::Let(symbol, expr, layout, next) => {
|
||||
if self.recursive_call.is_none() {
|
||||
if let Some(call) = Self::is_recursive_expr(expr, self.function_name) {
|
||||
if has_cons_in_tail_position(next, self.function_name, Some(*symbol)) {
|
||||
self.recursive_call = Some((*symbol, call));
|
||||
return self.walk_stmt(env, next);
|
||||
let can_trmc =
|
||||
has_cons_in_tail_position(next, self.function_name, Some(*symbol));
|
||||
|
||||
match can_trmc {
|
||||
SymbolUse::NotUsed => { /* the variable is dead */ }
|
||||
SymbolUse::TrmcOppotunity => {
|
||||
self.recursive_call = Some((*symbol, call));
|
||||
return self.walk_stmt(env, next);
|
||||
}
|
||||
SymbolUse::Used => { /* the variable is used making TRMC invaid */ }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -911,3 +975,51 @@ impl<'a> TrmcEnv<'a> {
|
|||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn expr_contains_symbol(expr: &Expr, needle: Symbol) -> bool {
|
||||
match expr {
|
||||
Expr::Literal(_) => false,
|
||||
Expr::Call(call) => call.arguments.contains(&needle),
|
||||
Expr::Tag { arguments, .. } => arguments.contains(&needle),
|
||||
Expr::Struct(fields) => fields.contains(&needle),
|
||||
Expr::NullPointer => false,
|
||||
Expr::StructAtIndex { structure, .. }
|
||||
| Expr::GetTagId { structure, .. }
|
||||
| Expr::UnionAtIndex { structure, .. }
|
||||
| Expr::UnionFieldPtrAtIndex { structure, .. } => needle == *structure,
|
||||
Expr::Array { elems, .. } => elems.iter().any(|element| match element {
|
||||
crate::ir::ListLiteralElement::Literal(_) => false,
|
||||
crate::ir::ListLiteralElement::Symbol(symbol) => needle == *symbol,
|
||||
}),
|
||||
Expr::EmptyArray => false,
|
||||
Expr::ExprBox { symbol } | Expr::ExprUnbox { symbol } => needle == *symbol,
|
||||
Expr::Reuse {
|
||||
symbol, arguments, ..
|
||||
} => needle == *symbol || arguments.contains(&needle),
|
||||
Expr::Reset { symbol, .. } | Expr::ResetRef { symbol, .. } => needle == *symbol,
|
||||
Expr::RuntimeErrorFunction(_) => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn stmt_contains_symbol_nonrec(stmt: &Stmt, needle: Symbol) -> bool {
|
||||
use crate::ir::ModifyRc::*;
|
||||
|
||||
match stmt {
|
||||
Stmt::Let(_, expr, _, _) => expr_contains_symbol(expr, needle),
|
||||
Stmt::Switch { cond_symbol, .. } => needle == *cond_symbol,
|
||||
Stmt::Ret(symbol) => needle == *symbol,
|
||||
Stmt::Refcounting(modify, _) => {
|
||||
matches!( modify, Inc(symbol, _) | Dec(symbol) | DecRef(symbol) if needle == *symbol )
|
||||
}
|
||||
Stmt::Expect {
|
||||
condition, lookups, ..
|
||||
}
|
||||
| Stmt::ExpectFx {
|
||||
condition, lookups, ..
|
||||
} => needle == *condition || lookups.contains(&needle),
|
||||
Stmt::Dbg { symbol, .. } => needle == *symbol,
|
||||
Stmt::Join { .. } => false,
|
||||
Stmt::Jump(_, arguments) => arguments.contains(&needle),
|
||||
Stmt::Crash(symbol, _) => needle == *symbol,
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue