invalidate TRMC attempt when symbol is used before TRMC opportunity

This commit is contained in:
Folkert 2023-06-20 21:20:33 +02:00
parent 9ab4413beb
commit 880d2ef788
No known key found for this signature in database
GPG key ID: 1F17F6FFD112B97C

View file

@ -420,67 +420,124 @@ where
return false;
}
has_cons_in_tail_position(&proc.body, proc.name, None)
match has_cons_in_tail_position(&proc.body, proc.name, None) {
SymbolUse::NotUsed | SymbolUse::Used => false,
SymbolUse::TrmcOppotunity => true,
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord)]
#[repr(C)]
enum SymbolUse {
#[default]
NotUsed = 0,
TrmcOppotunity = 1,
Used = 2,
}
impl SymbolUse {
#[must_use]
fn mappend(self, y: Self) -> Self {
debug_assert_eq!(self.mappend_slow(y), Ord::max(self, y));
Ord::max(self, y)
}
fn mappend_slow(self, y: Self) -> Self {
use SymbolUse::*;
match (self, y) {
(Used, _) | (_, Used) => Used,
(TrmcOppotunity, _) | (_, TrmcOppotunity) => TrmcOppotunity,
(NotUsed, NotUsed) => NotUsed,
}
}
}
fn has_cons_in_tail_position(
initial_stmt: &Stmt<'_>,
stmt: &Stmt<'_>,
function_name: LambdaName,
initial_recursive_call: Option<Symbol>,
) -> bool {
recursive_call: Option<Symbol>,
) -> SymbolUse {
// we are looking for code of the form
//
// let x = Tag a b c
// ret x
let mut stack = vec![(initial_recursive_call, initial_stmt)];
while let Some((recursive_call, stmt)) = stack.pop() {
match stmt {
Stmt::Let(symbol, expr, _, next) => {
if let Some(cons_info) = TrmcEnv::is_terminal_constructor(stmt) {
// must use the result of a recursive call directly as an argument
if let Some(recursive_call) = recursive_call {
if cons_info.arguments.contains(&recursive_call) {
return true;
}
}
}
let recursive_call = recursive_call
.or_else(|| TrmcEnv::is_recursive_expr(expr, function_name).map(|_| *symbol));
stack.push((recursive_call, next));
// if this stmt is the literal tail tag application and return, then this is a TRMC opportunity
if let Some(cons_info) = TrmcEnv::is_terminal_constructor(stmt) {
// must use the result of a recursive call directly as an argument
if let Some(recursive_call) = recursive_call {
if cons_info.arguments.contains(&recursive_call) {
return SymbolUse::TrmcOppotunity;
}
Stmt::Switch {
branches,
default_branch,
..
} => {
for (_, _, stmt) in branches.iter() {
stack.push((recursive_call, stmt));
}
stack.push((recursive_call, default_branch.1));
}
Stmt::Refcounting(_, next) => {
stack.push((recursive_call, next));
}
Stmt::Expect { remainder, .. }
| Stmt::ExpectFx { remainder, .. }
| Stmt::Dbg { remainder, .. } => {
stack.push((recursive_call, remainder));
}
Stmt::Join {
body, remainder, ..
} => {
stack.push((recursive_call, body));
stack.push((recursive_call, remainder));
}
Stmt::Ret(_) | Stmt::Jump(_, _) | Stmt::Crash(_, _) => { /* terminal */ }
}
}
false
// if the stmt uses the active recursive call, that invalidates the recursive call for this branch
if let Some(recursive_call) = recursive_call {
if stmt_contains_symbol_nonrec(stmt, recursive_call) {
// this means we really only check for the first recursive call (in each branch)
// whether it presents a TRMC opportunity. In theory we can look at all recursive calls
// this is future work.
return SymbolUse::Used;
}
}
match stmt {
Stmt::Let(symbol, expr, _, next) => {
// find a new recursive call if we currently have none
// that means we generally pick the first recursive call we find
let recursive_call = recursive_call
.or_else(|| TrmcEnv::is_recursive_expr(expr, function_name).map(|_| *symbol));
has_cons_in_tail_position(next, function_name, recursive_call)
}
Stmt::Switch {
branches,
default_branch,
..
} => {
let it = branches
.iter()
.map(|(_, _, stmt)| stmt)
.chain([default_branch.1]);
let mut accum = SymbolUse::NotUsed;
for next in it {
let x = has_cons_in_tail_position(next, function_name, recursive_call);
accum = accum.mappend(x);
if let SymbolUse::Used = accum {
return SymbolUse::Used;
}
}
accum
}
Stmt::Refcounting(_, next) => {
has_cons_in_tail_position(next, function_name, recursive_call)
}
Stmt::Expect { remainder, .. }
| Stmt::ExpectFx { remainder, .. }
| Stmt::Dbg { remainder, .. } => {
has_cons_in_tail_position(remainder, function_name, recursive_call)
}
Stmt::Join {
body, remainder, ..
} => {
let x = has_cons_in_tail_position(body, function_name, recursive_call);
if let SymbolUse::Used = x {
SymbolUse::Used
} else {
let y = has_cons_in_tail_position(remainder, function_name, recursive_call);
x.mappend(y)
}
}
Stmt::Ret(_) | Stmt::Jump(_, _) | Stmt::Crash(_, _) => SymbolUse::NotUsed,
}
}
#[derive(Clone)]
@ -678,9 +735,16 @@ impl<'a> TrmcEnv<'a> {
Stmt::Let(symbol, expr, layout, next) => {
if self.recursive_call.is_none() {
if let Some(call) = Self::is_recursive_expr(expr, self.function_name) {
if has_cons_in_tail_position(next, self.function_name, Some(*symbol)) {
self.recursive_call = Some((*symbol, call));
return self.walk_stmt(env, next);
let can_trmc =
has_cons_in_tail_position(next, self.function_name, Some(*symbol));
match can_trmc {
SymbolUse::NotUsed => { /* the variable is dead */ }
SymbolUse::TrmcOppotunity => {
self.recursive_call = Some((*symbol, call));
return self.walk_stmt(env, next);
}
SymbolUse::Used => { /* the variable is used making TRMC invaid */ }
}
}
}
@ -911,3 +975,51 @@ impl<'a> TrmcEnv<'a> {
)
}
}
fn expr_contains_symbol(expr: &Expr, needle: Symbol) -> bool {
match expr {
Expr::Literal(_) => false,
Expr::Call(call) => call.arguments.contains(&needle),
Expr::Tag { arguments, .. } => arguments.contains(&needle),
Expr::Struct(fields) => fields.contains(&needle),
Expr::NullPointer => false,
Expr::StructAtIndex { structure, .. }
| Expr::GetTagId { structure, .. }
| Expr::UnionAtIndex { structure, .. }
| Expr::UnionFieldPtrAtIndex { structure, .. } => needle == *structure,
Expr::Array { elems, .. } => elems.iter().any(|element| match element {
crate::ir::ListLiteralElement::Literal(_) => false,
crate::ir::ListLiteralElement::Symbol(symbol) => needle == *symbol,
}),
Expr::EmptyArray => false,
Expr::ExprBox { symbol } | Expr::ExprUnbox { symbol } => needle == *symbol,
Expr::Reuse {
symbol, arguments, ..
} => needle == *symbol || arguments.contains(&needle),
Expr::Reset { symbol, .. } | Expr::ResetRef { symbol, .. } => needle == *symbol,
Expr::RuntimeErrorFunction(_) => false,
}
}
fn stmt_contains_symbol_nonrec(stmt: &Stmt, needle: Symbol) -> bool {
use crate::ir::ModifyRc::*;
match stmt {
Stmt::Let(_, expr, _, _) => expr_contains_symbol(expr, needle),
Stmt::Switch { cond_symbol, .. } => needle == *cond_symbol,
Stmt::Ret(symbol) => needle == *symbol,
Stmt::Refcounting(modify, _) => {
matches!( modify, Inc(symbol, _) | Dec(symbol) | DecRef(symbol) if needle == *symbol )
}
Stmt::Expect {
condition, lookups, ..
}
| Stmt::ExpectFx {
condition, lookups, ..
} => needle == *condition || lookups.contains(&needle),
Stmt::Dbg { symbol, .. } => needle == *symbol,
Stmt::Join { .. } => false,
Stmt::Jump(_, arguments) => arguments.contains(&needle),
Stmt::Crash(symbol, _) => needle == *symbol,
}
}